1
- use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity , DiffMode } ;
1
+ use rustc_ast:: expand:: autodiff_attrs:: { AutoDiffAttrs , DiffActivity , DiffMode , valid_ret_activity , valid_input_activities } ;
2
2
use rustc_ast:: { ast, attr, MetaItem , MetaItemKind , NestedMetaItem } ;
3
3
use rustc_attr:: { list_contains_name, InlineAttr , InstructionSetAttr , OptimizeAttr } ;
4
4
use rustc_errors:: struct_span_err;
@@ -692,6 +692,11 @@ fn check_link_name_xor_ordinal(
692
692
}
693
693
}
694
694
695
+ /// We now check the #[rustc_autodiff] attributes which we generated from the #[autodiff(...)]
696
+ /// macros. There are two forms. The pure one without args to mark primal functions (the functions
697
+ /// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
698
+ /// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
699
+ /// panic, unless we introduced a bug when parsing the autodiff macro.
695
700
fn autodiff_attrs ( tcx : TyCtxt < ' _ > , id : DefId ) -> AutoDiffAttrs {
696
701
let attrs = tcx. get_attrs ( id, sym:: rustc_autodiff) ;
697
702
@@ -726,20 +731,22 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
726
731
return AutoDiffAttrs :: source ( ) ;
727
732
}
728
733
729
- let msg_ad_mode = "autodiff attribute must contain autodiff mode" ;
730
- let ( mode, list) = match list. split_first ( ) {
731
- Some ( (
732
- NestedMetaItem :: MetaItem ( MetaItem { path : ref p1, kind : MetaItemKind :: Word , .. } ) ,
733
- list,
734
- ) ) => ( p1. segments . first ( ) . unwrap ( ) . ident , list) ,
735
- _ => {
736
- tcx. sess
737
- . struct_span_err ( attr. span , msg_ad_mode)
738
- . span_label ( attr. span , "empty argument list" )
739
- . emit ( ) ;
740
-
741
- return AutoDiffAttrs :: inactive ( ) ;
742
- }
734
+ let [ mode, input_activities @ .., ret_activity] = & list[ ..] else {
735
+ tcx. sess
736
+ . struct_span_err ( attr. span , msg_once)
737
+ . span_label ( attr. span , "Implementation bug in autodiff_attrs. Please report this!" )
738
+ . emit ( ) ;
739
+ return AutoDiffAttrs :: inactive ( ) ;
740
+ } ;
741
+ let mode = if let NestedMetaItem :: MetaItem ( MetaItem { path : ref p1, .. } ) = mode {
742
+ p1. segments . first ( ) . unwrap ( ) . ident
743
+ } else {
744
+ let msg = "autodiff attribute must contain autodiff mode" ;
745
+ tcx. sess
746
+ . struct_span_err ( attr. span , msg)
747
+ . span_label ( attr. span , "empty argument list" )
748
+ . emit ( ) ;
749
+ return AutoDiffAttrs :: inactive ( ) ;
743
750
} ;
744
751
745
752
// parse mode
@@ -752,27 +759,23 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
752
759
. struct_span_err ( attr. span , msg_mode)
753
760
. span_label ( attr. span , "invalid mode" )
754
761
. emit ( ) ;
755
-
756
762
return AutoDiffAttrs :: inactive ( ) ;
757
763
}
758
764
} ;
759
765
760
- let msg_ret_activity = "autodiff attribute must contain the return activity" ;
761
- let ( ret_symbol, list) = match list. split_last ( ) {
762
- Some ( (
763
- NestedMetaItem :: MetaItem ( MetaItem { path : ref p1, kind : MetaItemKind :: Word , .. } ) ,
764
- list,
765
- ) ) => ( p1. segments . first ( ) . unwrap ( ) . ident , list) ,
766
- _ => {
767
- tcx. sess
768
- . struct_span_err ( attr. span , msg_ret_activity)
769
- . span_label ( attr. span , "missing return activity" )
770
- . emit ( ) ;
771
-
772
- return AutoDiffAttrs :: inactive ( ) ;
773
- }
766
+ // First read the ret symbol from the attribute
767
+ let ret_symbol = if let NestedMetaItem :: MetaItem ( MetaItem { path : ref p1, .. } ) = ret_activity {
768
+ p1. segments . first ( ) . unwrap ( ) . ident
769
+ } else {
770
+ let msg = "autodiff attribute must contain the return activity" ;
771
+ tcx. sess
772
+ . struct_span_err ( attr. span , msg)
773
+ . span_label ( attr. span , "missing return activity" )
774
+ . emit ( ) ;
775
+ return AutoDiffAttrs :: inactive ( ) ;
774
776
} ;
775
777
778
+ // Then parse it into an actual DiffActivity
776
779
let msg_unknown_ret_activity = "unknown return activity" ;
777
780
let ret_activity = match DiffActivity :: from_str ( ret_symbol. as_str ( ) ) {
778
781
Ok ( x) => x,
@@ -781,26 +784,22 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
781
784
. struct_span_err ( attr. span , msg_unknown_ret_activity)
782
785
. span_label ( attr. span , "invalid return activity" )
783
786
. emit ( ) ;
784
-
785
787
return AutoDiffAttrs :: inactive ( ) ;
786
788
}
787
789
} ;
788
790
791
+ // Now parse all the intermediate (inptut) activities
789
792
let msg_arg_activity = "autodiff attribute must contain the return activity" ;
790
793
let mut arg_activities: Vec < DiffActivity > = vec ! [ ] ;
791
- for arg in list {
792
- let arg_symbol = match arg {
793
- NestedMetaItem :: MetaItem ( MetaItem {
794
- path : ref p2, kind : MetaItemKind :: Word , ..
795
- } ) => p2. segments . first ( ) . unwrap ( ) . ident ,
796
- _ => {
797
- tcx. sess
798
- . struct_span_err ( attr. span , msg_arg_activity)
799
- . span_label ( attr. span , "missing return activity" )
800
- . emit ( ) ;
801
-
802
- return AutoDiffAttrs :: inactive ( ) ;
803
- }
794
+ for arg in input_activities {
795
+ let arg_symbol = if let NestedMetaItem :: MetaItem ( MetaItem { path : ref p2, .. } ) = arg {
796
+ p2. segments . first ( ) . unwrap ( ) . ident
797
+ } else {
798
+ tcx. sess
799
+ . struct_span_err ( attr. span , msg_arg_activity)
800
+ . span_label ( attr. span , "Implementation bug, please report this!" )
801
+ . emit ( ) ;
802
+ return AutoDiffAttrs :: inactive ( ) ;
804
803
} ;
805
804
806
805
match DiffActivity :: from_str ( arg_symbol. as_str ( ) ) {
@@ -810,45 +809,22 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
810
809
. struct_span_err ( attr. span , msg_unknown_ret_activity)
811
810
. span_label ( attr. span , "invalid input activity" )
812
811
. emit ( ) ;
813
-
814
812
return AutoDiffAttrs :: inactive ( ) ;
815
813
}
816
814
}
817
815
}
818
816
819
- let msg_fwd_incompatible_ret = "Forward Mode is incompatible with Active ret" ;
820
- let msg_fwd_incompatible_arg = "Forward Mode is incompatible with Active ret" ;
821
- let msg_rev_incompatible_arg =
822
- "Reverse Mode is only compatible with Active, None, or Const ret" ;
823
- if mode == DiffMode :: Forward {
824
- if ret_activity == DiffActivity :: Active {
825
- tcx. sess
826
- . struct_span_err ( attr. span , msg_fwd_incompatible_ret)
827
- . span_label ( attr. span , "invalid return activity" )
828
- . emit ( ) ;
829
- return AutoDiffAttrs :: inactive ( ) ;
830
- }
831
- if arg_activities. iter ( ) . filter ( |& x| * x == DiffActivity :: Active ) . count ( ) > 0 {
832
- tcx. sess
833
- . struct_span_err ( attr. span , msg_fwd_incompatible_arg)
834
- . span_label ( attr. span , "invalid input activity" )
835
- . emit ( ) ;
836
- return AutoDiffAttrs :: inactive ( ) ;
837
- }
817
+ let msg = "Invalid activity for mode" ;
818
+ let valid_input = valid_input_activities ( mode, & arg_activities) ;
819
+ let valid_ret = valid_ret_activity ( mode, ret_activity) ;
820
+ if !valid_input || !valid_ret {
821
+ tcx. sess
822
+ . struct_span_err ( attr. span , msg)
823
+ . span_label ( attr. span , "invalid activity" )
824
+ . emit ( ) ;
825
+ return AutoDiffAttrs :: inactive ( ) ;
838
826
}
839
827
840
- if mode == DiffMode :: Reverse {
841
- if ret_activity == DiffActivity :: Duplicated
842
- || ret_activity == DiffActivity :: DuplicatedOnly
843
- {
844
- dbg ! ( "ret_activity = {:?}" , ret_activity) ;
845
- tcx. sess
846
- . struct_span_err ( attr. span , msg_rev_incompatible_arg)
847
- . span_label ( attr. span , "invalid return activity" )
848
- . emit ( ) ;
849
- return AutoDiffAttrs :: inactive ( ) ;
850
- }
851
- }
852
828
853
829
AutoDiffAttrs { mode, ret_activity, input_activity : arg_activities }
854
830
}
0 commit comments