Skip to content

Commit cb33094

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
Fixed self.accum_ref assertion in mosaic/pipeline.py
PiperOrigin-RevId: 874097569
1 parent 2287580 commit cb33094

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

jax/_src/pallas/mosaic/pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,10 +1052,10 @@ def set_accumulator(self, init=False):
10521052
if self.accum_ref is not None:
10531053
accum_dtype = self.accum_ref.dtype
10541054
def _init():
1055-
assert self.accum_ref # pyrefly#40
1055+
assert self.accum_ref is not None # pyrefly#40
10561056
self.accum_ref[...] = jnp.zeros_like(self.accum_ref[...])
10571057
def _set():
1058-
assert self.accum_ref # pyrefly#40
1058+
assert self.accum_ref is not None # pyrefly#40
10591059
self.accum_ref[...] = self.current_ref[...].astype(accum_dtype)
10601060
lax.cond(init, _init, _set)
10611061

0 commit comments

Comments
 (0)