Skip to content

Commit 212fa58

Browse files
authored
fixes #9403 -- re-allow copying public keys (#9433)
they are immutable, so this is trivial
1 parent 986c4b5 commit 212fa58

File tree

15 files changed

+143
-1
lines changed

15 files changed

+143
-1
lines changed

src/rust/src/backend/dh.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -352,6 +352,10 @@ impl DHPublicKey {
352352
_ => Err(pyo3::exceptions::PyTypeError::new_err("Cannot be ordered")),
353353
}
354354
}
355+
356+
fn __copy__(slf: pyo3::PyRef<'_, Self>) -> pyo3::PyRef<'_, Self> {
357+
slf
358+
}
355359
}
356360

357361
#[pyo3::prelude::pymethods]

src/rust/src/backend/dsa.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -295,6 +295,10 @@ impl DsaPublicKey {
295295
_ => Err(pyo3::exceptions::PyTypeError::new_err("Cannot be ordered")),
296296
}
297297
}
298+
299+
fn __copy__(slf: pyo3::PyRef<'_, Self>) -> pyo3::PyRef<'_, Self> {
300+
slf
301+
}
298302
}
299303

300304
#[pyo3::prelude::pymethods]

src/rust/src/backend/ec.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,12 @@ impl ECPublicKey {
558558
_ => Err(pyo3::exceptions::PyTypeError::new_err("Cannot be ordered")),
559559
}
560560
}
561+
562+
fn __copy__(slf: pyo3::PyRef<'_, Self>) -> pyo3::PyRef<'_, Self> {
563+
slf
564+
}
561565
}
566+
562567
pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> {
563568
let m = pyo3::prelude::PyModule::new(py, "ec")?;
564569
m.add_function(pyo3::wrap_pyfunction!(curve_supported, m)?)?;

src/rust/src/backend/ed25519.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,10 @@ impl Ed25519PublicKey {
160160
_ => Err(pyo3::exceptions::PyTypeError::new_err("Cannot be ordered")),
161161
}
162162
}
163+
164+
fn __copy__(slf: pyo3::PyRef<'_, Self>) -> pyo3::PyRef<'_, Self> {
165+
slf
166+
}
163167
}
164168

165169
pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> {

src/rust/src/backend/ed448.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,10 @@ impl Ed448PublicKey {
158158
_ => Err(pyo3::exceptions::PyTypeError::new_err("Cannot be ordered")),
159159
}
160160
}
161+
162+
fn __copy__(slf: pyo3::PyRef<'_, Self>) -> pyo3::PyRef<'_, Self> {
163+
slf
164+
}
161165
}
162166

163167
pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> {

src/rust/src/backend/x25519.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,10 @@ impl X25519PublicKey {
149149
_ => Err(pyo3::exceptions::PyTypeError::new_err("Cannot be ordered")),
150150
}
151151
}
152+
153+
fn __copy__(slf: pyo3::PyRef<'_, Self>) -> pyo3::PyRef<'_, Self> {
154+
slf
155+
}
152156
}
153157

154158
pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> {

src/rust/src/backend/x448.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,10 @@ impl X448PublicKey {
148148
_ => Err(pyo3::exceptions::PyTypeError::new_err("Cannot be ordered")),
149149
}
150150
}
151+
152+
fn __copy__(slf: pyo3::PyRef<'_, Self>) -> pyo3::PyRef<'_, Self> {
153+
slf
154+
}
151155
}
152156

153157
pub(crate) fn create_module(py: pyo3::Python<'_>) -> pyo3::PyResult<&pyo3::prelude::PyModule> {

tests/hazmat/primitives/test_dh.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55

66
import binascii
7+
import copy
78
import itertools
89
import os
910
import typing
@@ -489,6 +490,21 @@ def test_public_key_equality(self, backend):
489490
with pytest.raises(TypeError):
490491
key1 < key2 # type: ignore[operator]
491492

493+
@pytest.mark.supported(
494+
only_if=lambda backend: backend.dh_x942_serialization_supported(),
495+
skip_message="DH X9.42 not supported",
496+
)
497+
def test_public_key_copy(self):
498+
key_bytes = load_vectors_from_file(
499+
os.path.join("asymmetric", "DH", "dhpub.pem"),
500+
lambda pemfile: pemfile.read(),
501+
mode="rb",
502+
)
503+
key1 = serialization.load_pem_public_key(key_bytes)
504+
key2 = copy.copy(key1)
505+
506+
assert key1 == key2
507+
492508

493509
@pytest.mark.supported(
494510
only_if=lambda backend: backend.dh_supported(),

tests/hazmat/primitives/test_dsa.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# for complete details.
44

55

6+
import copy
67
import itertools
78
import os
89
import typing
@@ -398,6 +399,16 @@ def test_public_key_equality(self, backend):
398399
with pytest.raises(TypeError):
399400
key1 < key2 # type: ignore[operator]
400401

402+
def test_public_key_copy(self):
403+
key_bytes = load_vectors_from_file(
404+
os.path.join("asymmetric", "PKCS8", "unenc-dsa-pkcs8.pem"),
405+
lambda pemfile: pemfile.read().encode(),
406+
)
407+
key1 = serialization.load_pem_private_key(key_bytes, None).public_key()
408+
key2 = copy.copy(key1)
409+
410+
assert key1 == key2
411+
401412

402413
@pytest.mark.supported(
403414
only_if=lambda backend: backend.dsa_supported(),

tests/hazmat/primitives/test_ec.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,8 @@
22
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
33
# for complete details.
44

5-
65
import binascii
6+
import copy
77
import itertools
88
import os
99
import textwrap
@@ -617,6 +617,17 @@ def test_public_key_equality(self, backend):
617617
with pytest.raises(TypeError):
618618
key1 < key2 # type: ignore[operator]
619619

620+
def test_public_key_copy(self, backend):
621+
_skip_curve_unsupported(backend, ec.SECP256R1())
622+
key_bytes = load_vectors_from_file(
623+
os.path.join("asymmetric", "PKCS8", "ec_private_key.pem"),
624+
lambda pemfile: pemfile.read().encode(),
625+
)
626+
key1 = serialization.load_pem_private_key(key_bytes, None).public_key()
627+
key2 = copy.copy(key1)
628+
629+
assert key1 == key2
630+
620631

621632
class TestECSerialization:
622633
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)