Skip to content

Commit 3a9b471

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas:SC] Fix the hypothesis OOM condition to avoid v6e OOMs
PiperOrigin-RevId: 835230881
1 parent 9808285 commit 3a9b471

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

tests/pallas/tpu_sparsecore_pallas_test.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,8 @@ def test_block_spec_untiled_slicing(self, data):
323323
slice_shape[-1] *= 128
324324
else:
325325
slice_shape[-1] *= 8
326-
hp.assume(math.prod(slice_shape) <= 25000) # Avoid OOMs.
326+
max_elems = 12000 if jtu.is_device_tpu(6, "e") else 25000
327+
hp.assume(math.prod(slice_shape) <= max_elems) # Avoid OOMs.
327328
rank = len(slice_shape)
328329
offsets = data.draw(
329330
hps.lists(hps.integers(0, 4), min_size=rank, max_size=rank)

0 commit comments

Comments
 (0)