Skip to content

Commit bdc17f6

Browse files
DoeringChristiannjroussel
authored andcommitted
ad_coop_vec_pack: test attach with ad_grad_enabled
1 parent a5e9145 commit bdc17f6

File tree

1 file changed

+4
-7
lines changed

1 file changed

+4
-7
lines changed

src/extra/autodiff.cpp

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3964,7 +3964,7 @@ Index ad_coop_vec_pack(uint32_t n, const Index *in) {
39643964
for (uint32_t i = 0; i < n; ++i) {
39653965
Index index = in[i];
39663966
tmp[i] = jit_index(index);
3967-
attached |= ad_index(index) != 0;
3967+
attached |= ad_grad_enabled(index);
39683968
}
39693969

39703970
JitVar result = JitVar::steal(jit_coop_vec_pack(n, tmp));
@@ -3973,14 +3973,11 @@ Index ad_coop_vec_pack(uint32_t n, const Index *in) {
39733973
VarInfo vi = jit_set_backend(result.index());
39743974

39753975
ref<CoopVecPack> ps = new CoopVecPack();
3976-
for (size_t i = 0; i < n; ++i) {
3977-
if (ad_grad_enabled(in[i]))
3978-
ps->add_input(vi.backend, ad_index(in[i]));
3979-
}
3976+
for (size_t i = 0; i < n; ++i)
3977+
ps->add_input(vi.backend, ad_index(in[i]));
39803978

39813979
uint64_t ad_result = ad_var_new(result.index());
3982-
if (ad_grad_enabled(ad_result))
3983-
ps->add_output(vi.backend, ad_index(ad_result));
3980+
ps->add_output(vi.backend, ad_index(ad_result));
39843981

39853982
if (ad_custom_op(ps.get()))
39863983
return ad_result;

0 commit comments

Comments
 (0)