1+ /*
2+ * The `ColsRef` procedural macro is used in constraint generation to create column structs that
3+ * have dynamic sizes.
4+ *
5+ * Note: this macro was originally created for use in the SHA-2 VM extension, where we reuse the
6+ * same constraint generation code for three different circuits (SHA-256, SHA-512, and SHA-384).
7+ * See the [SHA-2 VM extension](openvm/extensions/sha2/circuit/src/sha2_chip/air.rs) for an
8+ * example of how to use the `ColsRef` macro to reuse constraint generation code over multiple
9+ * circuits.
10+ *
11+ * This macro can also be used in other situations where we want to derive Borrow<T> for &[u8],
12+ * for some complicated struct T.
13+ */
14+ mod utils;
15+
16+ use utils:: * ;
17+
118extern crate proc_macro;
219
320use itertools:: Itertools ;
@@ -169,7 +186,8 @@ fn make_struct(struct_info: StructInfo, config: &proc_macro2::Ident) -> proc_mac
169186 }
170187 }
171188
172- // returns number of cells in the struct (where each cell has type T)
189+ // Returns number of cells in the struct (where each cell has type T).
190+ // This method should only be called if the struct has no primitive types (i.e. for columns structs).
173191 pub const fn width<C : #config>( ) -> usize {
174192 0 #( + #length_exprs ) *
175193 }
@@ -227,7 +245,7 @@ fn make_from_mut(struct_info: StructInfo, config: &proc_macro2::Ident) -> proc_m
227245 & other. #ident
228246 }
229247 } else {
230- panic ! ( "Unsupported field type: {:?}" , f. ty) ;
248+ panic ! ( "Unsupported field type (in make_from_mut) : {:?}" , f. ty) ;
231249 }
232250 } )
233251 . collect_vec ( ) ;
@@ -346,8 +364,31 @@ fn get_const_cols_ref_fields(
346364 #slice_var
347365 } ,
348366 }
367+ } else if is_primitive_type ( & elem_type) {
368+ FieldInfo {
369+ ty : parse_quote ! {
370+ & ' a #elem_type
371+ } ,
372+ // Columns structs won't ever have primitive types, but this macro can be used on
373+ // other structs as well, to make it easy to borrow a struct from &[u8].
374+ // We just set length = 0 knowing that calling the width() method is undefined if
375+ // the struct has a primitive type.
376+ length_expr : quote ! {
377+ 0
378+ } ,
379+ prepare_subslice : quote ! {
380+ let ( #slice_var, slice) = slice. split_at( std:: mem:: size_of:: <#elem_type>( ) #( * #dim_exprs) * ) ;
381+ let #slice_var = ndarray:: #ndarray_ident:: from_shape( ( #( #dim_exprs) , * ) , #slice_var) . unwrap( ) ;
382+ } ,
383+ initializer : quote ! {
384+ #slice_var
385+ } ,
386+ }
349387 } else {
350- panic ! ( "Unsupported field type: {:?}" , f. ty) ;
388+ panic ! (
389+ "Unsupported field type (in get_const_cols_ref_fields): {:?}" ,
390+ f. ty
391+ ) ;
351392 }
352393 } else if derives_aligned_borrow {
353394 // treat the field as a struct that derives AlignedBorrow (and doesn't depend on the config)
@@ -405,7 +446,10 @@ fn get_const_cols_ref_fields(
405446 } ,
406447 }
407448 } else {
408- panic ! ( "Unsupported field type: {:?}" , f. ty) ;
449+ panic ! (
450+ "Unsupported field type (in get_mut_cols_ref_fields): {:?}" ,
451+ f. ty
452+ ) ;
409453 }
410454}
411455
@@ -485,8 +529,31 @@ fn get_mut_cols_ref_fields(
485529 #slice_var
486530 } ,
487531 }
532+ } else if is_primitive_type ( & elem_type) {
533+ FieldInfo {
534+ ty : parse_quote ! {
535+ & ' a mut #elem_type
536+ } ,
537+ // Columns structs won't ever have primitive types, but this macro can be used on
538+ // other structs as well, to make it easy to borrow a struct from &[u8].
539+ // We just set length = 0 knowing that calling the width() method is undefined if
540+ // the struct has a primitive type.
541+ length_expr : quote ! {
542+ 0
543+ } ,
544+ prepare_subslice : quote ! {
545+ let ( #slice_var, slice) = slice. split_at_mut( std:: mem:: size_of:: <#elem_type>( ) #( * #dim_exprs) * ) ;
546+ let #slice_var = ndarray:: #ndarray_ident:: from_shape( ( #( #dim_exprs) , * ) , #slice_var) . unwrap( ) ;
547+ } ,
548+ initializer : quote ! {
549+ #slice_var
550+ } ,
551+ }
488552 } else {
489- panic ! ( "Unsupported field type: {:?}" , f. ty) ;
553+ panic ! (
554+ "Unsupported field type (in get_mut_cols_ref_fields): {:?}" ,
555+ f. ty
556+ ) ;
490557 }
491558 } else if derives_aligned_borrow {
492559 // treat the field as a struct that derives AlignedBorrow (and doesn't depend on the config)
@@ -544,7 +611,10 @@ fn get_mut_cols_ref_fields(
544611 } ,
545612 }
546613 } else {
547- panic ! ( "Unsupported field type: {:?}" , f. ty) ;
614+ panic ! (
615+ "Unsupported field type (in get_mut_cols_ref_fields): {:?}" ,
616+ f. ty
617+ ) ;
548618 }
549619}
550620
@@ -556,7 +626,7 @@ fn is_columns_struct(ty: &syn::Type) -> bool {
556626 . path
557627 . segments
558628 . iter ( )
559- . last ( )
629+ . next_back ( )
560630 . map ( |s| s. ident . to_string ( ) . ends_with ( "Cols" ) )
561631 . unwrap_or ( false )
562632 } else {
@@ -576,7 +646,7 @@ fn get_const_cols_ref_type(
576646 }
577647
578648 if let syn:: Type :: Path ( type_path) = ty {
579- let s = type_path. path . segments . iter ( ) . last ( ) . unwrap ( ) ;
649+ let s = type_path. path . segments . iter ( ) . next_back ( ) . unwrap ( ) ;
580650 if s. ident . to_string ( ) . ends_with ( "Cols" ) {
581651 let const_cols_ref_ident = format_ident ! ( "{}Ref" , s. ident) ;
582652 let const_cols_ref_type = parse_quote ! {
@@ -602,7 +672,7 @@ fn get_mut_cols_ref_type(ty: &syn::Type, generic_type: &syn::TypeParam) -> syn::
602672 }
603673
604674 if let syn:: Type :: Path ( type_path) = ty {
605- let s = type_path. path . segments . iter ( ) . last ( ) . unwrap ( ) ;
675+ let s = type_path. path . segments . iter ( ) . next_back ( ) . unwrap ( ) ;
606676 if s. ident . to_string ( ) . ends_with ( "Cols" ) {
607677 let mut_cols_ref_ident = format_ident ! ( "{}RefMut" , s. ident) ;
608678 let mut_cols_ref_type = parse_quote ! {
@@ -627,7 +697,7 @@ fn is_generic_type(ty: &syn::Type, generic_type: &syn::TypeParam) -> bool {
627697 . path
628698 . segments
629699 . iter ( )
630- . last ( )
700+ . next_back ( )
631701 . map ( |s| s. ident == generic_type. ident )
632702 . unwrap_or ( false )
633703 } else {
@@ -637,61 +707,3 @@ fn is_generic_type(ty: &syn::Type, generic_type: &syn::TypeParam) -> bool {
637707 false
638708 }
639709}
640-
641- // Type of array dimension
642- enum Dimension {
643- ConstGeneric ( syn:: Expr ) ,
644- Other ( syn:: Expr ) ,
645- }
646-
647- // Describes a nested array
648- struct ArrayInfo {
649- dims : Vec < Dimension > ,
650- elem_type : syn:: Type ,
651- }
652-
653- fn get_array_info ( ty : & syn:: Type , const_generics : & [ & syn:: Ident ] ) -> ArrayInfo {
654- let dims = get_dims ( ty, const_generics) ;
655- let elem_type = get_elem_type ( ty) ;
656- ArrayInfo { dims, elem_type }
657- }
658-
659- fn get_elem_type ( ty : & syn:: Type ) -> syn:: Type {
660- match ty {
661- syn:: Type :: Array ( array) => get_elem_type ( array. elem . as_ref ( ) ) ,
662- syn:: Type :: Path ( _) => ty. clone ( ) ,
663- _ => panic ! ( "Unsupported type: {:?}" , ty) ,
664- }
665- }
666-
667- // Get a vector of the dimensions of the array
668- // Each dimension is either a constant generic or a literal integer value
669- fn get_dims ( ty : & syn:: Type , const_generics : & [ & syn:: Ident ] ) -> Vec < Dimension > {
670- get_dims_impl ( ty, const_generics)
671- . into_iter ( )
672- . rev ( )
673- . collect ( )
674- }
675-
676- fn get_dims_impl ( ty : & syn:: Type , const_generics : & [ & syn:: Ident ] ) -> Vec < Dimension > {
677- match ty {
678- syn:: Type :: Array ( array) => {
679- let mut dims = get_dims_impl ( array. elem . as_ref ( ) , const_generics) ;
680- match & array. len {
681- syn:: Expr :: Path ( syn:: ExprPath { path, .. } ) => {
682- let len_ident = path. get_ident ( ) ;
683- if len_ident. is_some ( ) && const_generics. contains ( & len_ident. unwrap ( ) ) {
684- dims. push ( Dimension :: ConstGeneric ( array. len . clone ( ) ) ) ;
685- } else {
686- dims. push ( Dimension :: Other ( array. len . clone ( ) ) ) ;
687- }
688- }
689- syn:: Expr :: Lit ( expr_lit) => dims. push ( Dimension :: Other ( expr_lit. clone ( ) . into ( ) ) ) ,
690- _ => panic ! ( "Unsupported array length type" ) ,
691- }
692- dims
693- }
694- syn:: Type :: Path ( _) => Vec :: new ( ) ,
695- _ => panic ! ( "Unsupported field type" ) ,
696- }
697- }
0 commit comments