File tree Expand file tree Collapse file tree 1 file changed +4
-7
lines changed
Expand file tree Collapse file tree 1 file changed +4
-7
lines changed Original file line number Diff line number Diff line change @@ -3964,7 +3964,7 @@ Index ad_coop_vec_pack(uint32_t n, const Index *in) {
39643964 for (uint32_t i = 0 ; i < n; ++i) {
39653965 Index index = in[i];
39663966 tmp[i] = jit_index (index);
3967- attached |= ad_index (index) != 0 ;
3967+ attached |= ad_grad_enabled (index);
39683968 }
39693969
39703970 JitVar result = JitVar::steal (jit_coop_vec_pack (n, tmp));
@@ -3973,14 +3973,11 @@ Index ad_coop_vec_pack(uint32_t n, const Index *in) {
39733973 VarInfo vi = jit_set_backend (result.index ());
39743974
39753975 ref<CoopVecPack> ps = new CoopVecPack ();
3976- for (size_t i = 0 ; i < n; ++i) {
3977- if (ad_grad_enabled (in[i]))
3978- ps->add_input (vi.backend , ad_index (in[i]));
3979- }
3976+ for (size_t i = 0 ; i < n; ++i)
3977+ ps->add_input (vi.backend , ad_index (in[i]));
39803978
39813979 uint64_t ad_result = ad_var_new (result.index ());
3982- if (ad_grad_enabled (ad_result))
3983- ps->add_output (vi.backend , ad_index (ad_result));
3980+ ps->add_output (vi.backend , ad_index (ad_result));
39843981
39853982 if (ad_custom_op (ps.get ()))
39863983 return ad_result;
You can’t perform that action at this time.
0 commit comments