Skip to content

Commit d49947e

Browse files
authored
handle case where a "valid" pkey does not contain a valid EC key (#12101)
* handle case where a "valid" pkey does not contain a valid EC key * add test * skip the test in some scenarios
1 parent 235f991 commit d49947e

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

src/rust/src/backend/ec.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,11 @@ pub(crate) fn private_key_from_pkey(
135135
py: pyo3::Python<'_>,
136136
pkey: &openssl::pkey::PKeyRef<openssl::pkey::Private>,
137137
) -> CryptographyResult<ECPrivateKey> {
138-
let curve = py_curve_from_curve(py, pkey.ec_key().unwrap().group())?;
139-
check_key_infinity(&pkey.ec_key().unwrap())?;
138+
let ec_key = pkey
139+
.ec_key()
140+
.map_err(|_| pyo3::exceptions::PyValueError::new_err("Invalid EC key"))?;
141+
let curve = py_curve_from_curve(py, ec_key.group())?;
142+
check_key_infinity(&ec_key)?;
140143
Ok(ECPrivateKey {
141144
pkey: pkey.to_owned(),
142145
curve: curve.into(),

tests/hazmat/primitives/test_ec.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -466,6 +466,25 @@ def test_load_invalid_ec_key_from_pem(self, backend):
466466
backend=backend,
467467
)
468468

469+
@pytest.mark.supported(
470+
only_if=(
471+
lambda backend: rust_openssl.CRYPTOGRAPHY_OPENSSL_300_OR_GREATER
472+
or rust_openssl.CRYPTOGRAPHY_IS_BORINGSSL
473+
),
474+
skip_message="LibreSSL and OpenSSL 1.1.1 handle this differently",
475+
)
476+
def test_load_invalid_private_scalar_pem(self, backend):
477+
_skip_curve_unsupported(backend, ec.SECP256R1())
478+
479+
data = load_vectors_from_file(
480+
os.path.join(
481+
"asymmetric", "PKCS8", "ec-invalid-private-scalar.pem"
482+
),
483+
lambda pemfile: pemfile.read().encode(),
484+
)
485+
with pytest.raises(ValueError):
486+
serialization.load_pem_private_key(data, None)
487+
469488
def test_signatures(self, backend, subtests):
470489
vectors = itertools.chain(
471490
load_vectors_from_file(

0 commit comments

Comments
 (0)