Skip to content

Commit a027803

Browse files
Avaneesh-axiomjonathanpwang
authored andcommitted
Parse struct name in sw_init as string
1 parent 94c0a9a commit a027803

25 files changed

+61
-84
lines changed

extensions/ecc/circuit/src/weierstrass_extension.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ impl WeierstrassExtension {
8484
let supported_curves = self
8585
.supported_curves
8686
.iter()
87-
.map(|curve_config| curve_config.struct_name.to_string())
87+
.map(|curve_config| format!("\"{}\"", curve_config.struct_name))
8888
.collect::<Vec<String>>()
8989
.join(", ");
9090

extensions/ecc/sw-macros/src/lib.rs

Lines changed: 38 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ use proc_macro::TokenStream;
55
use quote::format_ident;
66
use syn::{
77
parse::{Parse, ParseStream},
8-
parse_macro_input, Expr, ExprPath, Path, Token,
8+
parse_macro_input, ExprPath, Token,
9+
LitStr,
910
};
1011

1112
/// This macro generates the code to setup the elliptic curve for a given modular type. Also it
@@ -28,9 +29,8 @@ pub fn sw_declare(input: TokenStream) -> TokenStream {
2829
let span = proc_macro::Span::call_site();
2930

3031
for item in items.into_iter() {
31-
let struct_name = item.name.to_string();
32-
let struct_name = syn::Ident::new(&struct_name, span.into());
33-
let struct_path: syn::Path = syn::parse_quote!(#struct_name);
32+
let struct_name_str = item.name.to_string();
33+
let struct_name = syn::Ident::new(&struct_name_str, span.into());
3434
let mut intmod_type: Option<syn::Path> = None;
3535
let mut const_a: Option<syn::Expr> = None;
3636
let mut const_b: Option<syn::Expr> = None;
@@ -74,12 +74,7 @@ pub fn sw_declare(input: TokenStream) -> TokenStream {
7474
&format!(
7575
"{}_{}",
7676
stringify!($name),
77-
struct_path
78-
.segments
79-
.iter()
80-
.map(|x| x.ident.to_string())
81-
.collect::<Vec<_>>()
82-
.join("_")
77+
struct_name_str
8378
),
8479
span.into(),
8580
);
@@ -89,13 +84,13 @@ pub fn sw_declare(input: TokenStream) -> TokenStream {
8984
create_extern_func!(sw_double_extern_func);
9085
create_extern_func!(sw_setup_extern_func);
9186

92-
let group_ops_mod_name = format_ident!("{}_ops", struct_name.to_string().to_lowercase());
87+
let group_ops_mod_name = format_ident!("{}_ops", struct_name_str.to_lowercase());
9388

9489
let result = TokenStream::from(quote::quote_spanned! { span.into() =>
9590
extern "C" {
9691
fn #sw_add_ne_extern_func(rd: usize, rs1: usize, rs2: usize);
9792
fn #sw_double_extern_func(rd: usize, rs1: usize);
98-
fn #sw_setup_extern_func();
93+
fn #sw_setup_extern_func(uninit: *mut core::ffi::c_void, p1: *const u8, p2: *const u8);
9994
}
10095

10196
#[derive(Eq, PartialEq, Clone, Debug, serde::Serialize, serde::Deserialize)]
@@ -196,8 +191,21 @@ pub fn sw_declare(input: TokenStream) -> TokenStream {
196191
#[cfg(target_os = "zkvm")]
197192
fn set_up_once() {
198193
static is_setup: ::openvm_ecc_guest::once_cell::race::OnceBool = ::openvm_ecc_guest::once_cell::race::OnceBool::new();
194+
199195
is_setup.get_or_init(|| {
200-
unsafe { #sw_setup_extern_func(); }
196+
// p1 is (x1, y1), and x1 must be the modulus.
197+
// y1 can be anything for SetupEcAdd, but must equal `a` for SetupEcDouble
198+
let modulus_bytes = <<Self as openvm_ecc_guest::weierstrass::WeierstrassPoint>::Coordinate as openvm_algebra_guest::IntMod>::MODULUS;
199+
let mut one = [0u8; <<Self as openvm_ecc_guest::weierstrass::WeierstrassPoint>::Coordinate as openvm_algebra_guest::IntMod>::NUM_LIMBS];
200+
one[0] = 1;
201+
let curve_a_bytes = openvm_algebra_guest::IntMod::as_le_bytes(&<#struct_name as openvm_ecc_guest::weierstrass::WeierstrassPoint>::CURVE_A);
202+
// p1 should be (p, a)
203+
let p1 = [modulus_bytes.as_ref(), curve_a_bytes.as_ref()].concat();
204+
// (EcAdd only) p2 is (x2, y2), and x1 - x2 has to be non-zero to avoid division over zero in add.
205+
let p2 = [one.as_ref(), one.as_ref()].concat();
206+
let mut uninit: core::mem::MaybeUninit<[Self; 2]> = core::mem::MaybeUninit::uninit();
207+
208+
unsafe { #sw_setup_extern_func(uninit.as_mut_ptr() as *mut core::ffi::c_void, p1.as_ptr(), p2.as_ptr()); }
201209
<#intmod_type as openvm_algebra_guest::IntMod>::set_up_once();
202210
true
203211
});
@@ -410,23 +418,14 @@ pub fn sw_declare(input: TokenStream) -> TokenStream {
410418
}
411419

412420
struct SwDefine {
413-
items: Vec<Path>,
421+
items: Vec<String>,
414422
}
415423

416424
impl Parse for SwDefine {
417425
fn parse(input: ParseStream) -> syn::Result<Self> {
418-
let items = input.parse_terminated(<Expr as Parse>::parse, Token![,])?;
426+
let items = input.parse_terminated(<LitStr as Parse>::parse, Token![,])?;
419427
Ok(Self {
420-
items: items
421-
.into_iter()
422-
.map(|e| {
423-
if let Expr::Path(p) = e {
424-
p.path
425-
} else {
426-
panic!("expected path");
427-
}
428-
})
429-
.collect(),
428+
items: items.into_iter().map(|e| e.value()).collect()
430429
})
431430
}
432431
}
@@ -439,19 +438,15 @@ pub fn sw_init(input: TokenStream) -> TokenStream {
439438

440439
let span = proc_macro::Span::call_site();
441440

442-
for (ec_idx, item) in items.into_iter().enumerate() {
443-
let str_path = item
444-
.segments
445-
.iter()
446-
.map(|x| x.ident.to_string())
447-
.collect::<Vec<_>>()
448-
.join("_");
441+
for (ec_idx, struct_id) in items.into_iter().enumerate() {
442+
// Unique identifier shared by sw_define! and sw_init! used for naming the extern funcs.
443+
// Currently it's just the struct type name.
449444
let add_ne_extern_func =
450-
syn::Ident::new(&format!("sw_add_ne_extern_func_{}", str_path), span.into());
445+
syn::Ident::new(&format!("sw_add_ne_extern_func_{}", struct_id), span.into());
451446
let double_extern_func =
452-
syn::Ident::new(&format!("sw_double_extern_func_{}", str_path), span.into());
447+
syn::Ident::new(&format!("sw_double_extern_func_{}", struct_id), span.into());
453448
let setup_extern_func =
454-
syn::Ident::new(&format!("sw_setup_extern_func_{}", str_path), span.into());
449+
syn::Ident::new(&format!("sw_setup_extern_func_{}", struct_id), span.into());
455450

456451
externs.push(quote::quote_spanned! { span.into() =>
457452
#[no_mangle]
@@ -481,41 +476,31 @@ pub fn sw_init(input: TokenStream) -> TokenStream {
481476
}
482477

483478
#[no_mangle]
484-
extern "C" fn #setup_extern_func() {
479+
extern "C" fn #setup_extern_func(uninit: *mut core::ffi::c_void, p1: *const u8, p2: *const u8) {
485480
#[cfg(target_os = "zkvm")]
486481
{
487-
use super::#item;
488-
// p1 is (x1, y1), and x1 must be the modulus.
489-
// y1 can be anything for SetupEcAdd, but must equal `a` for SetupEcDouble
490-
let modulus_bytes = <<#item as openvm_ecc_guest::weierstrass::WeierstrassPoint>::Coordinate as openvm_algebra_guest::IntMod>::MODULUS;
491-
let mut one = [0u8; <<#item as openvm_ecc_guest::weierstrass::WeierstrassPoint>::Coordinate as openvm_algebra_guest::IntMod>::NUM_LIMBS];
492-
one[0] = 1;
493-
let curve_a_bytes = openvm_algebra_guest::IntMod::as_le_bytes(&<#item as openvm_ecc_guest::weierstrass::WeierstrassPoint>::CURVE_A);
494-
// p1 should be (p, a)
495-
let p1 = [modulus_bytes.as_ref(), curve_a_bytes.as_ref()].concat();
496-
// (EcAdd only) p2 is (x2, y2), and x1 - x2 has to be non-zero to avoid division over zero in add.
497-
let p2 = [one.as_ref(), one.as_ref()].concat();
498-
let mut uninit: core::mem::MaybeUninit<[#item; 2]> = core::mem::MaybeUninit::uninit();
499482
openvm::platform::custom_insn_r!(
500483
opcode = ::openvm_ecc_guest::OPCODE,
501484
funct3 = ::openvm_ecc_guest::SW_FUNCT3 as usize,
502485
funct7 = ::openvm_ecc_guest::SwBaseFunct7::SwSetup as usize
503486
+ #ec_idx
504487
* (::openvm_ecc_guest::SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
505-
rd = In uninit.as_mut_ptr(),
506-
rs1 = In p1.as_ptr(),
507-
rs2 = In p2.as_ptr()
488+
rd = In uninit,
489+
rs1 = In p1,
490+
rs2 = In p2
508491
);
509492
openvm::platform::custom_insn_r!(
510493
opcode = ::openvm_ecc_guest::OPCODE,
511494
funct3 = ::openvm_ecc_guest::SW_FUNCT3 as usize,
512495
funct7 = ::openvm_ecc_guest::SwBaseFunct7::SwSetup as usize
513496
+ #ec_idx
514497
* (::openvm_ecc_guest::SwBaseFunct7::SHORT_WEIERSTRASS_MAX_KINDS as usize),
515-
rd = In uninit.as_mut_ptr(),
516-
rs1 = In p1.as_ptr(),
498+
rd = In uninit,
499+
rs1 = In p1,
517500
rs2 = Const "x0" // will be parsed as 0 and therefore transpiled to SETUP_EC_DOUBLE
518501
);
502+
503+
519504
}
520505
}
521506
});

extensions/ecc/tests/programs/examples/invalid_setup.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ openvm_algebra_moduli_macros::moduli_init! {
1414

1515
// the order of the curves here does not match the order in supported_curves
1616
openvm_ecc_sw_macros::sw_init! {
17-
P256Point,
18-
Secp256k1Point,
17+
"P256Point",
18+
"Secp256k1Point",
1919
}
2020

2121
openvm::entry!(main);
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
// This file is automatically generated by cargo openvm. Do not rename or edit.
22
openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337", "115792089237316195423570985008687907853269984665640564039457584007913129639501", "1000000007", "26959946667150639794667015087019630673557916260026308143510066298881", "26959946667150639794667015087019625940457807714424391721682722368061" }
3-
openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point, CurvePoint5mod8, CurvePoint1mod4 }
3+
openvm_ecc_guest::sw_macros::sw_init! { "Secp256k1Point", "CurvePoint5mod8", "CurvePoint1mod4" }
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
// This file is automatically generated by cargo openvm. Do not rename or edit.
22
openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" }
3-
openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point }
3+
openvm_ecc_guest::sw_macros::sw_init! { "Secp256k1Point" }
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
// This file is automatically generated by cargo openvm. Do not rename or edit.
22
openvm_algebra_guest::moduli_macros::moduli_init! { "115792089210356248762697446949407573530086143415290314195533631308867097853951", "115792089210356248762697446949407573529996955224135760342422259061068512044369" }
3-
openvm_ecc_guest::sw_macros::sw_init! { P256Point }
3+
openvm_ecc_guest::sw_macros::sw_init! { "P256Point" }
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
// This file is automatically generated by cargo openvm. Do not rename or edit.
22
openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337", "115792089210356248762697446949407573530086143415290314195533631308867097853951", "115792089210356248762697446949407573529996955224135760342422259061068512044369" }
3-
openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point, P256Point }
3+
openvm_ecc_guest::sw_macros::sw_init! { "Secp256k1Point", "P256Point" }
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
// This file is automatically generated by cargo openvm. Do not rename or edit.
22
openvm_algebra_guest::moduli_macros::moduli_init! { "115792089237316195423570985008687907853269984665640564039457584007908834671663", "115792089237316195423570985008687907852837564279074904382605163141518161494337" }
3-
openvm_ecc_guest::sw_macros::sw_init! { Secp256k1Point }
3+
openvm_ecc_guest::sw_macros::sw_init! { "Secp256k1Point" }

guest-libs/k256/tests/programs/examples/add.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use openvm_k256::Secp256k1Point;
1010
mod test_vectors;
1111
use test_vectors::ADD_TEST_VECTORS;
1212

13-
openvm::init!("openvm_init_simple.rs");
13+
openvm::init!("openvm_init_add.rs");
1414

1515
openvm::entry!(main);
1616

guest-libs/k256/tests/programs/examples/mul.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use openvm_k256::Secp256k1Point;
1010
mod test_vectors;
1111
use test_vectors::{ADD_TEST_VECTORS, MUL_TEST_VECTORS};
1212

13-
openvm::init!("openvm_init_simple.rs");
13+
openvm::init!("openvm_init_mul.rs");
1414

1515
openvm::entry!(main);
1616

0 commit comments

Comments
 (0)