We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
self.accum_ref
1 parent 2287580 commit cb33094Copy full SHA for cb33094
jax/_src/pallas/mosaic/pipeline.py
@@ -1052,10 +1052,10 @@ def set_accumulator(self, init=False):
1052
if self.accum_ref is not None:
1053
accum_dtype = self.accum_ref.dtype
1054
def _init():
1055
- assert self.accum_ref # pyrefly#40
+ assert self.accum_ref is not None # pyrefly#40
1056
self.accum_ref[...] = jnp.zeros_like(self.accum_ref[...])
1057
def _set():
1058
1059
self.accum_ref[...] = self.current_ref[...].astype(accum_dtype)
1060
lax.cond(init, _init, _set)
1061
0 commit comments