Skip to content

Commit 73317cb

Browse files
authored
Rollup merge of rust-lang#147390 - ZuseZ4:autodiff-dbg, r=jieyouxu
Use globals instead of metadata for std::autodiff LLVM's Metadata is quite fragile. In debug builds we use incremental compilation, which caused the metadata to be dropped. With this change we use named globals instead of metadata to instruct Enzyme how to differentiate functions. Globals are proper llvm values and thus can't be dropped. Also added an incremental/dbg test which now passes, to unblock the EnzymeAD CI which wants to run Rust autodiff tests. r? compiler
2 parents 2d09935 + 6ce845a commit 73317cb

File tree

4 files changed

+74
-81
lines changed

4 files changed

+74
-81
lines changed

compiler/rustc_codegen_llvm/src/builder/autodiff.rs

Lines changed: 28 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use crate::builder::{Builder, PlaceRef, UNNAMED};
1313
use crate::context::SimpleCx;
1414
use crate::declare::declare_simple_fn;
1515
use crate::llvm;
16-
use crate::llvm::{Metadata, TRUE, Type};
16+
use crate::llvm::{TRUE, Type};
1717
use crate::value::Value;
1818

1919
pub(crate) fn adjust_activity_to_abi<'tcx>(
@@ -159,32 +159,36 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
159159
let mut outer_pos: usize = 0;
160160
let mut activity_pos = 0;
161161

162-
let enzyme_const = cx.create_metadata(b"enzyme_const");
163-
let enzyme_out = cx.create_metadata(b"enzyme_out");
164-
let enzyme_dup = cx.create_metadata(b"enzyme_dup");
165-
let enzyme_dupv = cx.create_metadata(b"enzyme_dupv");
166-
let enzyme_dupnoneed = cx.create_metadata(b"enzyme_dupnoneed");
167-
let enzyme_dupnoneedv = cx.create_metadata(b"enzyme_dupnoneedv");
162+
// We used to use llvm's metadata to instruct enzyme how to differentiate a function.
163+
// In debug mode we would use incremental compilation which caused the metadata to be
164+
// dropped. This is prevented by now using named globals, which are also understood
165+
// by Enzyme.
166+
let global_const = cx.declare_global("enzyme_const", cx.type_ptr());
167+
let global_out = cx.declare_global("enzyme_out", cx.type_ptr());
168+
let global_dup = cx.declare_global("enzyme_dup", cx.type_ptr());
169+
let global_dupv = cx.declare_global("enzyme_dupv", cx.type_ptr());
170+
let global_dupnoneed = cx.declare_global("enzyme_dupnoneed", cx.type_ptr());
171+
let global_dupnoneedv = cx.declare_global("enzyme_dupnoneedv", cx.type_ptr());
168172

