@@ -164,10 +164,10 @@ fn generate_enzyme_call<'ll>(
164
164
let mut activity_pos = 0;
165
165
let outer_args: Vec<&llvm::Value> = get_params(outer_fn);
166
166
while activity_pos < inputs.len() {
167
- let activity = inputs[activity_pos as usize];
167
+ let diff_activity = inputs[activity_pos as usize];
168
168
// Duplicated arguments received a shadow argument, into which enzyme will write the
169
169
// gradient.
170
- let (activity, duplicated): (&Metadata, bool) = match activity {
170
+ let (activity, duplicated): (&Metadata, bool) = match diff_activity {
171
171
DiffActivity::None => panic!("not a valid input activity"),
172
172
DiffActivity::Const => (enzyme_const, false),
173
173
DiffActivity::Active => (enzyme_out, false),
@@ -222,7 +222,12 @@ fn generate_enzyme_call<'ll>(
222
222
// A duplicated pointer will have the following two outer_fn arguments:
223
223
// (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
224
224
// (..., metadata! enzyme_dup, ptr, ptr, ...).
225
- assert!(llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer);
225
+ if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly) {
226
+ assert!(
227
+ llvm::LLVMRustGetTypeKind(next_outer_ty) == llvm::TypeKind::Pointer
228
+ );
229
+ }
230
+ // In the case of Dual we don't have assumptions, e.g. f32 would be valid.
226
231
args.push(next_outer_arg);
227
232
outer_pos += 2;
228
233
activity_pos += 1;
0 commit comments