Skip to content

Commit e394546

Browse files
committed
small cleanups, make Enzyme config nicer
1 parent 769622d commit e394546

File tree

9 files changed

+57
-49
lines changed

9 files changed

+57
-49
lines changed

compiler/rustc_builtin_macros/messages.ftl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ builtin_macros_alloc_error_must_be_fn = alloc_error_handler must be a function
22
builtin_macros_alloc_must_statics = allocators must be statics
33
44
builtin_macros_autodiff = autodiff must be applied to function
5+
builtin_macros_autodiff_not_build = this rustc version does not support autodiff
56
67
builtin_macros_asm_clobber_abi = clobber_abi
78
builtin_macros_asm_clobber_no_reg = asm with `clobber_abi` must specify explicit registers for outputs

compiler/rustc_builtin_macros/src/autodiff.rs

Lines changed: 28 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
11
#![allow(unused_imports)]
2-
#![allow(unused_variables)]
3-
#![allow(unused_mut)]
42
//use crate::util::check_builtin_macro_attribute;
53
//use crate::util::check_autodiff;
64

@@ -20,12 +18,25 @@ use rustc_span::Symbol;
2018
use std::string::String;
2119
use thin_vec::{thin_vec, ThinVec};
2220

21+
#[cfg(llvm_enzyme)]
2322
fn first_ident(x: &NestedMetaItem) -> rustc_span::symbol::Ident {
2423
let segments = &x.meta_item().unwrap().path.segments;
2524
assert!(segments.len() == 1);
2625
segments[0].ident
2726
}
2827

