Skip to content

Commit dc4d84a

Browse files
fix(ecdsa): match upstream behavior to error on overflowing add (#1742)
In the case `Is_x_reduced()` is true, the addition of `C::ORDER` to `r` is done in `Uint` and the upstream `ecdsa` crate errors if the addition overflow `Uint`'s fixed size. To match upstream behavior, we switch back to using the same `Uint` logic. Also took some additional care in handling conversions when the `Scalar<C>` num bytes differs from `Coordinate<C>` num bytes. Added a synthetic test case that failed before and now passes.
1 parent 63d21e8 commit dc4d84a

File tree

10 files changed

+216
-47
lines changed

10 files changed

+216
-47
lines changed

Cargo.lock

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

extensions/ecc/guest/src/ecdsa.rs

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use alloc::vec::Vec;
2-
use core::ops::{Add, AddAssign, Mul};
2+
use core::ops::{Add, Mul};
33

44
use ecdsa_core::{
55
self,
@@ -12,10 +12,11 @@ use ecdsa_core::{
1212
EncodedPoint, Error, RecoveryId, Result, Signature, SignatureSize,
1313
};
1414
use elliptic_curve::{
15-
generic_array::ArrayLength,
15+
bigint::CheckedAdd,
16+
generic_array::{typenum::Unsigned, ArrayLength},
1617
sec1::{FromEncodedPoint, ModulusSize, Tag, ToEncodedPoint},
1718
subtle::{Choice, ConditionallySelectable, CtOption},
18-
CurveArithmetic, FieldBytesSize, PrimeCurve,
19+
CurveArithmetic, FieldBytes, FieldBytesEncoding, FieldBytesSize, PrimeCurve,
1920
};
2021
use openvm_algebra_guest::{DivUnsafe, IntMod, Reduce};
2122

@@ -378,12 +379,14 @@ where
378379
Coordinate<C>: IntMod,
379380
C::Scalar: IntMod + Reduce,
380381
{
382+
/// ## Assumption
383+
/// To use this implementation, the `Signature<C>`, `Coordinate<C>`, and `FieldBytes<C>` should
384+
/// all be encoded in big endian bytes. The implementation also assumes that
385+
/// `Scalar::<C>::NUM_LIMBS <= FieldBytesSize::<C>::USIZE <= Coordinate::<C>::NUM_LIMBS`.
386+
///
381387
/// Ref: <https://github.com/RustCrypto/signatures/blob/85c984bcc9927c2ce70c7e15cbfe9c6936dd3521/ecdsa/src/recovery.rs#L297>
382388
///
383389
/// Recovery does not require additional signature verification: <https://github.com/RustCrypto/signatures/pull/831>
384-
///
385-
/// ## Panics
386-
/// If the signature is invalid or public key cannot be recovered from the given input.
387390
#[allow(non_snake_case)]
388391
pub fn recover_from_prehash_noverify(
389392
prehash: &[u8],
@@ -411,16 +414,37 @@ where
411414
}
412415

413416
// Perf: don't use bits2field from ::ecdsa
414-
let z = Scalar::<C>::from_be_bytes(bits2field::<C>(prehash).unwrap().as_ref());
417+
let prehash_bytes = bits2field::<C>(prehash)?;
418+
// If prehash is longer than Scalar::NUM_LIMBS, take leftmost bytes
419+
let trim = prehash_bytes.len().saturating_sub(Scalar::<C>::NUM_LIMBS);
420+
// from_be_bytes still works if len < Scalar::NUM_LIMBS
421+
// we don't need to reduce because IntMod is up to modular equivalence
422+
let z = Scalar::<C>::from_be_bytes(&prehash_bytes[..prehash_bytes.len() - trim]);
415423

416424
// `r` is in the Scalar field, we now possibly add C::ORDER to it to get `x`
417425
// in the Coordinate field.
418-
let mut x = Coordinate::<C>::from_le_bytes(r.as_le_bytes());
426+
// We take some extra care for the case when FieldBytesSize<C> may be larger than
427+
// Scalar::<C>::NUM_LIMBS.
428+
let mut r_bytes = {
429+
let mut r_bytes = FieldBytes::<C>::default();
430+
assert!(FieldBytesSize::<C>::USIZE >= Scalar::<C>::NUM_LIMBS);
431+
let offset = r_bytes.len().saturating_sub(r_be.len());
432+
r_bytes[offset..].copy_from_slice(r_be);
433+
r_bytes
434+
};
419435
if recovery_id.is_x_reduced() {
420-
// Copy from slice in case Coordinate has more bytes than Scalar
421-
let order = Coordinate::<C>::from_le_bytes(Scalar::<C>::MODULUS.as_ref());
422-
x.add_assign(order);
436+
match Option::<C::Uint>::from(
437+
C::Uint::decode_field_bytes(&r_bytes).checked_add(&C::ORDER),
438+
) {
439+
Some(restored) => r_bytes = restored.encode_field_bytes(),
440+
// No reduction should happen here if r was reduced
441+
None => {
442+
return Err(Error::new());
443+
}
444+
};
423445
}
446+
assert!(FieldBytesSize::<C>::USIZE <= Coordinate::<C>::NUM_LIMBS);
447+
let x = Coordinate::<C>::from_be_bytes(&r_bytes);
424448
let rec_id = recovery_id.to_byte();
425449
// The point R decompressed from x-coordinate `r`
426450
let R: C::Point = FromCompressed::decompress(x, &rec_id).ok_or_else(Error::new)?;
@@ -463,8 +487,12 @@ where
463487
}
464488

465489
// Perf: don't use bits2field from ::ecdsa
466-
let z =
467-
<C as IntrinsicCurve>::Scalar::from_be_bytes(bits2field::<C>(prehash).unwrap().as_ref());
490+
let prehash_bytes = bits2field::<C>(prehash)?;
491+
// If prehash is longer than Scalar::NUM_LIMBS, take leftmost bytes
492+
let trim = prehash_bytes.len().saturating_sub(Scalar::<C>::NUM_LIMBS);
493+
// from_be_bytes still works if len < Scalar::NUM_LIMBS
494+
// we don't need to reduce because IntMod is up to modular equivalence
495+
let z = Scalar::<C>::from_be_bytes(&prehash_bytes[..prehash_bytes.len() - trim]);
468496

469497
let u1 = z.div_unsafe(&s);
470498
let u2 = (&r).div_unsafe(&s);

extensions/ecc/tests/Cargo.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@ repository.workspace = true
1111
openvm-stark-sdk.workspace = true
1212
openvm-circuit = { workspace = true, features = ["test-utils"] }
1313
openvm-transpiler.workspace = true
14-
openvm-algebra-circuit.workspace = true
1514
openvm-algebra-transpiler.workspace = true
1615
openvm-ecc-transpiler.workspace = true
1716
openvm-ecc-circuit.workspace = true
1817
openvm-rv32im-transpiler.workspace = true
19-
openvm-keccak256-transpiler.workspace = true
2018
openvm-toolchain-tests = { path = "../../../crates/toolchain/tests" }
2119
openvm-sdk.workspace = true
20+
serde.workspace = true
21+
serde_with.workspace = true
22+
toml.workspace = true
2223
eyre.workspace = true
2324
hex-literal.workspace = true
2425
num-bigint.workspace = true

extensions/ecc/tests/programs/Cargo.toml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,13 @@ serde = { version = "1.0", default-features = false, features = [
2727
"alloc",
2828
"derive",
2929
] }
30+
serde_with = { version = "3.13.0", default-features = false, features = [
31+
"alloc",
32+
"macros",
33+
] }
3034
hex = { version = "0.4.3", default-features = false, features = ["alloc"] }
3135
hex-literal = { version = "0.4.1", default-features = false }
36+
ecdsa-core = { version = "0.16.9", package = "ecdsa", default-features = false }
3237

3338
[target.'cfg(not(target_os = "zkvm"))'.dependencies]
3439
num-bigint = "0.4.6"
@@ -64,6 +69,10 @@ required-features = ["k256"]
6469
name = "ecdsa"
6570
required-features = ["k256"]
6671

72+
[[example]]
73+
name = "ecdsa_recover"
74+
required-features = ["p256"]
75+
6776
[[example]]
6877
name = "invalid_setup"
6978
required-features = ["k256", "p256"]
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
#![cfg_attr(not(feature = "std"), no_main)]
2+
#![cfg_attr(not(feature = "std"), no_std)]
3+
4+
extern crate alloc;
5+
6+
use alloc::vec::Vec;
7+
8+
use ecdsa_core::RecoveryId;
9+
use openvm::io::read;
10+
#[allow(unused_imports)]
11+
use openvm_p256::{
12+
ecdsa::{Signature, VerifyingKey},
13+
P256Coord, P256Point,
14+
};
15+
use serde::{Deserialize, Serialize};
16+
use serde_with::{serde_as, Bytes};
17+
18+
openvm::entry!(main);
19+
20+
openvm::init!("openvm_init_ecdsa_recover_p256.rs");
21+
22+
/// Signature recovery test vectors
23+
#[repr(C)]
24+
#[serde_as]
25+
#[derive(Serialize, Deserialize)]
26+
struct RecoveryTestVector {
27+
#[serde_as(as = "Bytes")]
28+
pk: [u8; 33],
29+
#[serde_as(as = "Bytes")]
30+
msg: [u8; 32],
31+
#[serde_as(as = "Bytes")]
32+
sig: [u8; 64],
33+
recid: u8,
34+
ok: bool,
35+
}
36+
37+
pub fn main() {
38+
let test_vectors: Vec<RecoveryTestVector> = read();
39+
for vector in test_vectors {
40+
let sig = match Signature::try_from(vector.sig.as_slice()) {
41+
Ok(_v) => _v,
42+
Err(_) => {
43+
assert_eq!(vector.ok, false);
44+
continue;
45+
}
46+
};
47+
let recid = match RecoveryId::from_byte(vector.recid) {
48+
Some(_v) => _v,
49+
None => {
50+
assert_eq!(vector.ok, false);
51+
continue;
52+
}
53+
};
54+
let _ = match VerifyingKey::recover_from_prehash(&vector.msg, &sig, recid) {
55+
Ok(_v) => _v,
56+
Err(_) => {
57+
openvm::io::println("Recovery failed");
58+
assert_eq!(vector.ok, false);
59+
continue;
60+
}
61+
};
62+
// If reached here, recovery succeeded
63+
assert_eq!(vector.ok, true);
64+
}
65+
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
// This file is automatically generated by cargo openvm. Do not rename or edit.
2+
openvm_algebra_guest::moduli_macros::moduli_init! { "115792089210356248762697446949407573530086143415290314195533631308867097853951", "115792089210356248762697446949407573529996955224135760342422259061068512044369" }
3+
openvm_ecc_guest::sw_macros::sw_init! { P256Point }
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
[app_vm_config.rv32i]
2+
[app_vm_config.rv32m]
3+
[app_vm_config.io]
4+
[app_vm_config.keccak]
5+
6+
[app_vm_config.modular]
7+
supported_moduli = [
8+
"115792089237316195423570985008687907853269984665640564039457584007908834671663",
9+
"115792089237316195423570985008687907852837564279074904382605163141518161494337",
10+
]
11+
12+
[[app_vm_config.ecc.supported_curves]]
13+
struct_name = "Secp256k1Point"
14+
modulus = "115792089237316195423570985008687907853269984665640564039457584007908834671663"
15+
scalar = "115792089237316195423570985008687907852837564279074904382605163141518161494337"
16+
a = "0"
17+
b = "7"
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
[app_vm_config.rv32i]
2+
[app_vm_config.rv32m]
3+
[app_vm_config.io]
4+
[app_vm_config.modular]
5+
supported_moduli = [
6+
"115792089210356248762697446949407573530086143415290314195533631308867097853951",
7+
"115792089210356248762697446949407573529996955224135760342422259061068512044369",
8+
]
9+
10+
[[app_vm_config.ecc.supported_curves]]
11+
struct_name = "P256Point"
12+
modulus = "115792089210356248762697446949407573530086143415290314195533631308867097853951"
13+
scalar = "115792089210356248762697446949407573529996955224135760342422259061068512044369"
14+
a = "115792089210356248762697446949407573530086143415290314195533631308867097853948"
15+
b = "41058363725152142129326129780047268409114441015993725554835256314039467401291"

extensions/ecc/tests/src/lib.rs

Lines changed: 33 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,34 @@
1+
mod test_vectors;
2+
13
#[cfg(test)]
24
mod tests {
35
use core::str::FromStr;
46

57
use eyre::Result;
68
use hex_literal::hex;
79
use num_bigint::BigUint;
8-
use openvm_algebra_circuit::ModularExtension;
910
use openvm_algebra_transpiler::ModularTranspilerExtension;
1011
use openvm_circuit::{
11-
arch::{instructions::exe::VmExe, SystemConfig},
12+
arch::instructions::exe::VmExe,
1213
utils::{air_test, air_test_with_min_segments},
1314
};
14-
use openvm_ecc_circuit::{
15-
CurveConfig, Rv32WeierstrassConfig, WeierstrassExtension, P256_CONFIG, SECP256K1_CONFIG,
16-
};
15+
use openvm_ecc_circuit::{CurveConfig, Rv32WeierstrassConfig, P256_CONFIG, SECP256K1_CONFIG};
1716
use openvm_ecc_transpiler::EccTranspilerExtension;
18-
use openvm_keccak256_transpiler::Keccak256TranspilerExtension;
1917
use openvm_rv32im_transpiler::{
2018
Rv32ITranspilerExtension, Rv32IoTranspilerExtension, Rv32MTranspilerExtension,
2119
};
22-
use openvm_sdk::config::SdkVmConfig;
20+
use openvm_sdk::{
21+
config::{AppConfig, SdkVmConfig},
22+
StdIn,
23+
};
2324
use openvm_stark_backend::p3_field::FieldAlgebra;
2425
use openvm_stark_sdk::{openvm_stark_backend, p3_baby_bear::BabyBear};
2526
use openvm_toolchain_tests::{
2627
build_example_program_at_path_with_features, get_programs_dir, NoInitFile,
2728
};
2829
use openvm_transpiler::{transpiler::Transpiler, FromElf};
30+
31+
use crate::test_vectors::P256_RECOVERY_TEST_VECTORS;
2932
type F = BabyBear;
3033

3134
#[test]
@@ -163,39 +166,39 @@ mod tests {
163166

164167
#[test]
165168
fn test_ecdsa() -> Result<()> {
166-
let config = SdkVmConfig::builder()
167-
.system(SystemConfig::default().with_continuations().into())
168-
.rv32i(Default::default())
169-
.rv32m(Default::default())
170-
.io(Default::default())
171-
.modular(ModularExtension::new(vec![
172-
SECP256K1_CONFIG.modulus.clone(),
173-
SECP256K1_CONFIG.scalar.clone(),
174-
]))
175-
.keccak(Default::default())
176-
.ecc(WeierstrassExtension::new(vec![SECP256K1_CONFIG.clone()]))
177-
.build();
178-
169+
let config = toml::from_str::<AppConfig<SdkVmConfig>>(include_str!(
170+
"../programs/openvm_k256_keccak.toml"
171+
))?
172+
.app_vm_config;
179173
let elf = build_example_program_at_path_with_features(
180174
get_programs_dir!(),
181175
"ecdsa",
182176
["k256"],
183177
&config,
184178
)?;
185-
let openvm_exe = VmExe::from_elf(
186-
elf,
187-
Transpiler::<F>::default()
188-
.with_extension(Rv32ITranspilerExtension)
189-
.with_extension(Rv32MTranspilerExtension)
190-
.with_extension(Rv32IoTranspilerExtension)
191-
.with_extension(Keccak256TranspilerExtension)
192-
.with_extension(EccTranspilerExtension)
193-
.with_extension(ModularTranspilerExtension),
194-
)?;
179+
let openvm_exe = VmExe::from_elf(elf, config.transpiler())?;
195180
air_test(config, openvm_exe);
196181
Ok(())
197182
}
198183

184+
#[test]
185+
fn test_p256_ecdsa_recover() -> Result<()> {
186+
let config =
187+
toml::from_str::<AppConfig<SdkVmConfig>>(include_str!("../programs/openvm_p256.toml"))?
188+
.app_vm_config;
189+
let elf = build_example_program_at_path_with_features(
190+
get_programs_dir!(),
191+
"ecdsa_recover",
192+
["p256"],
193+
&config,
194+
)?;
195+
let openvm_exe = VmExe::from_elf(elf, config.transpiler())?;
196+
let mut input = StdIn::default();
197+
input.write(&P256_RECOVERY_TEST_VECTORS.to_vec());
198+
air_test_with_min_segments(config, openvm_exe, input, 1);
199+
Ok(())
200+
}
201+
199202
#[test]
200203
#[should_panic]
201204
fn test_invalid_setup() {
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
use hex_literal::hex;
2+
use serde::{Deserialize, Serialize};
3+
use serde_with::{serde_as, Bytes};
4+
5+
#[repr(C)]
6+
#[serde_as]
7+
#[derive(Clone, Debug, Serialize, Deserialize)]
8+
pub struct RecoveryTestVector {
9+
#[serde_as(as = "Bytes")]
10+
pub pk: [u8; 33],
11+
#[serde_as(as = "Bytes")]
12+
pub msg: [u8; 32],
13+
#[serde_as(as = "Bytes")]
14+
pub sig: [u8; 64],
15+
pub recid: u8,
16+
pub ok: bool,
17+
}
18+
19+
#[allow(dead_code)]
20+
pub const P256_RECOVERY_TEST_VECTORS: &[RecoveryTestVector] = &[RecoveryTestVector {
21+
pk: hex!("020000000000000000000000000000000000000000000000000000000000000000"),
22+
msg: hex!("00000000000000000000FFFFFFFF03030BFFFFFFFFFF030BFFFFFFFFFFFFF8FC"),
23+
sig: hex!("00000000ffffffff00000000000000004319055258e8617b0c46353d039cdaaf0000000000000000000000000000000000000000000000000000000000000001"
24+
),
25+
recid: 2,
26+
ok: false,
27+
}];

0 commit comments

Comments
 (0)