diff --git a/secp256kfun/src/scalar.rs b/secp256kfun/src/scalar.rs index e38edcbe..c0ab64da 100644 --- a/secp256kfun/src/scalar.rs +++ b/secp256kfun/src/scalar.rs @@ -482,6 +482,30 @@ mod conversion_impls { #[cfg(feature = "std")] impl std::error::Error for ScalarTooLarge {} + /// Error returned when trying to convert a zero value into a NonZero scalar + pub struct ZeroScalar(PhantomData); + + impl core::fmt::Display for ZeroScalar { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "cannot convert zero {} to NonZero scalar", + type_name::() + ) + } + } + + impl core::fmt::Debug for ZeroScalar { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("ZeroScalar") + .field(&type_name::()) + .finish() + } + } + + #[cfg(feature = "std")] + impl std::error::Error for ZeroScalar {} + /// Implements `From<$t> for $scalar` **and** /// `TryFrom<$scalar> for $t` for every `$t` supplied. macro_rules! impl_scalar_conversions { @@ -497,6 +521,27 @@ mod conversion_impls { } } + impl TryFrom<$t> for Scalar { + type Error = ZeroScalar<$t>; + + fn try_from(value: $t) -> Result { + // big-endian integer → 32-byte array + let mut bytes = [0u8; 32]; + let int_bytes = value.to_be_bytes(); + bytes[32 - int_bytes.len() ..].copy_from_slice(&int_bytes); + let scalar = Scalar::::from_bytes(bytes).unwrap(); + + // Check if value is zero + if value == 0 { + Err(ZeroScalar(PhantomData)) + } else { + Ok(scalar.non_zero().unwrap()) + } + } + } + + + impl TryFrom> for $t { type Error = ScalarTooLarge<$t>; @@ -759,4 +804,34 @@ mod test { assert!(Scalar::::from(41u32) < Scalar::::from(42u32)); assert!(Scalar::::from(42u32) <= Scalar::::from(42u32)); } + + #[test] + fn try_from_zero_to_nonzero() { + use core::convert::TryFrom; + + // Test that converting zero to NonZero fails + let result = Scalar::::try_from(0u32); + assert!(result.is_err()); + + // Test that converting non-zero to NonZero succeeds + let result = Scalar::::try_from(42u32); + assert!(result.is_ok()); + assert_eq!( + result.unwrap(), + Scalar::::from(42u32).non_zero().unwrap() + ); + + // Test with different integer types + assert!(Scalar::::try_from(0u8).is_err()); + assert!(Scalar::::try_from(0u16).is_err()); + assert!(Scalar::::try_from(0u64).is_err()); + + assert!(Scalar::::try_from(1u8).is_ok()); + assert!(Scalar::::try_from(1u16).is_ok()); + assert!(Scalar::::try_from(1u64).is_ok()); + + // Test that infallible From still works for Zero + let _zero_scalar: Scalar = 0u32.into(); + let _nonzero_scalar: Scalar = 42u32.into(); + } }