28+
#[cfg(not(llvm_enzyme))]
29+
pub fn expand(
30+
ecx: &mut ExtCtxt<'_>,
31+
_expand_span: Span,
32+
meta_item: &ast::MetaItem,
33+
item: Annotatable,
34+
) -> Vec<Annotatable> {
35+
ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
36+
return vec![item];
37+
}
38+
39+
#[cfg(llvm_enzyme)]
2940
pub fn expand(
3041
ecx: &mut ExtCtxt<'_>,
3142
expand_span: Span,
@@ -45,24 +56,16 @@ pub fn expand(
4556
let primal = orig_item.ident.clone();
4657

4758
// Allow using `#[autodiff(...)]` only on a Fn
48-
let (fn_item, has_ret, sig, sig_span) = if let Annotatable::Item(item) = &item
59+
let (has_ret, sig, sig_span) = if let Annotatable::Item(item) = &item
4960
&& let ItemKind::Fn(box ast::Fn { sig, .. }) = &item.kind
5061
{
51-
(item, sig.decl.output.has_ret(), sig, ecx.with_call_site_ctxt(sig.span))
62+
(sig.decl.output.has_ret(), sig, ecx.with_call_site_ctxt(sig.span))
5263
} else {
5364
ecx.sess.dcx().emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
5465
return vec![item];
5566
};
5667
// create TokenStream from vec elemtents:
5768
// meta_item doesn't have a .tokens field
58-
let ts: Vec<Token> = meta_item_vec.clone()[1..]
59-
.iter()
60-
.map(|x| {
61-
let val = first_ident(x);
62-
let t = Token::from_ast_ident(val);
63-
t
64-
})
65-
.collect();
6669
let comma: Token = Token::new(TokenKind::Comma, Span::default());
6770
let mut ts: Vec<TokenTree> = vec![];
6871
for t in meta_item_vec.clone()[1..].iter() {
@@ -77,18 +80,15 @@ pub fn expand(
7780
dbg!(&x);
7881
let span = ecx.with_def_site_ctxt(expand_span);
7982

80-
let (d_sig, old_names, new_args, idents) = gen_enzyme_decl(&sig, &x, span);
83+
let (d_sig, new_args, idents) = gen_enzyme_decl(&sig, &x, span);
8184
let new_decl_span = d_sig.span;
8285
let d_body = gen_enzyme_body(
8386
ecx,
8487
primal,
85-
&old_names,
8688
&new_args,
8789
span,
8890
sig_span,
8991
new_decl_span,
90-
&sig,
91-
&d_sig,
9292
idents,
9393
);
9494
let d_ident = meta_item_vec[0].meta_item().unwrap().path.segments[0].ident;
@@ -102,7 +102,7 @@ pub fn expand(
102102
}));
103103
let mut rustc_ad_attr =
104104
P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
105-
let mut attr: ast::Attribute = ast::Attribute {
105+
let attr: ast::Attribute = ast::Attribute {
106106
kind: ast::AttrKind::Normal(rustc_ad_attr.clone()),
107107
id: ast::AttrId::from_u32(0),
108108
style: ast::AttrStyle::Outer,
@@ -116,7 +116,7 @@ pub fn expand(
116116
delim: rustc_ast::token::Delimiter::Parenthesis,
117117
tokens: ts,
118118
});
119-
let mut attr2: ast::Attribute = ast::Attribute {
119+
let attr2: ast::Attribute = ast::Attribute {
120120
kind: ast::AttrKind::Normal(rustc_ad_attr),
121121
id: ast::AttrId::from_u32(0),
122122
style: ast::AttrStyle::Outer,
@@ -131,6 +131,7 @@ pub fn expand(
131131
}
132132

133133
// shadow arguments must be mutable references or ptrs, because Enzyme will write into them.
134+
#[cfg(llvm_enzyme)]
134135
fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
135136
let mut ty = ty.clone();
136137
match ty.kind {
@@ -152,37 +153,21 @@ fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
152153
// The second will just take a tuple containing the new arguments.
153154
// This way we surpress rustc from optimizing any argument away.
154155
// The last line will 'loop {}', to match the return type of the new function
156+
#[cfg(llvm_enzyme)]
155157
fn gen_enzyme_body(
156158
ecx: &ExtCtxt<'_>,
157159
primal: Ident,
158-
old_names: &[String],
159160
new_names: &[String],
160161
span: Span,
161162
sig_span: Span,
162163
new_decl_span: Span,
163-
sig: &ast::FnSig,
164-
d_sig: &ast::FnSig,
165164
idents: Vec<Ident>,
166165
) -> P<ast::Block> {
167166
let blackbox_path = ecx.std_path(&[Symbol::intern("hint"), Symbol::intern("black_box")]);
168-
let zeroed_path = ecx.std_path(&[Symbol::intern("mem"), Symbol::intern("zeroed")]);
169167
let empty_loop_block = ecx.block(span, ThinVec::new());
170168
let loop_expr = ecx.expr_loop(span, empty_loop_block);
171-
172169
let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
173-
let zeroed_call_expr = ecx.expr_path(ecx.path(span, zeroed_path));
174-
175-
let mem_zeroed_call: Stmt =
176-
ecx.stmt_expr(ecx.expr_call(span, zeroed_call_expr.clone(), thin_vec![]));
177-
let unsafe_block_with_zeroed_call: P<ast::Expr> = ecx.expr_block(P(ast::Block {
178-
stmts: thin_vec![mem_zeroed_call],
179-
id: ast::DUMMY_NODE_ID,
180-
rules: ast::BlockCheckMode::Unsafe(ast::UserProvided),
181-
span: sig_span,
182-
tokens: None,
183-
could_be_bare_literal: false,
184-
}));
185-
let primal_call = gen_primal_call(ecx, span, primal, sig, idents);
170+
let primal_call = gen_primal_call(ecx, span, primal, idents);
186171
// create ::core::hint::black_box(array(arr));
187172
let black_box_primal_call =
188173
ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![primal_call.clone()]);
@@ -207,11 +192,11 @@ fn gen_enzyme_body(
207192
body
208193
}
209194

195+
#[cfg(llvm_enzyme)]
210196
fn gen_primal_call(
211197
ecx: &ExtCtxt<'_>,
212198
span: Span,
213199
primal: Ident,
214-
sig: &ast::FnSig,
215200
idents: Vec<Ident>,
216201
) -> P<ast::Expr> {
217202
let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
@@ -226,17 +211,18 @@ fn gen_primal_call(
226211
// zero-initialized by Enzyme). Active arguments are not handled yet.
227212
// Each argument of the primal function (and the return type if existing) must be annotated with an
228213
// activity.
214+
#[cfg(llvm_enzyme)]
229215
fn gen_enzyme_decl(
230216
sig: &ast::FnSig,
231217
x: &AutoDiffAttrs,
232218
span: Span,
233-
) -> (ast::FnSig, Vec<String>, Vec<String>, Vec<Ident>) {
219+
) -> (ast::FnSig, Vec<String>, Vec<Ident>) {
234220
assert!(sig.decl.inputs.len() == x.input_activity.len());
235221
assert!(sig.decl.output.has_ret() == x.has_ret_activity());
236222
let mut d_decl = sig.decl.clone();
237223
let mut d_inputs = Vec::new();
238224
let mut new_inputs = Vec::new();
239-
let mut old_names = Vec::new();
225+
//let mut old_names = Vec::new();
240226
let mut idents = Vec::new();
241227
let mut act_ret = ThinVec::new();
242228
for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
@@ -256,7 +242,7 @@ fn gen_enzyme_decl(
256242
dbg!(&shadow_arg.pat);
257243
panic!("not an ident?");
258244
};
259-
old_names.push(old_name.to_string());
245+
//old_names.push(old_name.to_string());
260246
let name: String = match x.mode {
261247
DiffMode::Reverse => format!("d{}", old_name),
262248
DiffMode::Forward => format!("b{}", old_name),
@@ -320,7 +306,7 @@ fn gen_enzyme_decl(
320306
// return type. This might require changing the return type to a
321307
// tuple.
322308
if act_ret.len() > 0 {
323-
let mut ret_ty = match d_decl.output {
309+
let ret_ty = match d_decl.output {
324310
FnRetTy::Ty(ref ty) => {
325311
act_ret.insert(0, ty.clone());
326312
let kind = TyKind::Tup(act_ret);
@@ -339,5 +325,5 @@ fn gen_enzyme_decl(
339325
}
340326

341327
let d_sig = FnSig { header: sig.header.clone(), decl: d_decl, span };
342-
(d_sig, old_names, new_inputs, idents)
328+
(d_sig, new_inputs, idents)
343329
}

compiler/rustc_builtin_macros/src/errors.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,13 @@ pub(crate) struct AutoDiffInvalidApplication {
171171
pub(crate) span: Span,
172172
}
173173

174+
#[derive(Diagnostic)]
175+
#[diag(builtin_macros_autodiff_not_build)]
176+
pub(crate) struct AutoDiffSupportNotBuild {
177+
#[primary_span]
178+
pub(crate) span: Span,
179+
}
180+
174181
#[derive(Diagnostic)]
175182
#[diag(builtin_macros_concat_bytes_invalid)]
176183
pub(crate) struct ConcatBytesInvalid {

compiler/rustc_codegen_llvm/src/llvm/ffi.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
#![allow(non_camel_case_types)]
22
#![allow(non_upper_case_globals)]
3-
#![allow(unexpected_cfgs)]
3+
//#![allow(unexpected_cfgs)]
44

55
use rustc_ast::expand::autodiff_attrs::DiffActivity;
66

@@ -2609,10 +2609,10 @@ extern "C" {
26092609
) -> *mut c_void;
26102610
}
26112611

2612-
#[cfg(autodiff_fallback)]
2612+
#[cfg(not(llvm_enzyme))]
26132613
pub use self::Fallback_AD::*;
26142614

2615-
#[cfg(autodiff_fallback)]
2615+
#[cfg(not(llvm_enzyme))]
26162616
pub mod Fallback_AD {
26172617
#![allow(unused_variables)]
26182618
use super::*;
@@ -2745,9 +2745,9 @@ pub mod Shared_AD {
27452745
use libc::size_t;
27462746
use super::Context;
27472747

2748-
#[cfg(autodiff_fallback)]
2748+
#[cfg(not(llvm_enzyme))]
27492749
use super::Fallback_AD::*;
2750-
#[cfg(not(autodiff_fallback))]
2750+
#[cfg(llvm_enzyme)]
27512751
use super::Enzyme_AD::*;
27522752

27532753
use core::fmt;
@@ -2931,13 +2931,13 @@ pub mod Shared_AD {
29312931
}
29322932
}
29332933

2934-
#[cfg(not(autodiff_fallback))]
2934+
#[cfg(llvm_enzyme)]
29352935
pub use self::Enzyme_AD::*;
29362936

29372937
// Enzyme is an optional component, so we do need to provide a fallback when it is ont getting
29382938
// compiled. We deny the usage of #[autodiff(..)] on a higher level, so a placeholder implementation
29392939
// here is completely fine.
2940-
#[cfg(not(autodiff_fallback))]
2940+
#[cfg(llvm_enzyme)]
29412941
pub mod Enzyme_AD {
29422942
use super::*;
29432943

compiler/rustc_session/src/config.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1273,6 +1273,7 @@ fn default_configuration(sess: &Session) -> Cfg {
12731273
//
12741274
// NOTE: These insertions should be kept in sync with
12751275
// `CheckCfg::fill_well_known` below.
1276+
ins_none!(sym::autodiff_fallback);
12761277

12771278
if sess.opts.debug_assertions {
12781279
ins_none!(sym::debug_assertions);
@@ -1460,6 +1461,7 @@ impl CheckCfg {
14601461
//
14611462
// When adding a new config here you should also update
14621463
// `tests/ui/check-cfg/well-known-values.rs`.
1464+
ins!(sym::autodiff_fallback, no_values);
14631465

14641466
ins!(sym::debug_assertions, no_values);
14651467

compiler/rustc_span/src/symbol.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ symbols! {
440440
augmented_assignments,
441441
auto_traits,
442442
autodiff,
443+
autodiff_fallback,
443444
automatically_derived,
444445
avx,
445446
avx512_target_feature,
@@ -493,6 +494,7 @@ symbols! {
493494
cfg_accessible,
494495
cfg_attr,
495496
cfg_attr_multi,
497+
cfg_autodiff_fallback,
496498
cfg_doctest,
497499
cfg_eval,
498500
cfg_hide,

src/bootstrap/src/core/build_steps/compile.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1048,6 +1048,9 @@ pub fn rustc_cargo_env(
10481048
if builder.config.rust_verify_llvm_ir {
10491049
cargo.env("RUSTC_VERIFY_LLVM_IR", "1");
10501050
}
1051+
if builder.config.llvm_enzyme {
1052+
cargo.rustflag("--cfg=llvm_enzyme");
1053+
}
10511054

10521055
// Note that this is disabled if LLVM itself is disabled or we're in a check
10531056
// build. If we are in a check build we still go ahead here presuming we've

src/bootstrap/src/core/config/config.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1095,6 +1095,7 @@ define_config! {
10951095
codegen_backends: Option<Vec<String>> = "codegen-backends",
10961096
lld: Option<bool> = "lld",
10971097
lld_mode: Option<LldMode> = "use-lld",
1098+
llvm_enzyme: Option<bool> = "llvm-enzyme",
10981099
llvm_tools: Option<bool> = "llvm-tools",
10991100
deny_warnings: Option<bool> = "deny-warnings",
11001101
backtrace_on_ice: Option<bool> = "backtrace-on-ice",
@@ -1545,6 +1546,7 @@ impl Config {
15451546
save_toolstates,
15461547
codegen_backends,
15471548
lld,
1549+
llvm_enzyme,
15481550
llvm_tools,
15491551
deny_warnings,
15501552
backtrace_on_ice,
@@ -1634,6 +1636,8 @@ impl Config {
16341636
}
16351637

16361638
set(&mut config.llvm_tools_enabled, llvm_tools);
1639+
config.llvm_enzyme =
1640+
llvm_enzyme.unwrap_or(config.channel == "dev" || config.channel == "nightly");
16371641
config.rustc_parallel =
16381642
parallel_compiler.unwrap_or(config.channel == "dev" || config.channel == "nightly");
16391643
config.rustc_default_linker = default_linker;

src/bootstrap/src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ const LLD_FILE_NAMES: &[&str] = &["ld.lld", "ld64.lld", "lld-link", "wasm-ld"];
7676
/// (Mode restriction, config name, config values (if any))
7777
const EXTRA_CHECK_CFGS: &[(Option<Mode>, &str, Option<&[&'static str]>)] = &[
7878
(None, "bootstrap", None),
79+
(Some(Mode::Rustc), "llvm_enzyme", None),
80+
(Some(Mode::Codegen), "llvm_enzyme", None),
81+
(Some(Mode::ToolRustc), "llvm_enzyme", None),
7982
(Some(Mode::Rustc), "parallel_compiler", None),
8083
(Some(Mode::ToolRustc), "parallel_compiler", None),
8184
(Some(Mode::ToolRustc), "rust_analyzer", None),

0 commit comments

Comments
 (0)