We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 8813973 commit 6dbafedCopy full SHA for 6dbafed
jax/_src/interpreters/ad.py
@@ -275,8 +275,8 @@ def write_primal(v, val):
275
ct_out = core.freeze(ref)
276
write_cotangent(eqn.primitive, val_var, ct_out)
277
elif eqn.primitive is core.freeze_p:
278
- val_var, = eqn.outvars
279
- ref_var, = eqn.invars
+ val_var, = eqn.outvars # type: ignore
+ ref_var, = eqn.invars # type: ignore
280
ct_in = instantiate_zeros(read_cotangent(val_var))
281
write_primal(ref_var, core.mutable_array(ct_in))
282
continue
0 commit comments