@@ -5,7 +5,8 @@ use proc_macro::TokenStream;
5
5
use quote:: format_ident;
6
6
use syn:: {
7
7
parse:: { Parse , ParseStream } ,
8
- parse_macro_input, Expr , ExprPath , Path , Token ,
8
+ parse_macro_input, ExprPath , Token ,
9
+ LitStr ,
9
10
} ;
10
11
11
12
/// This macro generates the code to setup the elliptic curve for a given modular type. Also it
@@ -28,9 +29,8 @@ pub fn sw_declare(input: TokenStream) -> TokenStream {
28
29
let span = proc_macro:: Span :: call_site ( ) ;
29
30
30
31
for item in items. into_iter ( ) {
31
- let struct_name = item. name . to_string ( ) ;
32
- let struct_name = syn:: Ident :: new ( & struct_name, span. into ( ) ) ;
33
- let struct_path: syn:: Path = syn:: parse_quote!( #struct_name) ;
32
+ let struct_name_str = item. name . to_string ( ) ;
33
+ let struct_name = syn:: Ident :: new ( & struct_name_str, span. into ( ) ) ;
34
34
let mut intmod_type: Option < syn:: Path > = None ;
35
35
let mut const_a: Option < syn:: Expr > = None ;
36
36
let mut const_b: Option < syn:: Expr > = None ;
@@ -74,12 +74,7 @@ pub fn sw_declare(input: TokenStream) -> TokenStream {
74
74
& format!(
75
75
"{}_{}" ,
76
76
stringify!( $name) ,
77
- struct_path
78
- . segments
79
- . iter( )
80
- . map( |x| x. ident. to_string( ) )
81
- . collect:: <Vec <_>>( )
82
- . join( "_" )
77
+ struct_name_str
83
78
) ,
84
79
span. into( ) ,
85
80
) ;
@@ -89,13 +84,13 @@ pub fn sw_declare(input: TokenStream) -> TokenStream {
89
84
create_extern_func ! ( sw_double_extern_func) ;
90
85
create_extern_func ! ( sw_setup_extern_func) ;
91
86
92
- let group_ops_mod_name = format_ident ! ( "{}_ops" , struct_name . to_string ( ) . to_lowercase( ) ) ;
87
+ let group_ops_mod_name = format_ident ! ( "{}_ops" , struct_name_str . to_lowercase( ) ) ;
93
88
94
89
let result = TokenStream :: from ( quote:: quote_spanned! { span. into( ) =>
95
90
extern "C" {
96
91
fn #sw_add_ne_extern_func( rd: usize , rs1: usize , rs2: usize ) ;
97
92
fn #sw_double_extern_func( rd: usize , rs1: usize ) ;
98
- fn #sw_setup_extern_func( ) ;
93
+ fn #sw_setup_extern_func( uninit : * mut core :: ffi :: c_void , p1 : * const u8 , p2 : * const u8 ) ;
99
94
}
100
95
101
96
#[ derive( Eq , PartialEq , Clone , Debug , serde:: Serialize , serde:: Deserialize ) ]
@@ -196,8 +191,21 @@ pub fn sw_declare(input: TokenStream) -> TokenStream {
196
191
#[ cfg( target_os = "zkvm" ) ]
197
192
fn set_up_once( ) {
198
193
static is_setup: :: openvm_ecc_guest:: once_cell:: race:: OnceBool = :: openvm_ecc_guest:: once_cell:: race:: OnceBool :: new( ) ;
194
+
199
195
is_setup. get_or_init( || {
200
- unsafe { #sw_setup_extern_func( ) ; }
196
+ // p1 is (x1, y1), and x1 must be the modulus.
197
+ // y1 can be anything for SetupEcAdd, but must equal `a` for SetupEcDouble
198
+ let modulus_bytes = <<Self as openvm_ecc_guest:: weierstrass:: WeierstrassPoint >:: Coordinate as openvm_algebra_guest:: IntMod >:: MODULUS ;
199
+ let mut one = [ 0u8 ; <<Self as openvm_ecc_guest:: weierstrass:: WeierstrassPoint >:: Coordinate as openvm_algebra_guest:: IntMod >:: NUM_LIMBS ] ;
200
+ one[ 0 ] = 1 ;
201
+ let curve_a_bytes = openvm_algebra_guest:: IntMod :: as_le_bytes( & <#struct_name as openvm_ecc_guest:: weierstrass:: WeierstrassPoint >:: CURVE_A ) ;
202
+ // p1 should be (p, a)
203
+ let p1 = [ modulus_bytes. as_ref( ) , curve_a_bytes. as_ref( ) ] . concat( ) ;
204
+ // (EcAdd only) p2 is (x2, y2), and x1 - x2 has to be non-zero to avoid division over zero in add.
205
+ let p2 = [ one. as_ref( ) , one. as_ref( ) ] . concat( ) ;
206
+ let mut uninit: core:: mem:: MaybeUninit <[ Self ; 2 ] > = core:: mem:: MaybeUninit :: uninit( ) ;
207
+
208
+ unsafe { #sw_setup_extern_func( uninit. as_mut_ptr( ) as * mut core:: ffi:: c_void, p1. as_ptr( ) , p2. as_ptr( ) ) ; }
201
209
<#intmod_type as openvm_algebra_guest:: IntMod >:: set_up_once( ) ;
202
210
true
203
211
} ) ;
@@ -410,23 +418,14 @@ pub fn sw_declare(input: TokenStream) -> TokenStream {
410
418
}
411
419
412
420
struct SwDefine {
413
- items : Vec < Path > ,
421
+ items : Vec < String > ,
414
422
}
415
423
416
424
impl Parse for SwDefine {
417
425
fn parse ( input : ParseStream ) -> syn:: Result < Self > {
418
- let items = input. parse_terminated ( <Expr as Parse >:: parse, Token ! [ , ] ) ?;
426
+ let items = input. parse_terminated ( <LitStr as Parse >:: parse, Token ! [ , ] ) ?;
419
427
Ok ( Self {
420
- items : items
421
- . into_iter ( )
422
- . map ( |e| {
423
- if let Expr :: Path ( p) = e {
424
- p. path
425
- } else {
426
- panic ! ( "expected path" ) ;
427
- }
428
- } )
429
- . collect ( ) ,
428
+ items : items. into_iter ( ) . map ( |e| e. value ( ) ) . collect ( )
430
429
} )
431
430
}
432
431
}
@@ -439,19 +438,15 @@ pub fn sw_init(input: TokenStream) -> TokenStream {
439
438
440
439
let span = proc_macro:: Span :: call_site ( ) ;
441
440
442
- for ( ec_idx, item) in items. into_iter ( ) . enumerate ( ) {
443
- let str_path = item
444
- . segments
445
- . iter ( )
446
- . map ( |x| x. ident . to_string ( ) )
447
- . collect :: < Vec < _ > > ( )
448
- . join ( "_" ) ;
441
+ for ( ec_idx, struct_id) in items. into_iter ( ) . enumerate ( ) {
442
+ // Unique identifier shared by sw_define! and sw_init! used for naming the extern funcs.
443
+ // Currently it's just the struct type name.
449
444
let add_ne_extern_func =
450
- syn:: Ident :: new ( & format ! ( "sw_add_ne_extern_func_{}" , str_path ) , span. into ( ) ) ;
445
+ syn:: Ident :: new ( & format ! ( "sw_add_ne_extern_func_{}" , struct_id ) , span. into ( ) ) ;
451
446
let double_extern_func =
452
- syn:: Ident :: new ( & format ! ( "sw_double_extern_func_{}" , str_path ) , span. into ( ) ) ;
447
+ syn:: Ident :: new ( & format ! ( "sw_double_extern_func_{}" , struct_id ) , span. into ( ) ) ;
453
448
let setup_extern_func =
454
- syn:: Ident :: new ( & format ! ( "sw_setup_extern_func_{}" , str_path ) , span. into ( ) ) ;
449
+ syn:: Ident :: new ( & format ! ( "sw_setup_extern_func_{}" , struct_id ) , span. into ( ) ) ;
455
450
456
451
externs. push ( quote:: quote_spanned! { span. into( ) =>
457
452
#[ no_mangle]
@@ -481,41 +476,31 @@ pub fn sw_init(input: TokenStream) -> TokenStream {
481
476
}
482
477
483
478
#[ no_mangle]
484
- extern "C" fn #setup_extern_func( ) {
479
+ extern "C" fn #setup_extern_func( uninit : * mut core :: ffi :: c_void , p1 : * const u8 , p2 : * const u8 ) {
485
480
#[ cfg( target_os = "zkvm" ) ]
486
481
{
487
- use super :: #item;
488
- // p1 is (x1, y1), and x1 must be the modulus.
489
- // y1 can be anything for SetupEcAdd, but must equal `a` for SetupEcDouble
490
- let modulus_bytes = <<#item as openvm_ecc_guest:: weierstrass:: WeierstrassPoint >:: Coordinate as openvm_algebra_guest:: IntMod >:: MODULUS ;
491
- let mut one = [ 0u8 ; <<#item as openvm_ecc_guest:: weierstrass:: WeierstrassPoint >:: Coordinate as openvm_algebra_guest:: IntMod >:: NUM_LIMBS ] ;
492
- one[ 0 ] = 1 ;
493
- let curve_a_bytes = openvm_algebra_guest:: IntMod :: as_le_bytes( & <#item as openvm_ecc_guest:: weierstrass:: WeierstrassPoint >:: CURVE_A ) ;
494
- // p1 should be (p, a)
495
- let p1 = [ modulus_bytes. as_ref( ) , curve_a_bytes. as_ref( ) ] . concat( ) ;
496
- // (EcAdd only) p2 is (x2, y2), and x1 - x2 has to be non-zero to avoid division over zero in add.
497
- let p2 = [ one. as_ref( ) , one. as_ref( ) ] . concat( ) ;
498
- let mut uninit: core:: mem:: MaybeUninit <[ #item; 2 ] > = core:: mem:: MaybeUninit :: uninit( ) ;
499
482
openvm:: platform:: custom_insn_r!(
500
483
opcode = :: openvm_ecc_guest:: OPCODE ,
501
484
funct3 = :: openvm_ecc_guest:: SW_FUNCT3 as usize ,
502
485
funct7 = :: openvm_ecc_guest:: SwBaseFunct7 :: SwSetup as usize
503
486
+ #ec_idx
504
487
* ( :: openvm_ecc_guest:: SwBaseFunct7 :: SHORT_WEIERSTRASS_MAX_KINDS as usize ) ,
505
- rd = In uninit. as_mut_ptr ( ) ,
506
- rs1 = In p1. as_ptr ( ) ,
507
- rs2 = In p2. as_ptr ( )
488
+ rd = In uninit,
489
+ rs1 = In p1,
490
+ rs2 = In p2
508
491
) ;
509
492
openvm:: platform:: custom_insn_r!(
510
493
opcode = :: openvm_ecc_guest:: OPCODE ,
511
494
funct3 = :: openvm_ecc_guest:: SW_FUNCT3 as usize ,
512
495
funct7 = :: openvm_ecc_guest:: SwBaseFunct7 :: SwSetup as usize
513
496
+ #ec_idx
514
497
* ( :: openvm_ecc_guest:: SwBaseFunct7 :: SHORT_WEIERSTRASS_MAX_KINDS as usize ) ,
515
- rd = In uninit. as_mut_ptr ( ) ,
516
- rs1 = In p1. as_ptr ( ) ,
498
+ rd = In uninit,
499
+ rs1 = In p1,
517
500
rs2 = Const "x0" // will be parsed as 0 and therefore transpiled to SETUP_EC_DOUBLE
518
501
) ;
502
+
503
+
519
504
}
520
505
}
521
506
} ) ;
0 commit comments