diff --git a/sm2/Cargo.toml b/sm2/Cargo.toml index ffe2e9f15..8159c46c2 100644 --- a/sm2/Cargo.toml +++ b/sm2/Cargo.toml @@ -35,10 +35,11 @@ proptest = "1" rand_core = { version = "0.9", features = ["os_rng"] } [features] -default = ["arithmetic", "dsa", "pke", "pem", "std"] +default = ["arithmetic", "dsa", "pke", "pem", "std", "getrandom"] alloc = ["elliptic-curve/alloc"] std = ["alloc", "elliptic-curve/std"] +getrandom = ["rand_core/os_rng"] arithmetic = ["dep:primefield", "dep:primeorder", "elliptic-curve/arithmetic"] bits = ["arithmetic", "elliptic-curve/bits"] dsa = ["arithmetic", "dep:rfc6979", "dep:signature", "dep:sm3"] diff --git a/sm2/src/pke.rs b/sm2/src/pke.rs index 80988daae..78d43fd9f 100644 --- a/sm2/src/pke.rs +++ b/sm2/src/pke.rs @@ -11,27 +11,32 @@ //! # fn example() -> Result<(), Box> { //! use rand_core::OsRng; // requires 'os_rng` feature //! use sm2::{ -//! pke::{EncryptingKey, Mode}, -//! {SecretKey, PublicKey} +//! pke::{EncryptingKey, Mode, Cipher}, +//! {SecretKey, PublicKey}, +//! pkcs8::der::{Encode, Decode} //! }; //! //! // Encrypting //! let secret_key = SecretKey::try_from_rng(&mut OsRng).unwrap(); // serialize with `::to_bytes()` //! let public_key = secret_key.public_key(); -//! let encrypting_key = EncryptingKey::new_with_mode(public_key, Mode::C1C2C3); +//! let encrypting_key = EncryptingKey::new(public_key); //! let plaintext = b"plaintext"; -//! let ciphertext = encrypting_key.encrypt(&mut OsRng, plaintext)?; +//! let cipher = encrypting_key.encrypt_rng(&mut OsRng, plaintext)?; +//! let ciphertext = cipher.to_vec(Mode::C1C2C3); //! //! use sm2::pke::DecryptingKey; //! // Decrypting -//! let decrypting_key = DecryptingKey::new_with_mode(secret_key.to_nonzero_scalar(), Mode::C1C2C3); -//! assert_eq!(decrypting_key.decrypt(&ciphertext)?, plaintext); +//! let cipher = Cipher::from_slice(&ciphertext, Mode::C1C2C3)?; +//! let decrypting_key = DecryptingKey::from_nonzero_scalar(secret_key.to_nonzero_scalar()); +//! assert_eq!(decrypting_key.decrypt(&cipher)?, plaintext); //! //! // Encrypting ASN.1 DER -//! let ciphertext = encrypting_key.encrypt_der(&mut OsRng, plaintext)?; //! +//! let cipher = encrypting_key.encrypt_rng(&mut OsRng, plaintext)?; +//! let ciphertext = cipher.to_der()?; //! // Decrypting ASN.1 DER -//! assert_eq!(decrypting_key.decrypt_der(&ciphertext)?, plaintext); +//! let cipher = Cipher::from_der(&ciphertext)?; +//! assert_eq!(decrypting_key.decrypt(&cipher)?, plaintext); //! //! Ok(()) //! # } @@ -42,24 +47,23 @@ use core::cmp::min; -use crate::AffinePoint; - #[cfg(feature = "alloc")] -use alloc::vec; +use alloc::{borrow::Cow, vec::Vec}; use elliptic_curve::{ - bigint::{Encoding, U256, Uint}, + CurveArithmetic, Error, Group, PrimeField, Result, + array::Array, + ops::Reduce, pkcs8::der::{ - Decode, DecodeValue, Encode, Length, Reader, Sequence, Tag, Writer, asn1::UintRef, + self, Decode, DecodeValue, Encode, EncodeValue, Length, Reader, Sequence, Writer, + asn1::{OctetStringRef, UintRef}, }, + sec1::{EncodedPoint, FromEncodedPoint, ModulusSize, Tag, ToEncodedPoint}, }; -use elliptic_curve::{ - Result, - pkcs8::der::{EncodeValue, asn1::OctetStringRef}, - sec1::ToEncodedPoint, -}; -use sm3::digest::DynDigest; +use crate::Sm2; +use sm3::Sm3; +use sm3::digest::{FixedOutputReset, Output, OutputSizeUser, Update, typenum::Unsigned}; #[cfg(feature = "arithmetic")] mod decrypting; @@ -77,78 +81,247 @@ pub enum Mode { /// new mode C1C3C2, } +impl Default for Mode { + fn default() -> Self { + Self::C1C3C2 + } +} /// Represents a cipher structure containing encryption-related data (asn.1 format). /// /// The `Cipher` structure includes the coordinates of the elliptic curve point (`x`, `y`), /// the digest of the message, and the encrypted cipher text. -pub struct Cipher<'a> { - x: U256, - y: U256, - digest: &'a [u8], - cipher: &'a [u8], +#[derive(Debug)] +pub struct Cipher<'a, C: CurveArithmetic = Sm2, D: OutputSizeUser = Sm3> { + c1: C::AffinePoint, + #[cfg(feature = "alloc")] + c2: Cow<'a, [u8]>, + #[cfg(not(feature = "alloc"))] + c2: &'a [u8], + c3: Output, } -impl<'a> Sequence<'a> for Cipher<'a> {} +impl<'a, C, D> Cipher<'a, C, D> +where + C: CurveArithmetic, + C::AffinePoint: FromEncodedPoint + ToEncodedPoint, + C::FieldBytesSize: ModulusSize, + D: OutputSizeUser, +{ + /// Decode from slice + pub fn from_slice(cipher: &'a [u8], mode: Mode) -> Result { + let tag = Tag::from_u8(cipher.first().cloned().ok_or(Error)?)?; + let c1_len = tag.message_len(C::FieldBytesSize::USIZE); + + // B1: get 𝐢1 from 𝐢 + let (c1, c) = cipher.split_at(c1_len); + // verify that point c1 satisfies the elliptic curve + let encoded_c1 = EncodedPoint::::from_bytes(c1)?; + let c1: C::AffinePoint = + Option::from(FromEncodedPoint::from_encoded_point(&encoded_c1)).ok_or(Error)?; + // B2: compute point 𝑆 = [β„Ž]𝐢1 + let scalar: C::Scalar = Reduce::::reduce(C::Uint::from(C::Scalar::S)); + + let s: C::ProjectivePoint = C::ProjectivePoint::from(c1) * scalar; + if s.is_identity().into() { + return Err(Error); + } + + let digest_size = D::output_size(); + let (c2, c3_buf) = match mode { + Mode::C1C3C2 => { + let (c3, c2) = c.split_at(digest_size); + (c2, c3) + } + Mode::C1C2C3 => c.split_at(c.len() - digest_size), + }; -impl EncodeValue for Cipher<'_> { - fn value_len(&self) -> elliptic_curve::pkcs8::der::Result { - UintRef::new(&self.x.to_be_bytes())?.encoded_len()? - + UintRef::new(&self.y.to_be_bytes())?.encoded_len()? - + OctetStringRef::new(self.digest)?.encoded_len()? - + OctetStringRef::new(self.cipher)?.encoded_len()? + let mut c3 = Output::::default(); + c3.clone_from_slice(c3_buf); + + #[cfg(feature = "alloc")] + let c2 = Cow::Borrowed(c2); + + Ok(Self { c1, c2, c3 }) } - fn encode_value(&self, writer: &mut impl Writer) -> elliptic_curve::pkcs8::der::Result<()> { - UintRef::new(&self.x.to_be_bytes())?.encode(writer)?; - UintRef::new(&self.y.to_be_bytes())?.encode(writer)?; - OctetStringRef::new(self.digest)?.encode(writer)?; - OctetStringRef::new(self.cipher)?.encode(writer)?; - Ok(()) + /// Encode to Vec + #[cfg(feature = "alloc")] + pub fn to_vec(&self, mode: Mode) -> Vec { + let point = self.c1.to_encoded_point(false); + let len = point.len() + self.c2.len() + self.c3.len(); + let mut result = Vec::with_capacity(len); + match mode { + Mode::C1C2C3 => { + result.extend(point.as_ref()); + result.extend(self.c2.as_ref()); + result.extend(&self.c3); + } + Mode::C1C3C2 => { + result.extend(point.as_ref()); + result.extend(&self.c3); + result.extend(self.c2.as_ref()); + } + } + + result + } + /// Encode to Vec + #[cfg(feature = "alloc")] + pub fn to_vec_compressed(&self, mode: Mode) -> Vec { + let point = self.c1.to_encoded_point(true); + let len = point.len() + self.c2.len() + self.c3.len(); + let mut result = Vec::with_capacity(len); + match mode { + Mode::C1C2C3 => { + result.extend(point.as_ref()); + result.extend(self.c2.as_ref()); + result.extend(&self.c3); + } + Mode::C1C3C2 => { + result.extend(point.as_ref()); + result.extend(&self.c3); + result.extend(self.c2.as_ref()); + } + } + + result + } + /// Get C1 + pub fn c1(&self) -> &C::AffinePoint { + &self.c1 } + /// Get C2 + pub fn c2(&self) -> &[u8] { + #[cfg(feature = "alloc")] + return &self.c2; + #[cfg(not(feature = "alloc"))] + return self.c2; + } + /// Get C3 + pub fn c3(&self) -> &Output { + &self.c3 + } +} + +impl<'a, C, D> Sequence<'a> for Cipher<'a, C, D> +where + C: CurveArithmetic, + D: OutputSizeUser, + C::AffinePoint: ToEncodedPoint + FromEncodedPoint, + C::FieldBytesSize: ModulusSize, +{ } -impl<'a> DecodeValue<'a> for Cipher<'a> { - type Error = elliptic_curve::pkcs8::der::Error; +#[cfg_attr(not(feature = "alloc"), allow(clippy::useless_asref))] +impl EncodeValue for Cipher<'_, C, D> +where + C: CurveArithmetic, + D: OutputSizeUser, + C::AffinePoint: ToEncodedPoint, + C::FieldBytesSize: ModulusSize, +{ + fn value_len(&self) -> der::Result { + let point = self.c1.to_encoded_point(false); + UintRef::new(point.x().expect("x is None"))?.encoded_len()? + + UintRef::new(point.y().expect("y is None"))?.encoded_len()? + + OctetStringRef::new(&self.c3)?.encoded_len()? + + OctetStringRef::new(self.c2.as_ref())?.encoded_len()? + } + fn encode_value(&self, writer: &mut impl Writer) -> der::Result<()> { + let point = self.c1.to_encoded_point(false); + UintRef::new(point.x().expect("x is None"))?.encode(writer)?; + UintRef::new(point.y().expect("y is None"))?.encode(writer)?; + OctetStringRef::new(&self.c3)?.encode(writer)?; + OctetStringRef::new(self.c2.as_ref())?.encode(writer)?; + Ok(()) + } +} + +impl<'a, C, D> DecodeValue<'a> for Cipher<'a, C, D> +where + C: CurveArithmetic, + D: OutputSizeUser, + C::AffinePoint: FromEncodedPoint, + C::FieldBytesSize: ModulusSize, +{ + type Error = der::Error; fn decode_value>( decoder: &mut R, - header: elliptic_curve::pkcs8::der::Header, - ) -> core::result::Result { + header: der::Header, + ) -> core::result::Result { decoder.read_nested(header.length, |nr| { let x = UintRef::decode(nr)?.as_bytes(); let y = UintRef::decode(nr)?.as_bytes(); - let digest = OctetStringRef::decode(nr)?.into(); - let cipher = OctetStringRef::decode(nr)?.into(); - Ok(Cipher { - x: Uint::from_be_bytes(zero_pad_byte_slice(x)?), - y: Uint::from_be_bytes(zero_pad_byte_slice(y)?), - digest, - cipher, - }) + let digest = OctetStringRef::decode(nr)?.as_bytes(); + let cipher = OctetStringRef::decode(nr)?.as_bytes(); + let size = C::FieldBytesSize::USIZE; + + let num_zeroes = size + .checked_sub(x.len()) + .ok_or_else(|| der::Tag::Integer.length_error())?; + let mut x_arr = Array::default(); + x_arr[num_zeroes..].clone_from_slice(x); + + let num_zeroes = size + .checked_sub(y.len()) + .ok_or_else(|| der::Tag::Integer.length_error())?; + let mut y_arr = Array::default(); + y_arr[num_zeroes..].clone_from_slice(y); + + let point = EncodedPoint::::from_affine_coordinates(&x_arr, &y_arr, false); + let c1 = Option::from(C::AffinePoint::from_encoded_point(&point)).ok_or_else(|| { + der::Error::new( + der::ErrorKind::Value { + tag: der::Tag::Integer, + }, + Length::new(C::FieldBytesSize::U32 * 2), + ) + })?; + + #[cfg(feature = "alloc")] + let c2 = Cow::Borrowed(cipher); + #[cfg(not(feature = "alloc"))] + let c2 = cipher; + // Output::::try_from() + let c3 = Output::::try_from(digest).map_err(|_e| { + der::Error::new( + der::ErrorKind::Value { + tag: der::Tag::OctetString, + }, + Length::new(D::output_size().try_into().expect("usize case error")), + ) + })?; + Ok(Cipher { c1, c2, c3 }) }) } } /// Performs key derivation using a hash function and elliptic curve point. -fn kdf(hasher: &mut dyn DynDigest, kpb: AffinePoint, c2: &mut [u8]) -> Result<()> { - let klen = c2.len(); +fn kdf(hasher: &mut D, kpb: C::AffinePoint, msg: &[u8], c2_out: &mut [u8]) -> Result<()> +where + D: Update + FixedOutputReset, + C: CurveArithmetic, + C::FieldBytesSize: ModulusSize, + C::AffinePoint: ToEncodedPoint, +{ + let klen = msg.len(); let mut ct: i32 = 0x00000001; let mut offset = 0; - let digest_size = hasher.output_size(); - let mut ha = vec![0u8; digest_size]; + let digest_size = D::output_size(); + let mut ha = Output::::default(); let encode_point = kpb.to_encoded_point(false); + hasher.reset(); while offset < klen { - hasher.update(encode_point.x().ok_or(elliptic_curve::Error)?); - hasher.update(encode_point.y().ok_or(elliptic_curve::Error)?); + hasher.update(encode_point.x().ok_or(Error)?); + hasher.update(encode_point.y().ok_or(Error)?); hasher.update(&ct.to_be_bytes()); - hasher - .finalize_into_reset(&mut ha) - .map_err(|_e| elliptic_curve::Error)?; + hasher.finalize_into_reset(&mut ha); let xor_len = min(digest_size, klen - offset); - xor(c2, &ha, offset, xor_len); + xor(msg, c2_out, &ha, offset, xor_len); offset += xor_len; ct += 1; } @@ -156,22 +329,8 @@ fn kdf(hasher: &mut dyn DynDigest, kpb: AffinePoint, c2: &mut [u8]) -> Result<() } /// XORs a portion of the buffer `c2` with a hash value. -fn xor(c2: &mut [u8], ha: &[u8], offset: usize, xor_len: usize) { +fn xor(msg: &[u8], c2_out: &mut [u8], ha: &[u8], offset: usize, xor_len: usize) { for i in 0..xor_len { - c2[offset + i] ^= ha[i]; + c2_out[offset + i] = msg[offset + i] ^ ha[i]; } } - -/// Converts a byte slice to a fixed-size array, padding with leading zeroes if necessary. -pub(crate) fn zero_pad_byte_slice( - bytes: &[u8], -) -> elliptic_curve::pkcs8::der::Result<[u8; N]> { - let num_zeroes = N - .checked_sub(bytes.len()) - .ok_or_else(|| Tag::Integer.length_error())?; - - // Copy input into `N`-sized output buffer with leading zeroes - let mut output = [0u8; N]; - output[num_zeroes..].copy_from_slice(bytes); - Ok(output) -} diff --git a/sm2/src/pke/decrypting.rs b/sm2/src/pke/decrypting.rs index 62135ebad..1b842f982 100644 --- a/sm2/src/pke/decrypting.rs +++ b/sm2/src/pke/decrypting.rs @@ -1,47 +1,41 @@ use core::fmt::{self, Debug}; -use crate::{ - AffinePoint, EncodedPoint, FieldBytes, NonZeroScalar, PublicKey, Scalar, SecretKey, - arithmetic::field::FieldElement, -}; +use crate::Sm2; +use crate::{FieldBytes, NonZeroScalar, PublicKey, SecretKey}; +#[cfg(feature = "alloc")] +use alloc::{vec, vec::Vec}; -use alloc::{borrow::ToOwned, vec::Vec}; use elliptic_curve::{ - Error, Group, Result, - bigint::U256, - ops::Reduce, - pkcs8::der::Decode, - sec1::{FromEncodedPoint, ToEncodedPoint}, + CurveArithmetic, CurveGroup, Error, Result, + sec1::{ModulusSize, ToEncodedPoint}, subtle::{Choice, ConstantTimeEq}, }; -use primeorder::PrimeField; -use sm3::{Digest, Sm3, digest::DynDigest}; +use sm3::{ + Sm3, + digest::{Digest, FixedOutputReset, Output, Update}, +}; -use super::{Cipher, Mode, encrypting::EncryptingKey, kdf, vec}; +use super::{Cipher, encrypting::EncryptingKey, kdf}; /// Represents a decryption key used for decrypting messages using elliptic curve cryptography. #[derive(Clone)] pub struct DecryptingKey { secret_scalar: NonZeroScalar, encryting_key: EncryptingKey, - mode: Mode, } impl DecryptingKey { /// Creates a new `DecryptingKey` from a `SecretKey` with the default decryption mode (`C1C3C2`). pub fn new(secret_key: SecretKey) -> Self { - Self::new_with_mode(secret_key.to_nonzero_scalar(), Mode::C1C3C2) + Self::from_nonzero_scalar(secret_key.to_nonzero_scalar()) } + /// Create a signing key from a non-zero scalar. /// Creates a new `DecryptingKey` from a non-zero scalar and sets the decryption mode. - pub fn new_with_mode(secret_scalar: NonZeroScalar, mode: Mode) -> Self { + pub fn from_nonzero_scalar(secret_scalar: NonZeroScalar) -> Self { Self { secret_scalar, - encryting_key: EncryptingKey::new_with_mode( - PublicKey::from_secret_scalar(&secret_scalar), - mode, - ), - mode, + encryting_key: EncryptingKey::new(PublicKey::from_secret_scalar(&secret_scalar)), } } @@ -54,12 +48,7 @@ impl DecryptingKey { /// scalar value. pub fn from_slice(slice: &[u8]) -> Result { let secret_scalar = NonZeroScalar::try_from(slice).map_err(|_| Error)?; - Self::from_nonzero_scalar(secret_scalar) - } - - /// Create a signing key from a non-zero scalar. - pub fn from_nonzero_scalar(secret_scalar: NonZeroScalar) -> Result { - Ok(Self::new_with_mode(secret_scalar, Mode::C1C3C2)) + Ok(Self::from_nonzero_scalar(secret_scalar)) } /// Serialize as bytes. @@ -83,40 +72,38 @@ impl DecryptingKey { &self.encryting_key } - /// Decrypts a ciphertext in-place using the default digest algorithm (`Sm3`). - pub fn decrypt(&self, ciphertext: &[u8]) -> Result> { - self.decrypt_digest::(ciphertext) + /// Decrypt the [`Cipher`] using the default digest algorithm [`Sm3`]. + #[cfg(feature = "alloc")] + pub fn decrypt(&self, cipher: &Cipher<'_, Sm2, Sm3>) -> Result> { + self.decrypt_digest::(cipher) } - /// Decrypts a ciphertext in-place using the specified digest algorithm. - pub fn decrypt_digest(&self, ciphertext: &[u8]) -> Result> - where - D: 'static + Digest + DynDigest + Send + Sync, - { - let mut digest = D::new(); - decrypt(&self.secret_scalar, self.mode, &mut digest, ciphertext) + /// Decrypt the [`Cipher`] using the specified digest algorithm. + #[cfg(feature = "alloc")] + pub fn decrypt_digest( + &self, + cipher: &Cipher<'_, Sm2, D>, + ) -> Result> { + let mut out = vec![0; cipher.c2.len()]; + self.decrypt_digest_into(cipher, &mut out)?; + Ok(out) } - /// Decrypts a ciphertext in-place from ASN.1 format using the default digest algorithm (`Sm3`). - pub fn decrypt_der(&self, ciphertext: &[u8]) -> Result> { - self.decrypt_der_digest::(ciphertext) + /// Decrypt the [`Cipher`] using the default digest algorithm [`Sm3`]. + pub fn decrypt_into(&self, cipher: &Cipher<'_, Sm2, Sm3>, out: &mut [u8]) -> Result { + self.decrypt_digest_into(cipher, out) } - /// Decrypts a ciphertext in-place from ASN.1 format using the specified digest algorithm. - pub fn decrypt_der_digest(&self, ciphertext: &[u8]) -> Result> - where - D: 'static + Digest + DynDigest + Send + Sync, - { - let cipher = Cipher::from_der(ciphertext).map_err(elliptic_curve::pkcs8::Error::from)?; - let prefix: &[u8] = &[0x04]; - let x: [u8; 32] = cipher.x.to_be_bytes(); - let y: [u8; 32] = cipher.y.to_be_bytes(); - let cipher = match self.mode { - Mode::C1C2C3 => [prefix, &x, &y, cipher.cipher, cipher.digest].concat(), - Mode::C1C3C2 => [prefix, &x, &y, cipher.digest, cipher.cipher].concat(), - }; - - Ok(self.decrypt_digest::(&cipher)?.to_vec()) + /// Decrypt the [`Cipher`] to out using the specified digest algorithm. + /// The length of out is equal to the length of C2. + /// * Note: buffer zones are prohibited from overlapping + pub fn decrypt_digest_into( + &self, + cipher: &Cipher<'_, Sm2, D>, + out: &mut [u8], + ) -> Result { + let scalar = self.as_nonzero_scalar(); + decrypt_into(scalar.as_ref(), cipher, out) } } @@ -153,64 +140,48 @@ impl PartialEq for DecryptingKey { } } -fn decrypt( - secret_scalar: &Scalar, - mode: Mode, - hasher: &mut dyn DynDigest, - cipher: &[u8], -) -> Result> { - let q = U256::from_be_hex(FieldElement::MODULUS); - let c1_len = q.bits().div_ceil(8) * 2 + 1; - - // B1: get 𝐢1 from 𝐢 - let (c1, c) = cipher.split_at(c1_len as usize); - let encoded_c1 = EncodedPoint::from_bytes(c1).map_err(Error::from)?; - - // verify that point c1 satisfies the elliptic curve - let mut c1_point = AffinePoint::from_encoded_point(&encoded_c1).unwrap(); - - // B2: compute point 𝑆 = [β„Ž]𝐢1 - let s = c1_point * Scalar::reduce(U256::from_u32(FieldElement::S)); - if s.is_identity().into() { +fn decrypt_into( + secret_scalar: &C::Scalar, + cipher: &Cipher<'_, C, D>, + out: &mut [u8], +) -> Result +where + C: CurveArithmetic, + D: FixedOutputReset + Digest, + C::FieldBytesSize: ModulusSize, + C::AffinePoint: ToEncodedPoint, +{ + if out.len() < cipher.c2.len() { return Err(Error); } + let out = &mut out[..cipher.c2.len()]; + + let mut digest = D::new(); // B3: compute [𝑑𝐡]𝐢1 = (π‘₯2, 𝑦2) - c1_point = (c1_point * secret_scalar).to_affine(); - let digest_size = hasher.output_size(); - let (c2, c3) = match mode { - Mode::C1C3C2 => { - let (c3, c2) = c.split_at(digest_size); - (c2, c3) - } - Mode::C1C2C3 => c.split_at(c.len() - digest_size), - }; + let c1_point = (C::ProjectivePoint::from(cipher.c1) * secret_scalar).to_affine(); + + #[cfg(feature = "alloc")] + let c2 = &cipher.c2; + #[cfg(not(feature = "alloc"))] + let c2 = cipher.c2; // B4: compute 𝑑 = 𝐾𝐷𝐹(π‘₯2 βˆ₯ 𝑦2, π‘˜π‘™π‘’π‘›) // B5: get 𝐢2 from 𝐢 and compute 𝑀′ = 𝐢2 βŠ• t - let mut c2 = c2.to_owned(); - kdf(hasher, c1_point, &mut c2)?; + kdf::(&mut digest, c1_point, c2, out)?; // compute 𝑒 = π»π‘Žπ‘ β„Ž(π‘₯2 βˆ₯ 𝑀′βˆ₯ 𝑦2). - let mut u = vec![0u8; digest_size]; + let mut u = Output::::default(); let encode_point = c1_point.to_encoded_point(false); - hasher.update(encode_point.x().ok_or(Error)?); - hasher.update(&c2); - hasher.update(encode_point.y().ok_or(Error)?); - hasher.finalize_into_reset(&mut u).map_err(|_e| Error)?; - let checked = u - .iter() - .zip(c3) - .fold(0, |mut check, (&c3_byte, &c3checked_byte)| { - check |= c3_byte ^ c3checked_byte; - check - }); + Update::update(&mut digest, encode_point.x().ok_or(Error)?); + Update::update(&mut digest, out); + Update::update(&mut digest, encode_point.y().ok_or(Error)?); + FixedOutputReset::finalize_into_reset(&mut digest, &mut u); // If 𝑒 β‰  𝐢3, output β€œERROR” and exit - if checked != 0 { + if cipher.c3 != u { return Err(Error); } - // B7: output the plaintext 𝑀′. - Ok(c2.to_vec()) + Ok(out.len()) } diff --git a/sm2/src/pke/encrypting.rs b/sm2/src/pke/encrypting.rs index 1ab37040f..b82329dc3 100644 --- a/sm2/src/pke/encrypting.rs +++ b/sm2/src/pke/encrypting.rs @@ -1,45 +1,38 @@ use core::fmt::Debug; -use crate::{ - AffinePoint, ProjectivePoint, PublicKey, Scalar, Sm2, - arithmetic::field::FieldElement, - pke::{kdf, vec}, -}; - #[cfg(feature = "alloc")] -use alloc::{borrow::ToOwned, boxed::Box, vec::Vec}; +use alloc::{borrow::Cow, boxed::Box, vec}; + +use crate::PublicKey; +use crate::Sm2; + +use super::kdf; + +use rand_core::TryCryptoRng; + use elliptic_curve::{ - Curve, Error, Group, Result, - bigint::{RandomBits, U256, Uint, Zero}, + CurveArithmetic, CurveGroup, Error, Group, NonZeroScalar, Result, ops::Reduce, - pkcs8::der::Encode, - rand_core::TryCryptoRng, - sec1::ToEncodedPoint, + sec1::{ModulusSize, ToEncodedPoint}, }; use primeorder::PrimeField; use sm3::{ Sm3, - digest::{Digest, DynDigest}, + digest::{Digest, FixedOutputReset, Output, Update}, }; -use super::{Cipher, Mode}; +use super::Cipher; /// Represents an encryption key used for encrypting messages using elliptic curve cryptography. #[derive(Clone, Debug)] pub struct EncryptingKey { public_key: PublicKey, - mode: Mode, } impl EncryptingKey { /// Initialize [`EncryptingKey`] from PublicKey pub fn new(public_key: PublicKey) -> Self { - Self::new_with_mode(public_key, Mode::C1C2C3) - } - - /// Initialize [`EncryptingKey`] from PublicKey and set Encryption mode - pub fn new_with_mode(public_key: PublicKey, mode: Mode) -> Self { - Self { public_key, mode } + Self { public_key } } /// Initialize [`EncryptingKey`] from a SEC1-encoded public key. @@ -52,13 +45,13 @@ impl EncryptingKey { /// /// Returns an [`Error`] if the given affine point is the additive identity /// (a.k.a. point at infinity). - pub fn from_affine(affine: AffinePoint) -> Result { + pub fn from_affine(affine: crate::AffinePoint) -> Result { let public_key = PublicKey::from_affine(affine).map_err(|_| Error)?; Ok(Self::new(public_key)) } - /// Borrow the inner [`AffinePoint`] for this public key. - pub fn as_affine(&self) -> &AffinePoint { + /// Borrow the inner [`crate::AffinePoint`] for this public key. + pub fn as_affine(&self) -> &crate::AffinePoint { self.public_key.as_affine() } @@ -73,68 +66,94 @@ impl EncryptingKey { self.public_key.to_sec1_bytes() } - /// Encrypts a message using the encryption key. - /// - /// This method calculates the digest using the `Sm3` hash function and then performs encryption. - pub fn encrypt(&self, rng: &mut R, msg: &[u8]) -> Result> { - self.encrypt_digest::(rng, msg) + /// Encrypt into [`Cipher`] using the default digest algorithm [`Sm3`]. + #[cfg(all(feature = "getrandom", feature = "alloc"))] + pub fn encrypt<'a>(&self, msg: &[u8]) -> Result> { + use rand_core::OsRng; + self.encrypt_rng(&mut OsRng, msg) } - /// Encrypts a message and returns the result in ASN.1 format. - /// - /// This method calculates the digest using the `Sm3` hash function and performs encryption, - /// then encodes the result in ASN.1 format. - pub fn encrypt_der( + /// Encrypt into [`Cipher`] using the default digest algorithm [`Sm3`]. + /// Use a custom RNG. + #[cfg(feature = "alloc")] + pub fn encrypt_rng<'a, R: TryCryptoRng>( &self, rng: &mut R, msg: &[u8], - ) -> Result> { - self.encrypt_der_digest::(rng, msg) + ) -> Result> { + self.encrypt_digest_rng::<_, Sm3>(rng, msg) } - /// Encrypts a message using a specified digest algorithm. - pub fn encrypt_digest( + /// Encrypt into [`Cipher`] using the specified digest algorithm. + /// Use a custom RNG. + #[cfg(feature = "alloc")] + pub fn encrypt_digest_rng<'a, R: TryCryptoRng, D: Digest + FixedOutputReset>( &self, rng: &mut R, msg: &[u8], - ) -> Result> - where - D: 'static + Digest + DynDigest + Send + Sync, - { - let mut digest = D::new(); - encrypt(rng, &self.public_key, self.mode, &mut digest, msg) + ) -> Result> { + let mut c1 = ::AffinePoint::default(); + let mut c2 = vec![0; msg.len()]; + let mut c3 = Output::::default(); + self.encrypt_digest_rng_into::(rng, msg, &mut c1, &mut c2, &mut c3)?; + Ok(Cipher { + c1, + c2: c2.into(), + c3, + }) } - /// Encrypts a message using a specified digest algorithm and returns the result in ASN.1 format. - pub fn encrypt_der_digest( + /// Encrypt into [`Cipher`] using the default digest algorithm [`Sm3`]. + /// `c2_out_buf` is the output of c2. + /// Use a custom RNG. + pub fn encrypt_buf_rng<'a, R: TryCryptoRng>( &self, rng: &mut R, msg: &[u8], - ) -> Result> - where - D: 'static + Digest + DynDigest + Send + Sync, - { - let mut digest = D::new(); - let cipher = encrypt(rng, &self.public_key, self.mode, &mut digest, msg)?; - let digest_size = digest.output_size(); - let (_, cipher) = cipher.split_at(1); - let (x, cipher) = cipher.split_at(32); - let (y, cipher) = cipher.split_at(32); - let (digest, cipher) = match self.mode { - Mode::C1C2C3 => { - let (cipher, digest) = cipher.split_at(cipher.len() - digest_size); - (digest, cipher) - } - Mode::C1C3C2 => cipher.split_at(digest_size), - }; - Ok(Cipher { - x: Uint::from_be_slice(x), - y: Uint::from_be_slice(y), - digest, - cipher, - } - .to_der() - .map_err(elliptic_curve::pkcs8::Error::from)?) + c2_out_buf: &'a mut [u8], + ) -> Result> { + self.encrypt_buf_digest_rng::(rng, msg, c2_out_buf) + } + + /// Encrypt into [`Cipher`] using the specified digest algorithm. + /// `c2_out_buf` is the output of c2. + /// Use a custom RNG. + pub fn encrypt_buf_digest_rng<'a, R: TryCryptoRng, D: Digest + FixedOutputReset>( + &self, + rng: &mut R, + msg: &[u8], + c2_out_buf: &'a mut [u8], + ) -> Result> { + let mut c1 = ::AffinePoint::default(); + let mut c3 = Output::::default(); + let len = self.encrypt_digest_rng_into::(rng, msg, &mut c1, c2_out_buf, &mut c3)?; + let c2 = &c2_out_buf[..len]; + + #[cfg(feature = "alloc")] + let c2 = Cow::Borrowed(c2); + + Ok(Cipher { c1, c2, c3 }) + } + + /// Encrypt into the specified buffer using the specified digest algorithm. + /// * Note: buffer zones are prohibited from overlapping + /// * returns c2_out length + pub fn encrypt_digest_rng_into( + &self, + rng: &mut R, + msg: &[u8], + c1_out: &mut ::AffinePoint, + c2_out: &mut [u8], + c3_out: &mut Output, + ) -> Result { + encrypt_into::( + self.public_key.as_affine(), + rng, + msg, + c1_out, + c2_out, + c3_out, + ) } } @@ -144,28 +163,38 @@ impl From for EncryptingKey { } } -/// Encrypts a message using the specified public key, mode, and digest algorithm. -fn encrypt( +fn encrypt_into( + affine_point: &C::AffinePoint, rng: &mut R, - public_key: &PublicKey, - mode: Mode, - digest: &mut dyn DynDigest, msg: &[u8], -) -> Result> { - const N_BYTES: u32 = Sm2::ORDER.bits().div_ceil(8); - let mut c1 = vec![0; (N_BYTES * 2 + 1) as usize]; - let mut c2 = msg.to_owned(); - let mut hpb: AffinePoint; + c1_out: &mut C::AffinePoint, + c2_out: &mut [u8], + c3_out: &mut Output, +) -> Result +where + C: CurveArithmetic, + R: TryCryptoRng, + D: FixedOutputReset + Digest + Update, + C::AffinePoint: ToEncodedPoint, + C::FieldBytesSize: ModulusSize, +{ + if c2_out.len() < msg.len() { + return Err(Error); + } + let c2_out = &mut c2_out[..msg.len()]; + + let mut digest = D::new(); + let mut hpb: C::AffinePoint; loop { // A1: generate a random number π‘˜ ∈ [1, 𝑛 βˆ’ 1] with the random number generator - let k = Scalar::from_uint(next_k(rng, N_BYTES)?).unwrap(); + let k: C::Scalar = C::Scalar::from(NonZeroScalar::try_from_rng(rng).map_err(|_e| Error)?); // A2: compute point 𝐢1 = [π‘˜]𝐺 = (π‘₯1, 𝑦1) - let kg = ProjectivePoint::mul_by_generator(&k).to_affine(); + let kg: C::AffinePoint = C::ProjectivePoint::mul_by_generator(&k).into(); // A3: compute point 𝑆 = [β„Ž]𝑃𝐡 of the elliptic curve - let pb_point = public_key.as_affine(); - let s = *pb_point * Scalar::reduce(U256::from_u32(FieldElement::S)); + let scalar: C::Scalar = Reduce::::reduce(C::Uint::from(C::Scalar::S)); + let s: C::ProjectivePoint = C::ProjectivePoint::from(*affine_point) * scalar; if s.is_identity().into() { return Err(Error); } @@ -175,37 +204,23 @@ fn encrypt( // A5: compute 𝑑 = 𝐾𝐷𝐹(π‘₯2||𝑦2, π‘˜π‘™π‘’π‘›) // A6: compute 𝐢2 = 𝑀 βŠ• t - kdf(digest, hpb, &mut c2)?; + kdf::(&mut digest, hpb, msg, c2_out)?; // // If 𝑑 is an all-zero bit string, go to A1. // if all of t are 0, xor(c2) == c2 - if c2.iter().zip(msg).any(|(pre, cur)| pre != cur) { - let uncompress_kg = kg.to_encoded_point(false); - c1.copy_from_slice(uncompress_kg.as_bytes()); + if c2_out.iter().zip(msg).any(|(pre, cur)| pre != cur) { + *c1_out = kg; break; } } let encode_point = hpb.to_encoded_point(false); // A7: compute 𝐢3 = π»π‘Žπ‘ β„Ž(π‘₯2||𝑀||𝑦2) - let mut c3 = vec![0; digest.output_size()]; - digest.update(encode_point.x().ok_or(Error)?); - digest.update(msg); - digest.update(encode_point.y().ok_or(Error)?); - digest.finalize_into_reset(&mut c3).map_err(|_e| Error)?; - - // A8: output the ciphertext 𝐢 = 𝐢1||𝐢2||𝐢3. - Ok(match mode { - Mode::C1C2C3 => [c1.as_slice(), &c2, &c3].concat(), - Mode::C1C3C2 => [c1.as_slice(), &c3, &c2].concat(), - }) -} + Digest::reset(&mut digest); + Digest::update(&mut digest, encode_point.x().ok_or(Error)?); + Digest::update(&mut digest, msg); + Digest::update(&mut digest, encode_point.y().ok_or(Error)?); + Digest::finalize_into_reset(&mut digest, c3_out); -fn next_k(rng: &mut R, bit_length: u32) -> Result { - loop { - let k = U256::try_random_bits(rng, bit_length).map_err(|_| Error)?; - if !bool::from(k.is_zero()) && k < Sm2::ORDER { - return Ok(k); - } - } + Ok(c2_out.len()) } diff --git a/sm2/tests/sm2pke.rs b/sm2/tests/sm2pke.rs index aa40955a5..8f2e7469b 100644 --- a/sm2/tests/sm2pke.rs +++ b/sm2/tests/sm2pke.rs @@ -3,9 +3,13 @@ use elliptic_curve::{NonZeroScalar, ops::Reduce}; use hex_literal::hex; use proptest::prelude::*; -use rand_core::OsRng; -use sm2::{Scalar, Sm2, U256, pke::DecryptingKey}; +#[allow(unused_imports)] +use sm2::{ + Scalar, Sm2, U256, + pkcs8::der::{Decode, Encode}, + pke::{Cipher, DecryptingKey, Mode}, +}; // private key bytes const PRIVATE_KEY: [u8; 32] = @@ -23,25 +27,28 @@ const ASN1_CIPHER: [u8; 116] = hex!( #[test] fn decrypt_verify() { - assert_eq!( - DecryptingKey::new( - NonZeroScalar::::try_from(PRIVATE_KEY.as_ref() as &[u8]) - .unwrap() - .into() - ) - .decrypt(&CIPHER) - .unwrap(), - MSG - ); + let cipher = Cipher::from_slice(&CIPHER, Mode::default()).expect("Unable to resolve"); + let mut buf = vec![0; MSG.len()]; + + DecryptingKey::new( + NonZeroScalar::::try_from(PRIVATE_KEY.as_ref() as &[u8]) + .unwrap() + .into(), + ) + .decrypt_into(&cipher, &mut buf) + .unwrap(); + assert_eq!(buf, MSG) } #[test] fn decrypt_der_verify() { - let dk = DecryptingKey::new_with_mode( + let cipher = Cipher::from_der(&ASN1_CIPHER).expect("Unable to resolve"); + let dk = DecryptingKey::from_nonzero_scalar( NonZeroScalar::::try_from(PRIVATE_KEY.as_ref() as &[u8]).unwrap(), - sm2::pke::Mode::C1C2C3, ); - assert_eq!(dk.decrypt_der(&ASN1_CIPHER).unwrap(), MSG); + let mut buf = vec![0; MSG.len()]; + dk.decrypt_into(&cipher, &mut buf).unwrap(); + assert_eq!(buf, MSG); } prop_compose! { @@ -49,46 +56,38 @@ prop_compose! { loop { let scalar = >::reduce_bytes(&bytes.into()); if let Some(scalar) = Option::from(NonZeroScalar::new(scalar)) { - return DecryptingKey::from_nonzero_scalar(scalar).unwrap(); - } - } - } -} - -prop_compose! { - fn decrypting_key_c1c2c3()(bytes in any::<[u8; 32]>()) -> DecryptingKey { - loop { - let scalar = >::reduce_bytes(&bytes.into()); - if let Some(scalar) = Option::from(NonZeroScalar::new(scalar)) { - return DecryptingKey::new_with_mode(scalar, sm2::pke::Mode::C1C2C3); + return DecryptingKey::from_nonzero_scalar(scalar); } } } } +#[cfg(all(feature = "alloc", feature = "getrandom"))] proptest! { #[test] fn encrypt_and_decrypt_der(dk in decrypting_key()) { let ek = dk.encrypting_key(); - let cipher_bytes = ek.encrypt_der(&mut OsRng, MSG).unwrap(); - prop_assert!(dk.decrypt_der(&cipher_bytes).is_ok()); + let cipher = ek.encrypt(MSG).unwrap(); + let cipher_bytes = cipher.to_der().unwrap(); + let cipher = Cipher::from_der(&cipher_bytes).unwrap(); + prop_assert!(dk.decrypt(&cipher).is_ok()); } #[test] fn encrypt_and_decrypt(dk in decrypting_key()) { let ek = dk.encrypting_key(); - let cipher_bytes = ek.encrypt(&mut OsRng, MSG).unwrap(); - assert_eq!(dk.decrypt(&cipher_bytes).unwrap(), MSG); + let cipher = ek.encrypt(MSG).unwrap(); + let cipher_bytes = cipher.to_vec(Mode::C1C2C3); + let cipher = Cipher::from_slice(&cipher_bytes, Mode::C1C2C3).unwrap(); + assert_eq!(dk.decrypt(&cipher).unwrap(), MSG); } #[test] - fn encrypt_and_decrypt_mode(dk in decrypting_key_c1c2c3()) { + fn encrypt_and_decrypt_mode(dk in decrypting_key()) { let ek = dk.encrypting_key(); - let cipher_bytes = ek.encrypt(&mut OsRng, MSG).unwrap(); - assert_eq!( - dk.decrypt(&cipher_bytes) - .unwrap(), - MSG - ); + let cipher = ek.encrypt(MSG).unwrap(); + let cipher_bytes = cipher.to_vec(Mode::C1C3C2); + let cipher = Cipher::from_slice(&cipher_bytes, Mode::C1C3C2).unwrap(); + assert_eq!(dk.decrypt(&cipher).unwrap(), MSG); } }