Skip to content

Commit 6dbafed

Browse files
Fix mypy failure
PiperOrigin-RevId: 704748889
1 parent 8813973 commit 6dbafed

File tree

1 file changed

+2
-2
lines changed
  • jax/_src/interpreters

1 file changed

+2
-2
lines changed

jax/_src/interpreters/ad.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,8 +275,8 @@ def write_primal(v, val):
275275
ct_out = core.freeze(ref)
276276
write_cotangent(eqn.primitive, val_var, ct_out)
277277
elif eqn.primitive is core.freeze_p:
278-
val_var, = eqn.outvars
279-
ref_var, = eqn.invars
278+
val_var, = eqn.outvars # type: ignore
279+
ref_var, = eqn.invars # type: ignore
280280
ct_in = instantiate_zeros(read_cotangent(val_var))
281281
write_primal(ref_var, core.mutable_array(ct_in))
282282
continue

0 commit comments

Comments
 (0)