Skip to content

Commit 7b0d0f1

Browse files
committed
cleanups2
1 parent c8c4ea3 commit 7b0d0f1

File tree

2 files changed

+91
-76
lines changed

2 files changed

+91
-76
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,45 @@ pub enum DiffMode {
1313
Reverse,
1414
}
1515

16+
pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
17+
match mode {
18+
DiffMode::Inactive => false,
19+
DiffMode::Source => false,
20+
DiffMode::Forward => {
21+
// Doesn't recognize all illegal cases (insufficient information)
22+
activity != DiffActivity::Active && activity != DiffActivity::ActiveOnly
23+
&& activity != DiffActivity::Duplicated && activity != DiffActivity::DuplicatedOnly
24+
}
25+
DiffMode::Reverse => {
26+
// Doesn't recognize all illegal cases (insufficient information)
27+
activity != DiffActivity::Duplicated && activity != DiffActivity::DuplicatedOnly
28+
&& activity != DiffActivity::Dual && activity != DiffActivity::DualOnly
29+
}
30+
}
31+
}
32+
33+
pub fn valid_input_activities(mode: DiffMode, activity_vec: &[DiffActivity]) -> bool {
34+
for &activity in activity_vec {
35+
let valid = match mode {
36+
DiffMode::Inactive => false,
37+
DiffMode::Source => false,
38+
DiffMode::Forward => {
39+
// These are the only valid cases
40+
activity == DiffActivity::Dual || activity == DiffActivity::DualOnly || activity == DiffActivity::Const
41+
}
42+
DiffMode::Reverse => {
43+
// These are the only valid cases
44+
activity == DiffActivity::Active || activity == DiffActivity::ActiveOnly || activity == DiffActivity::Const
45+
|| activity == DiffActivity::Duplicated || activity == DiffActivity::DuplicatedOnly
46+
}
47+
};
48+
if !valid {
49+
return false;
50+
}
51+
}
52+
true
53+
}
54+
1655
#[allow(dead_code)]
1756
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
1857
pub enum DiffActivity {

compiler/rustc_codegen_ssa/src/codegen_attrs.rs

Lines changed: 52 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
1+
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode, valid_ret_activity, valid_input_activities};
22
use rustc_ast::{ast, attr, MetaItem, MetaItemKind, NestedMetaItem};
33
use rustc_attr::{list_contains_name, InlineAttr, InstructionSetAttr, OptimizeAttr};
44
use rustc_errors::struct_span_err;
@@ -692,6 +692,11 @@ fn check_link_name_xor_ordinal(
692692
}
693693
}
694694

695+
/// We now check the #[rustc_autodiff] attributes which we generated from the #[autodiff(...)]
696+
/// macros. There are two forms. The pure one without args to mark primal functions (the functions
697+
/// being differentiated). The other form is #[rustc_autodiff(Mode, ActivityList)] on top of the
698+
/// placeholder functions. We wrote the rustc_autodiff attributes ourself, so this should never
699+
/// panic, unless we introduced a bug when parsing the autodiff macro.
695700
fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
696701
let attrs = tcx.get_attrs(id, sym::rustc_autodiff);
697702

@@ -726,20 +731,22 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
726731
return AutoDiffAttrs::source();
727732
}
728733

729-
let msg_ad_mode = "autodiff attribute must contain autodiff mode";
730-
let (mode, list) = match list.split_first() {
731-
Some((
732-
NestedMetaItem::MetaItem(MetaItem { path: ref p1, kind: MetaItemKind::Word, .. }),
733-
list,
734-
)) => (p1.segments.first().unwrap().ident, list),
735-
_ => {
736-
tcx.sess
737-
.struct_span_err(attr.span, msg_ad_mode)
738-
.span_label(attr.span, "empty argument list")
739-
.emit();
740-
741-
return AutoDiffAttrs::inactive();
742-
}
734+
let [mode, input_activities @ .., ret_activity] = &list[..] else {
735+
tcx.sess
736+
.struct_span_err(attr.span, msg_once)
737+
.span_label(attr.span, "Implementation bug in autodiff_attrs. Please report this!")
738+
.emit();
739+
return AutoDiffAttrs::inactive();
740+
};
741+
let mode = if let NestedMetaItem::MetaItem(MetaItem { path: ref p1, .. }) = mode {
742+
p1.segments.first().unwrap().ident
743+
} else {
744+
let msg = "autodiff attribute must contain autodiff mode";
745+
tcx.sess
746+
.struct_span_err(attr.span, msg)
747+
.span_label(attr.span, "empty argument list")
748+
.emit();
749+
return AutoDiffAttrs::inactive();
743750
};
744751

745752
// parse mode
@@ -752,27 +759,23 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
752759
.struct_span_err(attr.span, msg_mode)
753760
.span_label(attr.span, "invalid mode")
754761
.emit();
755-
756762
return AutoDiffAttrs::inactive();
757763
}
758764
};
759765