169173
while activity_pos < inputs.len() {
170174
let diff_activity = inputs[activity_pos as usize];
171175
// Duplicated arguments received a shadow argument, into which enzyme will write the
172176
// gradient.
173-
let (activity, duplicated): (&Metadata, bool) = match diff_activity {
177+
let (activity, duplicated): (&llvm::Value, bool) = match diff_activity {
174178
DiffActivity::None => panic!("not a valid input activity"),
175-
DiffActivity::Const => (enzyme_const, false),
176-
DiffActivity::Active => (enzyme_out, false),
177-
DiffActivity::ActiveOnly => (enzyme_out, false),
178-
DiffActivity::Dual => (enzyme_dup, true),
179-
DiffActivity::Dualv => (enzyme_dupv, true),
180-
DiffActivity::DualOnly => (enzyme_dupnoneed, true),
181-
DiffActivity::DualvOnly => (enzyme_dupnoneedv, true),
182-
DiffActivity::Duplicated => (enzyme_dup, true),
183-
DiffActivity::DuplicatedOnly => (enzyme_dupnoneed, true),
184-
DiffActivity::FakeActivitySize(_) => (enzyme_const, false),
179+
DiffActivity::Const => (global_const, false),
180+
DiffActivity::Active => (global_out, false),
181+
DiffActivity::ActiveOnly => (global_out, false),
182+
DiffActivity::Dual => (global_dup, true),
183+
DiffActivity::Dualv => (global_dupv, true),
184+
DiffActivity::DualOnly => (global_dupnoneed, true),
185+
DiffActivity::DualvOnly => (global_dupnoneedv, true),
186+
DiffActivity::Duplicated => (global_dup, true),
187+
DiffActivity::DuplicatedOnly => (global_dupnoneed, true),
188+
DiffActivity::FakeActivitySize(_) => (global_const, false),
185189
};
186190
let outer_arg = outer_args[outer_pos];
187-
args.push(cx.get_metadata_value(activity));
191+
args.push(activity);
188192
if matches!(diff_activity, DiffActivity::Dualv) {
189193
let next_outer_arg = outer_args[outer_pos + 1];
190194
let elem_bytes_size: u64 = match inputs[activity_pos + 1] {
@@ -244,7 +248,7 @@ fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
244248
assert_eq!(cx.type_kind(next_outer_ty3), TypeKind::Integer);
245249
args.push(next_outer_arg2);
246250
}
247-
args.push(cx.get_metadata_value(enzyme_const));
251+
args.push(global_const);
248252
args.push(next_outer_arg);
249253
outer_pos += 2 + 2 * iterations;
250254
activity_pos += 2;
@@ -353,13 +357,13 @@ pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
353357
let mut args = Vec::with_capacity(num_args as usize + 1);
354358
args.push(fn_to_diff);
355359

356-
let enzyme_primal_ret = cx.create_metadata(b"enzyme_primal_return");
360+
let global_primal_ret = cx.declare_global("enzyme_primal_return", cx.type_ptr());
357361
if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
358-
args.push(cx.get_metadata_value(enzyme_primal_ret));
362+
args.push(global_primal_ret);
359363
}
360364
if attrs.width > 1 {
361-
let enzyme_width = cx.create_metadata(b"enzyme_width");
362-
args.push(cx.get_metadata_value(enzyme_width));
365+
let global_width = cx.declare_global("enzyme_width", cx.type_ptr());
366+
args.push(global_width);
363367
args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
364368
}
365369

tests/ui/autodiff/autodiff_illegal.rs

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -110,15 +110,6 @@ fn f14(x: f32) -> Foo {
110110

111111
type MyFloat = f32;
112112

113-
// We would like to support type alias to f32/f64 in argument type in the future,
114-
// but that requires us to implement our checks at a later stage
115-
// like THIR which has type information available.
116-
#[autodiff_reverse(df15, Active, Active)]
117-
fn f15(x: MyFloat) -> f32 {
118-
//~^^ ERROR failed to resolve: use of undeclared type `MyFloat` [E0433]
119-
unimplemented!()
120-
}
121-
122113
// We would like to support type alias to f32/f64 in return type in the future
123114
#[autodiff_reverse(df16, Active, Active)]
124115
fn f16(x: f32) -> MyFloat {
@@ -136,13 +127,6 @@ fn f17(x: f64) -> F64Trans {
136127
unimplemented!()
137128
}
138129

139-
// We would like to support `#[repr(transparent)]` f32/f64 wrapper in argument type in the future
140-
#[autodiff_reverse(df18, Active, Active)]
141-
fn f18(x: F64Trans) -> f64 {
142-
//~^^ ERROR failed to resolve: use of undeclared type `F64Trans` [E0433]
143-
unimplemented!()
144-
}
145-
146130
// Invalid return activity
147131
#[autodiff_forward(df19, Dual, Active)]
148132
fn f19(x: f32) -> f32 {
@@ -163,11 +147,4 @@ fn f21(x: f32) -> f32 {
163147
unimplemented!()
164148
}
165149

166-
struct DoesNotImplDefault;
167-
#[autodiff_forward(df22, Dual)]
168-
pub fn f22() -> DoesNotImplDefault {
169-
//~^^ ERROR the function or associated item `default` exists for tuple `(DoesNotImplDefault, DoesNotImplDefault)`, but its trait bounds were not satisfied
170-
unimplemented!()
171-
}
172-
173150
fn main() {}

tests/ui/autodiff/autodiff_illegal.stderr

Lines changed: 5 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -107,53 +107,24 @@ LL | #[autodiff_reverse(df13, Reverse)]
107107
| ^^^^^^^
108108

109109
error: invalid return activity Active in Forward Mode
110-
--> $DIR/autodiff_illegal.rs:147:1
110+
--> $DIR/autodiff_illegal.rs:131:1
111111
|
112112
LL | #[autodiff_forward(df19, Dual, Active)]
113113
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
114114

115115
error: invalid return activity Dual in Reverse Mode
116-
--> $DIR/autodiff_illegal.rs:153:1
116+
--> $DIR/autodiff_illegal.rs:137:1
117117
|
118118
LL | #[autodiff_reverse(df20, Active, Dual)]
119119
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
120120

121121
error: invalid return activity Duplicated in Reverse Mode
122-
--> $DIR/autodiff_illegal.rs:160:1
122+
--> $DIR/autodiff_illegal.rs:144:1
123123
|
124124
LL | #[autodiff_reverse(df21, Active, Duplicated)]
125125
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
126126

127-
error[E0433]: failed to resolve: use of undeclared type `MyFloat`
128-
--> $DIR/autodiff_illegal.rs:116:1
129-
|
130-
LL | #[autodiff_reverse(df15, Active, Active)]
131-
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ use of undeclared type `MyFloat`
132-
133-
error[E0433]: failed to resolve: use of undeclared type `F64Trans`
134-
--> $DIR/autodiff_illegal.rs:140:1
135-
|
136-
LL | #[autodiff_reverse(df18, Active, Active)]
137-
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ use of undeclared type `F64Trans`
138-
139-
error[E0599]: the function or associated item `default` exists for tuple `(DoesNotImplDefault, DoesNotImplDefault)`, but its trait bounds were not satisfied
140-
--> $DIR/autodiff_illegal.rs:167:1
141-
|
142-
LL | struct DoesNotImplDefault;
143-
| ------------------------- doesn't satisfy `DoesNotImplDefault: Default`
144-
LL | #[autodiff_forward(df22, Dual)]
145-
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ function or associated item cannot be called on `(DoesNotImplDefault, DoesNotImplDefault)` due to unsatisfied trait bounds
146-
|
147-
= note: the following trait bounds were not satisfied:
148-
`DoesNotImplDefault: Default`
149-
which is required by `(DoesNotImplDefault, DoesNotImplDefault): Default`
150-
help: consider annotating `DoesNotImplDefault` with `#[derive(Default)]`
151-
|
152-
LL + #[derive(Default)]
153-
LL | struct DoesNotImplDefault;
154-
|
155-
156-
error: aborting due to 21 previous errors
127+
error: aborting due to 18 previous errors
157128

158-
Some errors have detailed explanations: E0428, E0433, E0599, E0658.
129+
Some errors have detailed explanations: E0428, E0658.
159130
For more information about an error, try `rustc --explain E0428`.

tests/ui/autodiff/incremental.rs

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
//@ revisions: DEBUG RELEASE
2+
//@[RELEASE] compile-flags: -Zautodiff=Enable,NoTT -C opt-level=3 -Clto=fat
3+
//@[DEBUG] compile-flags: -Zautodiff=Enable,NoTT -C opt-level=0 -Clto=fat -C debuginfo=2
4+
//@ needs-enzyme
5+
//@ incremental
6+
//@ no-prefer-dynamic
7+
//@ build-pass
8+
#![crate_type = "bin"]
9+
#![feature(autodiff)]
10+
11+
// We used to use llvm's metadata to instruct enzyme how to differentiate a function.
12+
// In debug mode we would use incremental compilation which caused the metadata to be
13+
// dropped. We now use globals instead and add this test to verify that incremental
14+
// keeps working. Also testing debug mode while at it.
15+
16+
use std::autodiff::autodiff_reverse;
17+
18+
#[autodiff_reverse(bar, Duplicated, Duplicated)]
19+
pub fn foo(r: &[f64; 10], res: &mut f64) {
20+
let mut output = [0.0; 10];
21+
output[0] = r[0];
22+
output[1] = r[1] * r[2];
23+
output[2] = r[4] * r[5];
24+
output[3] = r[2] * r[6];
25+
output[4] = r[1] * r[7];
26+
output[5] = r[2] * r[8];
27+
output[6] = r[1] * r[9];
28+
output[7] = r[5] * r[6];
29+
output[8] = r[5] * r[7];
30+
output[9] = r[4] * r[8];
31+
*res = output.iter().sum();
32+
}
33+
fn main() {
34+
let inputs = Box::new([3.1; 10]);
35+
let mut d_inputs = Box::new([0.0; 10]);
36+
let mut res = Box::new(0.0);
37+
let mut d_res = Box::new(1.0);
38+
39+
bar(&inputs, &mut d_inputs, &mut res, &mut d_res);
40+
dbg!(&d_inputs);
41+
}

0 commit comments

Comments
 (0)