Skip to content

Commit 79045d5

Browse files
brianwa84Google-ML-Automation
authored andcommitted
Adds a plsc.cummax primitive to compute the cumulative maximum along the (masked) elements of a vector.
Also renames `plsc.masked_cumsum` to `plsc.cumsum`, since the mask is optional. PiperOrigin-RevId: 835144185
1 parent 2fbc76d commit 79045d5

File tree

3 files changed

+65
-6
lines changed

3 files changed

+65
-6
lines changed

jax/_src/pallas/mosaic/sc_primitives.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,45 @@ def scan_count(
519519
return scan_count_p.bind(x, lax.full(x.shape, True) if mask is None else mask)
520520

521521

522+
masked_cummax_p = jax_core.Primitive("masked_cummax")
523+
masked_cummax_p.multiple_results = False
524+
525+
@masked_cummax_p.def_abstract_eval
526+
def _masked_cummax_abstract_eval(x, mask):
527+
if x.dtype != jnp.int32 and x.dtype != jnp.float32:
528+
raise NotImplementedError(f"x.dtype={x.dtype} must be int32 or float32")
529+
if not jnp.issubdtype(mask.dtype, jnp.bool):
530+
raise TypeError(f"mask.dtype={mask.dtype} is not a boolean dtype")
531+
if x.shape != mask.shape:
532+
raise ValueError(f"x.shape={x.shape} != mask.shape={mask.shape}")
533+
return x
534+
535+
@sc_lowering.register_lowering_rule(masked_cummax_p)
536+
def _masked_cummax_lowering_rule(ctx: sc_lowering.LoweringRuleContext, x, mask):
537+
del ctx # Unused.
538+
return tpu.scan(
539+
x.type, x, ir.Attribute.parse("#tpu.reduction_kind<max>"), mask=mask)
540+
541+
def cummax(x: jax.Array, *, mask: jax.Array | None = None) -> jax.Array:
542+
"""Returns the cumulative max of the array along its innermost axis.
543+
544+
Elements from `x` will pass through directly to the result until the first
545+
valid value is encountered (`mask[i] == True`). If you would like to specify
546+
a default value for such elements instead, write
547+
`x = jnp.where(mask, x, default_value)` before or after calling this function.
548+
549+
Args:
550+
x: An array of integers or floats.
551+
mask: An optional array of booleans, which specifies which elements of `x`
552+
are eligible for the max. If `None`, all elements are eligible.
553+
"""
554+
if x.ndim != 1:
555+
raise NotImplementedError(f"masked_cummax: x={x.aval} must be rank 1")
556+
if mask is None:
557+
mask = lax.full(x.shape, True)
558+
return masked_cummax_p.bind(x, mask)
559+
560+
522561
masked_cumsum_p = jax_core.Primitive("masked_cumsum")
523562
masked_cumsum_p.multiple_results = False
524563

@@ -553,18 +592,20 @@ def _lax_cumsum_lowering_rule(ctx: sc_lowering.LoweringRuleContext, x, axis,
553592
return tpu.scan(
554593
x.type, x, ir.Attribute.parse("#tpu.reduction_kind<sum>"), mask=c1v)
555594

556-
def masked_cumsum(x: jax.Array, mask: jax.Array) -> jax.Array:
595+
def cumsum(x: jax.Array, *, mask: jax.Array | None = None) -> jax.Array:
557596
"""Returns the cumulative sum of the array along its innermost axis.
558597
559598
This differs from `jnp.cumsum` in that it takes an additional `mask` argument.
560599
561600
Args:
562601
x: An array of integers or floats.
563-
mask: An optional array of booleans, which specifies which elements ``x``
564-
are eligible for summing. If ``None``, all elements are eligible.
602+
mask: An optional array of booleans, which specifies which elements of `x`
603+
are eligible for summing. If `None`, all elements are eligible.
565604
"""
566605
if x.ndim != 1:
567-
raise NotImplementedError(f"masked_cumsum: x={x.aval} must be rank 1")
606+
raise NotImplementedError(f"cumsum: x={x.aval} must be rank 1")
607+
if mask is None:
608+
mask = lax.full(x.shape, True)
568609
return masked_cumsum_p.bind(x, mask)
569610

570611

jax/experimental/pallas/tpu_sc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@
2424
from jax._src.pallas.mosaic.sc_primitives import all_reduce_ffs as all_reduce_ffs
2525
from jax._src.pallas.mosaic.sc_primitives import all_reduce_population_count as all_reduce_population_count
2626
from jax._src.pallas.mosaic.sc_primitives import bitcast as bitcast
27+
from jax._src.pallas.mosaic.sc_primitives import cummax as cummax
28+
from jax._src.pallas.mosaic.sc_primitives import cumsum as cumsum
2729
from jax._src.pallas.mosaic.sc_primitives import load_expanded as load_expanded
2830
from jax._src.pallas.mosaic.sc_primitives import load_gather as load_gather
29-
from jax._src.pallas.mosaic.sc_primitives import masked_cumsum as masked_cumsum
3031
from jax._src.pallas.mosaic.sc_primitives import pack as pack
3132
from jax._src.pallas.mosaic.sc_primitives import PackFormat as PackFormat
3233
from jax._src.pallas.mosaic.sc_primitives import parallel_loop as parallel_loop

tests/pallas/tpu_sparsecore_pallas_test.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1219,10 +1219,27 @@ def test_masked_cumsum(self, dtype):
12191219

12201220
@self.vector_subcore_kernel(out_shape=x)
12211221
def kernel(x_ref, o_ref):
1222-
o_ref[...] = plsc.masked_cumsum(x_ref[...], mask=(x_ref[...] % 2) == 1)
1222+
o_ref[...] = plsc.cumsum(x_ref[...], mask=(x_ref[...] % 2) == 1)
12231223

12241224
np.testing.assert_array_equal(kernel(x), np.cumsum(x * (x % 2)))
12251225

1226+
@parameterized.product(dtype=[jnp.int32, jnp.float32])
1227+
def test_masked_cummax(self, dtype):
1228+
x = np.arange(self.sc_info.num_lanes, dtype=dtype)
1229+
np.random.shuffle(x)
1230+
1231+
@self.vector_subcore_kernel(out_shape=x)
1232+
def kernel(x_ref, o_ref):
1233+
o_ref[...] = plsc.cummax(x_ref[...], mask=(x_ref[...] % 2) == 1)
1234+
1235+
row = np.arange(self.sc_info.num_lanes)[:, np.newaxis]
1236+
col = np.arange(self.sc_info.num_lanes)[np.newaxis, :]
1237+
mask = x % 2
1238+
expected = (x * mask * (col <= row)).max(axis=1)
1239+
has_valid_value_so_far = np.cumsum(mask) > 0
1240+
expected = np.where(has_valid_value_so_far, expected, x)
1241+
np.testing.assert_array_equal(kernel(x), expected)
1242+
12261243
def test_parallel_loop_with_carry(self):
12271244
chunk_size = self.sc_info.num_lanes
12281245
nchunks = 4

0 commit comments

Comments
 (0)