Skip to content

Commit f8c263b

Browse files
committed
use propper rustc error handler for mode/act check
1 parent c90ab5a commit f8c263b

File tree

4 files changed

+105
-26
lines changed

4 files changed

+105
-26
lines changed

compiler/rustc_ast/src/expand/autodiff_attrs.rs

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::expand::typetree::TypeTree;
22
use std::str::FromStr;
33
use thin_vec::ThinVec;
4-
4+
use std::fmt::{Display, Formatter};
55
use crate::NestedMetaItem;
66

77
#[allow(dead_code)]
@@ -13,6 +13,17 @@ pub enum DiffMode {
1313
Reverse,
1414
}
1515

16+
impl Display for DiffMode {
17+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
18+
match self {
19+
DiffMode::Inactive => write!(f, "Inactive"),
20+
DiffMode::Source => write!(f, "Source"),
21+
DiffMode::Forward => write!(f, "Forward"),
22+
DiffMode::Reverse => write!(f, "Reverse"),
23+
}
24+
}
25+
}
26+
1627
pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
1728
match mode {
1829
DiffMode::Inactive => false,
@@ -30,26 +41,28 @@ pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
3041
}
3142
}
3243

33-
pub fn valid_input_activities(mode: DiffMode, activity_vec: &[DiffActivity]) -> bool {
34-
for &activity in activity_vec {
35-
let valid = match mode {
44+
pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
45+
return match mode {
3646
DiffMode::Inactive => false,
3747
DiffMode::Source => false,
3848
DiffMode::Forward => {
3949
// These are the only valid cases
40-
activity == DiffActivity::Dual || activity == DiffActivity::DualOnly || activity == DiffActivity::Const
50+
activity == DiffActivity::Dual ||
51+
activity == DiffActivity::DualOnly ||
52+
activity == DiffActivity::Const
4153
}
4254
DiffMode::Reverse => {
4355
// These are the only valid cases
44-
activity == DiffActivity::Active || activity == DiffActivity::ActiveOnly || activity == DiffActivity::Const
45-
|| activity == DiffActivity::Duplicated || activity == DiffActivity::DuplicatedOnly
56+
activity == DiffActivity::Active ||
57+
activity == DiffActivity::ActiveOnly ||
58+
activity == DiffActivity::Const ||
59+
activity == DiffActivity::Duplicated ||
60+
activity == DiffActivity::DuplicatedOnly
4661
}
4762
};
48-
if !valid {
49-
return false;
50-
}
51-
}
52-
true
63+
}
64+
pub fn valid_input_activities(mode: DiffMode, activity_vec: &[DiffActivity]) -> bool {
65+
return activity_vec.iter().any(|&x| !valid_input_activity(mode, x));
5366
}
5467

5568
#[allow(dead_code)]
@@ -65,6 +78,21 @@ pub enum DiffActivity {
6578
DuplicatedOnly,
6679
}
6780

81+
impl Display for DiffActivity {
82+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
83+
match self {
84+
DiffActivity::None => write!(f, "None"),
85+
DiffActivity::Const => write!(f, "Const"),
86+
DiffActivity::Active => write!(f, "Active"),
87+
DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
88+
DiffActivity::Dual => write!(f, "Dual"),
89+
DiffActivity::DualOnly => write!(f, "DualOnly"),
90+
DiffActivity::Duplicated => write!(f, "Duplicated"),
91+
DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
92+
}
93+
}
94+
}
95+
6896
impl FromStr for DiffMode {
6997
type Err = ();
7098

compiler/rustc_builtin_macros/messages.ftl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ builtin_macros_alloc_must_statics = allocators must be statics
33
44
builtin_macros_autodiff = autodiff must be applied to function
55
builtin_macros_autodiff_not_build = this rustc version does not support autodiff
6+
builtin_macros_autodiff_mode_activity = {$act} can not be used in {$mode} Mode
67
78
builtin_macros_asm_clobber_abi = clobber_abi
89
builtin_macros_asm_clobber_no_reg = asm with `clobber_abi` must specify explicit registers for outputs

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
//use crate::util::check_autodiff;
44

55
use crate::errors;
6-
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
6+
use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity};
77
use rustc_ast::ptr::P;
88
use rustc_ast::token::{Token, TokenKind};
99
use rustc_ast::tokenstream::*;
@@ -80,7 +80,7 @@ pub fn expand(
8080
dbg!(&x);
8181
let span = ecx.with_def_site_ctxt(expand_span);
8282

83-
let (d_sig, new_args, idents) = gen_enzyme_decl(&sig, &x, span);
83+
let (d_sig, new_args, idents) = gen_enzyme_decl(ecx, &sig, &x, span);
8484
let new_decl_span = d_sig.span;
8585
let d_body = gen_enzyme_body(
8686
ecx,
@@ -175,6 +175,26 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
175175
ty
176176
}
177177

178+
// TODO We should make this more robust to also
179+
// accept aliases of f32 and f64
180+
#[cfg(llvm_enzyme)]
181+
fn is_float(ty: &ast::Ty) -> bool {
182+
match ty.kind {
183+
TyKind::Path(_, ref path) => {
184+
let last = path.segments.last().unwrap();
185+
last.ident.name == sym::f32 || last.ident.name == sym::f64
186+
}
187+
_ => false,
188+
}
189+
}
190+
#[cfg(llvm_enzyme)]
191+
fn is_ptr_or_ref(ty: &ast::Ty) -> bool {
192+
match ty.kind {
193+
TyKind::Ptr(_) | TyKind::Ref(_, _) => true,
194+
_ => false,
195+
}
196+
}
197+
178198
// The body of our generated functions will consist of two black_Box calls.
179199
// The first will call the primal function with the original arguments.
180200
// The second will just take a tuple containing the new arguments.
@@ -259,6 +279,7 @@ fn gen_primal_call(
259279
// activity.
260280
#[cfg(llvm_enzyme)]
261281
fn gen_enzyme_decl(
282+
ecx: &ExtCtxt<'_>,
262283
sig: &ast::FnSig,
263284
x: &AutoDiffAttrs,
264285
span: Span,
@@ -273,31 +294,50 @@ fn gen_enzyme_decl(
273294
let mut act_ret = ThinVec::new();
274295
for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
275296
d_inputs.push(arg.clone());
297+
if !valid_input_activity(x.mode, *activity) {
298+
ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplicationModeAct {
299+
span,
300+
mode: x.mode.to_string(),
301+
act: activity.to_string()
302+
});
303+
}
276304
match activity {
277305
DiffActivity::Active => {
278-
assert!(x.mode == DiffMode::Reverse);
306+
assert!(is_float(&arg.ty));
279307
act_ret.push(arg.ty.clone());
280308
}
281-
DiffActivity::Duplicated | DiffActivity::Dual => {
309+
DiffActivity::Duplicated => {
310+
assert!(is_ptr_or_ref(&arg.ty));
282311
let mut shadow_arg = arg.clone();
283312
// We += into the shadow in reverse mode.
284-
// Otherwise copy mutability of the original argument.
285-
if activity == &DiffActivity::Duplicated {
286-
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
287-
}
288-
// adjust name depending on mode
313+
shadow_arg.ty = P(assure_mut_ref(&arg.ty));
289314
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
290315
ident.name
291316
} else {
292317
dbg!(&shadow_arg.pat);
293318
panic!("not an ident?");
294319
};
295-
let name: String = match x.mode {
296-
DiffMode::Reverse => format!("d{}", old_name),
297-
DiffMode::Forward => format!("b{}", old_name),
298-
_ => panic!("unsupported mode: {}", old_name),
320+
let name: String = format!("d{}", old_name);
321+
new_inputs.push(name.clone());
322+
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
323+
shadow_arg.pat = P(ast::Pat {
324+
// TODO: Check id
325+
id: ast::DUMMY_NODE_ID,
326+
kind: PatKind::Ident(BindingAnnotation::NONE, ident, None),
327+
span: shadow_arg.pat.span,
328+
tokens: shadow_arg.pat.tokens.clone(),
329+
});
330+
d_inputs.push(shadow_arg);
331+
}
332+
DiffActivity::Dual => {
333+
let mut shadow_arg = arg.clone();
334+
let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
335+
ident.name
336+
} else {
337+
dbg!(&shadow_arg.pat);
338+
panic!("not an ident?");
299339
};
300-
dbg!(&name);
340+
let name: String = format!("b{}", old_name);
301341
new_inputs.push(name.clone());
302342
let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
303343
shadow_arg.pat = P(ast::Pat {
@@ -311,6 +351,7 @@ fn gen_enzyme_decl(
311351
}
312352
_ => {
313353
dbg!(&activity);
354+
panic!("Not implemented");
314355
}
315356
}
316357
if let PatKind::Ident(_, ident, _) = arg.pat.kind {

compiler/rustc_builtin_macros/src/errors.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,15 @@ pub(crate) struct AllocMustStatics {
164164
pub(crate) span: Span,
165165
}
166166

167+
#[derive(Diagnostic)]
168+
#[diag(builtin_macros_autodiff_mode_activity)]
169+
pub(crate) struct AutoDiffInvalidApplicationModeAct {
170+
#[primary_span]
171+
pub(crate) span: Span,
172+
pub(crate) mode: String,
173+
pub(crate) act: String,
174+
}
175+
167176
#[derive(Diagnostic)]
168177
#[diag(builtin_macros_autodiff)]
169178
pub(crate) struct AutoDiffInvalidApplication {

0 commit comments

Comments
 (0)