Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion mithril-stm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ rand = { version = "0.9.2", optional = true }
rand_core = { workspace = true, features = ["std"] }
rayon = { workspace = true }
serde = { workspace = true }
sha2 = "0.10.9"
subtle = { version = "2.5.0", optional = true }
thiserror = { workspace = true }

Expand Down
9 changes: 5 additions & 4 deletions mithril-stm/benches/schnorr_sig.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use midnight_curves::Fq as JubjubBase;
use rand_chacha::ChaCha20Rng;
use rand_core::{RngCore, SeedableRng};

use mithril_stm::{SchnorrSigningKey, SchnorrVerificationKey};
use mithril_stm::{BaseFieldElement, SchnorrSigningKey, SchnorrVerificationKey};

fn midnight_poseidon_hash(c: &mut Criterion, nr_sigs: usize) {
let mut group = c.benchmark_group("Schnorr".to_string());
Expand Down Expand Up @@ -47,13 +47,14 @@ fn sign_and_verify(c: &mut Criterion, nr_sigs: usize) {

let mut msg = [0u8; 32];
rng.fill_bytes(&mut msg);
let base_input = BaseFieldElement::from(msg.as_slice());
let mut mvks = Vec::new();
let mut msks = Vec::new();
let mut sigs = Vec::new();
for _ in 0..nr_sigs {
let sk = SchnorrSigningKey::generate(&mut rng).unwrap();
let vk = SchnorrVerificationKey::new_from_signing_key(sk.clone()).unwrap();
let sig = sk.sign(&msg, &mut rng_sig).unwrap();
let sig = sk.sign(&[base_input], &mut rng_sig).unwrap();
sigs.push(sig);
mvks.push(vk);
msks.push(sk);
Expand All @@ -62,15 +63,15 @@ fn sign_and_verify(c: &mut Criterion, nr_sigs: usize) {
group.bench_function(BenchmarkId::new("Signature", nr_sigs), |b| {
b.iter(|| {
for sk in msks.iter() {
let _sig = sk.sign(&msg, &mut rng_sig).unwrap();
let _sig = sk.sign(&[base_input], &mut rng_sig).unwrap();
}
})
});

group.bench_function(BenchmarkId::new("Verification", nr_sigs), |b| {
b.iter(|| {
for (vk, sig) in mvks.iter().zip(sigs.iter()) {
assert!(sig.verify(&msg, vk).is_ok());
assert!(sig.verify(&[base_input], vk).is_ok());
}
})
});
Expand Down
2 changes: 1 addition & 1 deletion mithril-stm/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ use std::fmt::Debug;

#[cfg(feature = "benchmark-internals")]
pub use signature_scheme::{
BlsProofOfPossession, BlsSignature, BlsSigningKey, BlsVerificationKey,
BaseFieldElement, BlsProofOfPossession, BlsSignature, BlsSigningKey, BlsVerificationKey,
BlsVerificationKeyProofOfPossession,
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use midnight_circuits::{
use midnight_curves::{
EDWARDS_D, Fq as JubjubBase, JubjubAffine as JubjubAffinePoint, JubjubExtended, JubjubSubgroup,
};
use sha2::{Digest, Sha256};
use std::ops::{Add, Mul};

use crate::{StmResult, signature_scheme::UniqueSchnorrSignatureError};
Expand Down Expand Up @@ -63,18 +62,10 @@ impl ProjectivePoint {
/// Hashes input bytes to a projective point on the Jubjub curve
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@damrobi
The doc comment on hash_to_projective_point is now outdated.

/// For now we leave the SHA call in the function since the SHA
/// function is not used anywhere else. This might change in the future.
pub(crate) fn hash_to_projective_point(input: &[u8]) -> StmResult<Self> {
let mut hash = Sha256::new();
hash.update(input);
let mut hashed_input = [0u8; 32];
hashed_input.copy_from_slice(&hash.finalize());
let scalar_input = JubjubBase::from_raw([
u64::from_le_bytes(hashed_input[0..8].try_into()?),
u64::from_le_bytes(hashed_input[8..16].try_into()?),
u64::from_le_bytes(hashed_input[16..24].try_into()?),
u64::from_le_bytes(hashed_input[24..32].try_into()?),
]);
let point = JubjubHashToCurveGadget::hash_to_curve(&[scalar_input]);
pub(crate) fn hash_to_projective_point(input: &[BaseFieldElement]) -> StmResult<Self> {
let point = JubjubHashToCurveGadget::hash_to_curve(
&input.iter().map(|elem| elem.0).collect::<Vec<JubjubBase>>(),
);
Ok(ProjectivePoint(JubjubExtended::from(point)))
}

Expand Down Expand Up @@ -263,13 +254,14 @@ mod tests {
use super::*;

const GOLDEN_BYTES: &[u8] = &[
15, 44, 110, 49, 102, 14, 172, 174, 230, 224, 30, 24, 129, 48, 80, 106, 88, 47, 98,
132, 180, 50, 8, 88, 48, 33, 149, 193, 129, 151, 209, 239,
238, 7, 23, 98, 52, 212, 110, 3, 226, 113, 172, 10, 74, 173, 92, 250, 224, 43, 81, 19,
173, 191, 35, 38, 127, 247, 107, 15, 230, 154, 198, 241,
];

fn golden_value() -> ProjectivePoint {
let msg = [255u8; 32];
ProjectivePoint::hash_to_projective_point(&msg).unwrap()
let base_input = BaseFieldElement::from(&msg[0..32]);
ProjectivePoint::hash_to_projective_point(&[base_input]).unwrap()
}

#[test]
Expand All @@ -288,24 +280,28 @@ mod tests {
let mut rng = ChaCha20Rng::from_seed([1u8; 32]);
let scalar1 = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
let scalar2 = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
let msg = b"test_point";
let base_input = BaseFieldElement::from(msg.as_slice());
let point = ProjectivePoint::hash_to_projective_point(&[base_input]).unwrap();

let point = ProjectivePoint::hash_to_projective_point(b"test_point").unwrap();
let p1 = scalar1 * point;
let p2 = scalar2 * point;

let result = p1 + p2;

let bytes = result.to_bytes();
let recovered = ProjectivePoint::from_bytes(&bytes).unwrap();

assert_eq!(result, recovered);
}

#[test]
fn test_add_identity() {
let point = ProjectivePoint::hash_to_projective_point(b"test_point").unwrap();
let msg = b"test_point";
let base_input = BaseFieldElement::from(msg.as_slice());
let point = ProjectivePoint::hash_to_projective_point(&[base_input]).unwrap();
let identity = ProjectivePoint(JubjubExtended::identity());

let result = point + identity;

assert_eq!(result, point);
}

Expand All @@ -314,8 +310,10 @@ mod tests {
let mut rng = ChaCha20Rng::from_seed([2u8; 32]);
let scalar1 = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
let scalar2 = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
let msg = b"test_point";
let base_input = BaseFieldElement::from(msg.as_slice());
let point = ProjectivePoint::hash_to_projective_point(&[base_input]).unwrap();

let point = ProjectivePoint::hash_to_projective_point(b"test_point").unwrap();
let p1 = scalar1 * point;
let p2 = scalar2 * point;

Expand All @@ -328,8 +326,10 @@ mod tests {
let scalar1 = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
let scalar2 = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
let scalar3 = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
let msg = b"test_point";
let base_input = BaseFieldElement::from(msg.as_slice());
let point = ProjectivePoint::hash_to_projective_point(&[base_input]).unwrap();

let point = ProjectivePoint::hash_to_projective_point(b"test_point").unwrap();
let p1 = scalar1 * point;
let p2 = scalar2 * point;
let p3 = scalar3 * point;
Expand All @@ -341,21 +341,27 @@ mod tests {
fn test_scalar_mul() {
let mut rng = ChaCha20Rng::from_seed([4u8; 32]);
let scalar = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
let point = ProjectivePoint::hash_to_projective_point(b"test_point").unwrap();
let msg = b"test_point";
let base_input = BaseFieldElement::from(msg.as_slice());
let point = ProjectivePoint::hash_to_projective_point(&[base_input]).unwrap();

let result = scalar * point;

let bytes = result.to_bytes();
let recovered = ProjectivePoint::from_bytes(&bytes).unwrap();

assert_eq!(result, recovered);
}

#[test]
fn test_scalar_mul_distributivity_over_point_addition() {
let mut rng = ChaCha20Rng::from_seed([5u8; 32]);
let scalar = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
let point1 = ProjectivePoint::hash_to_projective_point(b"test_point_1").unwrap();
let point2 = ProjectivePoint::hash_to_projective_point(b"test_point_2").unwrap();
let msg1 = b"test_point_1".to_vec();
let base_input1 = BaseFieldElement::from(msg1.as_slice());
let msg2 = b"test_point_2".to_vec();
let base_input2 = BaseFieldElement::from(msg2.as_slice());
let point1 = ProjectivePoint::hash_to_projective_point(&[base_input1]).unwrap();
let point2 = ProjectivePoint::hash_to_projective_point(&[base_input2]).unwrap();

let left = scalar * (point1 + point2);
let right = (scalar * point1) + (scalar * point2);
Expand All @@ -368,7 +374,9 @@ mod tests {
let mut rng = ChaCha20Rng::from_seed([6u8; 32]);
let scalar1 = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
let scalar2 = ScalarFieldElement::new_random_nonzero_scalar(&mut rng).unwrap();
let point = ProjectivePoint::hash_to_projective_point(b"test_point").unwrap();
let msg = b"test_point";
let base_input = BaseFieldElement::from(msg.as_slice());
let point = ProjectivePoint::hash_to_projective_point(&[base_input]).unwrap();

let combined_scalar = scalar1 * scalar2;
let left = combined_scalar * point;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,28 @@ impl BaseFieldElement {
}
}

#[cfg(any(test, feature = "benchmark-internals"))]
// Only uses the 32 first bytes of the input or pads it if necessary
// To use only in tests or benchmarks
impl From<&[u8]> for BaseFieldElement {
fn from(value: &[u8]) -> Self {
let bytes: Vec<u8> = if value.len() < 32 {
let mut v = vec![0u8; value.len()];
v.copy_from_slice(value);
v.resize(32, 0);
v
} else {
value.to_vec()
};
BaseFieldElement(JubjubBase::from_raw([
u64::from_le_bytes(bytes[0..8].try_into().unwrap()),
u64::from_le_bytes(bytes[8..16].try_into().unwrap()),
u64::from_le_bytes(bytes[16..24].try_into().unwrap()),
u64::from_le_bytes(bytes[24..32].try_into().unwrap()),
]))
}
}

impl Add for BaseFieldElement {
type Output = BaseFieldElement;

Expand Down
19 changes: 11 additions & 8 deletions mithril-stm/src/signature_scheme/unique_schnorr_signature/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ mod signing_key;
mod verification_key;

pub use error::*;
pub use jubjub::BaseFieldElement;
pub(crate) use jubjub::*;
pub use signature::*;
pub use signing_key::*;
Expand All @@ -25,8 +26,8 @@ mod tests {
use rand_core::SeedableRng;

use crate::signature_scheme::{
PrimeOrderProjectivePoint, ScalarFieldElement, SchnorrSigningKey, SchnorrVerificationKey,
UniqueSchnorrSignature, UniqueSchnorrSignatureError,
BaseFieldElement, PrimeOrderProjectivePoint, ScalarFieldElement, SchnorrSigningKey,
SchnorrVerificationKey, UniqueSchnorrSignature, UniqueSchnorrSignatureError,
};

proptest! {
Expand Down Expand Up @@ -60,15 +61,15 @@ mod tests {
let sk_result = SchnorrSigningKey::generate(&mut ChaCha20Rng::from_seed(seed));
assert!(sk_result.is_ok(), "Signing key generation failed");
let sk = sk_result.unwrap();

let vk = SchnorrVerificationKey::new_from_signing_key(sk.clone()).unwrap();
let base_input = BaseFieldElement::from(msg.as_slice());

let sig_result = sk.sign(&msg, &mut ChaCha20Rng::from_seed(seed));
let sig_result = sk.sign(&[base_input], &mut ChaCha20Rng::from_seed(seed));
assert!(sig_result.is_ok(), "Signature generation failed");

let sig = sig_result.unwrap();

assert!(sig.verify(&msg, &vk).is_ok(), "Verification failed.");
assert!(sig.verify(&[base_input], &vk).is_ok(), "Verification failed.");
}

#[test]
Expand All @@ -77,9 +78,10 @@ mod tests {
let sk1 = SchnorrSigningKey::generate(&mut rng).unwrap();
let vk1 = SchnorrVerificationKey::new_from_signing_key(sk1).unwrap();
let sk2 = SchnorrSigningKey::generate(&mut rng).unwrap();
let fake_sig = sk2.sign(&msg, &mut rng).unwrap();
let base_input = BaseFieldElement::from(msg.as_slice());
let fake_sig = sk2.sign(&[base_input], &mut rng).unwrap();

let error = fake_sig.verify(&msg, &vk1).expect_err("Fake signature should not be verified");
let error = fake_sig.verify(&[base_input], &vk1).expect_err("Fake signature should not be verified");

assert!(
matches!(
Expand Down Expand Up @@ -163,7 +165,8 @@ mod tests {
fn signature_to_from_bytes(msg in prop::collection::vec(any::<u8>(), 1..128), seed in any::<[u8;32]>()) {
let mut rng = ChaCha20Rng::from_seed(seed);
let sk = SchnorrSigningKey::generate(&mut rng).unwrap();
let signature = sk.sign(&msg, &mut ChaCha20Rng::from_seed(seed)).unwrap();
let base_input = BaseFieldElement::from(msg.as_slice());
let signature = sk.sign(&[base_input], &mut ChaCha20Rng::from_seed(seed)).unwrap();
let signature_bytes = signature.to_bytes();

// Valid conversion
Expand Down
Loading