@@ -5,7 +5,7 @@ use proc_macro::TokenStream;
5
5
use quote:: format_ident;
6
6
use syn:: {
7
7
parse:: { Parse , ParseStream } ,
8
- parse_macro_input, Expr , ExprPath , Path , Token ,
8
+ parse_macro_input, ExprPath , LitStr , Token ,
9
9
} ;
10
10
11
11
/// This macro generates the code to setup a Twisted Edwards elliptic curve for a given modular
@@ -86,7 +86,7 @@ pub fn te_declare(input: TokenStream) -> TokenStream {
86
86
let result = TokenStream :: from ( quote:: quote_spanned! { span. into( ) =>
87
87
extern "C" {
88
88
fn #te_add_extern_func( rd: usize , rs1: usize , rs2: usize ) ;
89
- fn #te_setup_extern_func( ) ;
89
+ fn #te_setup_extern_func( uninit : * mut core :: ffi :: c_void , p1 : * const u8 , p2 : * const u8 ) ;
90
90
}
91
91
92
92
#[ derive( Eq , PartialEq , Clone , Debug , serde:: Serialize , serde:: Deserialize ) ]
@@ -143,7 +143,15 @@ pub fn te_declare(input: TokenStream) -> TokenStream {
143
143
fn set_up_once( ) {
144
144
static is_setup: :: openvm_ecc_guest:: once_cell:: race:: OnceBool = :: openvm_ecc_guest:: once_cell:: race:: OnceBool :: new( ) ;
145
145
is_setup. get_or_init( || {
146
- unsafe { #te_setup_extern_func( ) ; }
146
+ let modulus_bytes = <<Self as openvm_ecc_guest:: edwards:: TwistedEdwardsPoint >:: Coordinate as openvm_algebra_guest:: IntMod >:: MODULUS ;
147
+ let mut zero = [ 0u8 ; <<Self as openvm_ecc_guest:: edwards:: TwistedEdwardsPoint >:: Coordinate as openvm_algebra_guest:: IntMod >:: NUM_LIMBS ] ;
148
+ let curve_a_bytes = openvm_algebra_guest:: IntMod :: as_le_bytes( & <Self as openvm_ecc_guest:: edwards:: TwistedEdwardsPoint >:: CURVE_A ) ;
149
+ let curve_d_bytes = openvm_algebra_guest:: IntMod :: as_le_bytes( & <Self as openvm_ecc_guest:: edwards:: TwistedEdwardsPoint >:: CURVE_D ) ;
150
+ let p1 = [ modulus_bytes. as_ref( ) , curve_a_bytes. as_ref( ) ] . concat( ) ;
151
+ let p2 = [ curve_d_bytes. as_ref( ) , zero. as_ref( ) ] . concat( ) ;
152
+ let mut uninit: core:: mem:: MaybeUninit <[ Self ; 2 ] > = core:: mem:: MaybeUninit :: uninit( ) ;
153
+
154
+ unsafe { #te_setup_extern_func( uninit. as_mut_ptr( ) as * mut core:: ffi:: c_void, p1. as_ptr( ) , p2. as_ptr( ) ) ; }
147
155
<#intmod_type as openvm_algebra_guest:: IntMod >:: set_up_once( ) ;
148
156
true
149
157
} ) ;
@@ -266,22 +274,16 @@ pub fn te_declare(input: TokenStream) -> TokenStream {
266
274
}
267
275
268
276
struct TeDefine {
269
- items : Vec < Path > ,
277
+ items : Vec < String > ,
270
278
}
271
279
272
280
impl Parse for TeDefine {
273
281
fn parse ( input : ParseStream ) -> syn:: Result < Self > {
274
- let items = input. parse_terminated ( <Expr as Parse >:: parse, Token ! [ , ] ) ?;
282
+ let items = input. parse_terminated ( <LitStr as Parse >:: parse, Token ! [ , ] ) ?;
275
283
Ok ( Self {
276
284
items : items
277
285
. into_iter ( )
278
- . map ( |e| {
279
- if let Expr :: Path ( p) = e {
280
- p. path
281
- } else {
282
- panic ! ( "expected path" ) ;
283
- }
284
- } )
286
+ . map ( |e| e. value ( ) )
285
287
. collect ( ) ,
286
288
} )
287
289
}
@@ -295,17 +297,11 @@ pub fn te_init(input: TokenStream) -> TokenStream {
295
297
296
298
let span = proc_macro:: Span :: call_site ( ) ;
297
299
298
- for ( ec_idx, item) in items. into_iter ( ) . enumerate ( ) {
299
- let str_path = item
300
- . segments
301
- . iter ( )
302
- . map ( |x| x. ident . to_string ( ) )
303
- . collect :: < Vec < _ > > ( )
304
- . join ( "_" ) ;
300
+ for ( ec_idx, struct_id) in items. into_iter ( ) . enumerate ( ) {
305
301
let add_extern_func =
306
- syn:: Ident :: new ( & format ! ( "te_add_extern_func_{}" , str_path ) , span. into ( ) ) ;
302
+ syn:: Ident :: new ( & format ! ( "te_add_extern_func_{}" , struct_id ) , span. into ( ) ) ;
307
303
let setup_extern_func =
308
- syn:: Ident :: new ( & format ! ( "te_setup_extern_func_{}" , str_path ) , span. into ( ) ) ;
304
+ syn:: Ident :: new ( & format ! ( "te_setup_extern_func_{}" , struct_id ) , span. into ( ) ) ;
309
305
externs. push ( quote:: quote_spanned! { span. into( ) =>
310
306
#[ no_mangle]
311
307
extern "C" fn #add_extern_func( rd: usize , rs1: usize , rs2: usize ) {
@@ -321,26 +317,19 @@ pub fn te_init(input: TokenStream) -> TokenStream {
321
317
}
322
318
323
319
#[ no_mangle]
324
- extern "C" fn #setup_extern_func( ) {
320
+ extern "C" fn #setup_extern_func( uninit : * mut core :: ffi :: c_void , p1 : * const u8 , p2 : * const u8 ) {
325
321
#[ cfg( target_os = "zkvm" ) ]
326
322
{
327
- use super :: #item;
328
- let modulus_bytes = <<#item as openvm_ecc_guest:: edwards:: TwistedEdwardsPoint >:: Coordinate as openvm_algebra_guest:: IntMod >:: MODULUS ;
329
- let mut zero = [ 0u8 ; <<#item as openvm_ecc_guest:: edwards:: TwistedEdwardsPoint >:: Coordinate as openvm_algebra_guest:: IntMod >:: NUM_LIMBS ] ;
330
- let curve_a_bytes = openvm_algebra_guest:: IntMod :: as_le_bytes( & <#item as openvm_ecc_guest:: edwards:: TwistedEdwardsPoint >:: CURVE_A ) ;
331
- let curve_d_bytes = openvm_algebra_guest:: IntMod :: as_le_bytes( & <#item as openvm_ecc_guest:: edwards:: TwistedEdwardsPoint >:: CURVE_D ) ;
332
- let p1 = [ modulus_bytes. as_ref( ) , curve_a_bytes. as_ref( ) ] . concat( ) ;
333
- let p2 = [ curve_d_bytes. as_ref( ) , zero. as_ref( ) ] . concat( ) ;
334
- let mut uninit: core:: mem:: MaybeUninit <[ #item; 2 ] > = core:: mem:: MaybeUninit :: uninit( ) ;
323
+
335
324
openvm:: platform:: custom_insn_r!(
336
325
opcode = :: openvm_ecc_guest:: TE_OPCODE ,
337
326
funct3 = :: openvm_ecc_guest:: TE_FUNCT3 as usize ,
338
327
funct7 = :: openvm_ecc_guest:: TeBaseFunct7 :: TeSetup as usize
339
328
+ #ec_idx
340
329
* ( :: openvm_ecc_guest:: TeBaseFunct7 :: TWISTED_EDWARDS_MAX_KINDS as usize ) ,
341
- rd = In uninit. as_mut_ptr ( ) ,
342
- rs1 = In p1. as_ptr ( ) ,
343
- rs2 = In p2. as_ptr ( ) ,
330
+ rd = In uninit,
331
+ rs1 = In p1,
332
+ rs2 = In p2,
344
333
) ;
345
334
}
346
335
}
0 commit comments