760-
let msg_ret_activity = "autodiff attribute must contain the return activity";
761-
let (ret_symbol, list) = match list.split_last() {
762-
Some((
763-
NestedMetaItem::MetaItem(MetaItem { path: ref p1, kind: MetaItemKind::Word, .. }),
764-
list,
765-
)) => (p1.segments.first().unwrap().ident, list),
766-
_ => {
767-
tcx.sess
768-
.struct_span_err(attr.span, msg_ret_activity)
769-
.span_label(attr.span, "missing return activity")
770-
.emit();
771-
772-
return AutoDiffAttrs::inactive();
773-
}
766+
// First read the ret symbol from the attribute
767+
let ret_symbol = if let NestedMetaItem::MetaItem(MetaItem { path: ref p1, .. }) = ret_activity {
768+
p1.segments.first().unwrap().ident
769+
} else {
770+
let msg = "autodiff attribute must contain the return activity";
771+
tcx.sess
772+
.struct_span_err(attr.span, msg)
773+
.span_label(attr.span, "missing return activity")
774+
.emit();
775+
return AutoDiffAttrs::inactive();
774776
};
775777

778+
// Then parse it into an actual DiffActivity
776779
let msg_unknown_ret_activity = "unknown return activity";
777780
let ret_activity = match DiffActivity::from_str(ret_symbol.as_str()) {
778781
Ok(x) => x,
@@ -781,26 +784,22 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
781784
.struct_span_err(attr.span, msg_unknown_ret_activity)
782785
.span_label(attr.span, "invalid return activity")
783786
.emit();
784-
785787
return AutoDiffAttrs::inactive();
786788
}
787789
};
788790

791+
// Now parse all the intermediate (inptut) activities
789792
let msg_arg_activity = "autodiff attribute must contain the return activity";
790793
let mut arg_activities: Vec<DiffActivity> = vec![];
791-
for arg in list {
792-
let arg_symbol = match arg {
793-
NestedMetaItem::MetaItem(MetaItem {
794-
path: ref p2, kind: MetaItemKind::Word, ..
795-
}) => p2.segments.first().unwrap().ident,
796-
_ => {
797-
tcx.sess
798-
.struct_span_err(attr.span, msg_arg_activity)
799-
.span_label(attr.span, "missing return activity")
800-
.emit();
801-
802-
return AutoDiffAttrs::inactive();
803-
}
794+
for arg in input_activities {
795+
let arg_symbol = if let NestedMetaItem::MetaItem(MetaItem { path: ref p2, .. }) = arg {
796+
p2.segments.first().unwrap().ident
797+
} else {
798+
tcx.sess
799+
.struct_span_err(attr.span, msg_arg_activity)
800+
.span_label(attr.span, "Implementation bug, please report this!")
801+
.emit();
802+
return AutoDiffAttrs::inactive();
804803
};
805804

806805
match DiffActivity::from_str(arg_symbol.as_str()) {
@@ -810,45 +809,22 @@ fn autodiff_attrs(tcx: TyCtxt<'_>, id: DefId) -> AutoDiffAttrs {
810809
.struct_span_err(attr.span, msg_unknown_ret_activity)
811810
.span_label(attr.span, "invalid input activity")
812811
.emit();
813-
814812
return AutoDiffAttrs::inactive();
815813
}
816814
}
817815
}
818816

819-
let msg_fwd_incompatible_ret = "Forward Mode is incompatible with Active ret";
820-
let msg_fwd_incompatible_arg = "Forward Mode is incompatible with Active ret";
821-
let msg_rev_incompatible_arg =
822-
"Reverse Mode is only compatible with Active, None, or Const ret";
823-
if mode == DiffMode::Forward {
824-
if ret_activity == DiffActivity::Active {
825-
tcx.sess
826-
.struct_span_err(attr.span, msg_fwd_incompatible_ret)
827-
.span_label(attr.span, "invalid return activity")
828-
.emit();
829-
return AutoDiffAttrs::inactive();
830-
}
831-
if arg_activities.iter().filter(|&x| *x == DiffActivity::Active).count() > 0 {
832-
tcx.sess
833-
.struct_span_err(attr.span, msg_fwd_incompatible_arg)
834-
.span_label(attr.span, "invalid input activity")
835-
.emit();
836-
return AutoDiffAttrs::inactive();
837-
}
817+
let msg = "Invalid activity for mode";
818+
let valid_input = valid_input_activities(mode, &arg_activities);
819+
let valid_ret = valid_ret_activity(mode, ret_activity);
820+
if !valid_input || !valid_ret {
821+
tcx.sess
822+
.struct_span_err(attr.span, msg)
823+
.span_label(attr.span, "invalid activity")
824+
.emit();
825+
return AutoDiffAttrs::inactive();
838826
}
839827

840-
if mode == DiffMode::Reverse {
841-
if ret_activity == DiffActivity::Duplicated
842-
|| ret_activity == DiffActivity::DuplicatedOnly
843-
{
844-
dbg!("ret_activity = {:?}", ret_activity);
845-
tcx.sess
846-
.struct_span_err(attr.span, msg_rev_incompatible_arg)
847-
.span_label(attr.span, "invalid return activity")
848-
.emit();
849-
return AutoDiffAttrs::inactive();
850-
}
851-
}
852828

853829
AutoDiffAttrs { mode, ret_activity, input_activity: arg_activities }
854830
}

0 commit comments

Comments
 (0)