@@ -519,6 +519,45 @@ def scan_count(
519519 return scan_count_p .bind (x , lax .full (x .shape , True ) if mask is None else mask )
520520
521521
522+ masked_cummax_p = jax_core .Primitive ("masked_cummax" )
523+ masked_cummax_p .multiple_results = False
524+
525+ @masked_cummax_p .def_abstract_eval
526+ def _masked_cummax_abstract_eval (x , mask ):
527+ if x .dtype != jnp .int32 and x .dtype != jnp .float32 :
528+ raise NotImplementedError (f"x.dtype={ x .dtype } must be int32 or float32" )
529+ if not jnp .issubdtype (mask .dtype , jnp .bool ):
530+ raise TypeError (f"mask.dtype={ mask .dtype } is not a boolean dtype" )
531+ if x .shape != mask .shape :
532+ raise ValueError (f"x.shape={ x .shape } != mask.shape={ mask .shape } " )
533+ return x
534+
535+ @sc_lowering .register_lowering_rule (masked_cummax_p )
536+ def _masked_cummax_lowering_rule (ctx : sc_lowering .LoweringRuleContext , x , mask ):
537+ del ctx # Unused.
538+ return tpu .scan (
539+ x .type , x , ir .Attribute .parse ("#tpu.reduction_kind<max>" ), mask = mask )
540+
541+ def cummax (x : jax .Array , * , mask : jax .Array | None = None ) -> jax .Array :
542+ """Returns the cumulative max of the array along its innermost axis.
543+
544+ Elements from `x` will pass through directly to the result until the first
545+ valid value is encountered (`mask[i] == True`). If you would like to specify
546+ a default value for such elements instead, write
547+ `x = jnp.where(mask, x, default_value)` before or after calling this function.
548+
549+ Args:
550+ x: An array of integers or floats.
551+ mask: An optional array of booleans, which specifies which elements of `x`
552+ are eligible for the max. If `None`, all elements are eligible.
553+ """
554+ if x .ndim != 1 :
555+ raise NotImplementedError (f"masked_cummax: x={ x .aval } must be rank 1" )
556+ if mask is None :
557+ mask = lax .full (x .shape , True )
558+ return masked_cummax_p .bind (x , mask )
559+
560+
522561masked_cumsum_p = jax_core .Primitive ("masked_cumsum" )
523562masked_cumsum_p .multiple_results = False
524563
@@ -553,18 +592,20 @@ def _lax_cumsum_lowering_rule(ctx: sc_lowering.LoweringRuleContext, x, axis,
553592 return tpu .scan (
554593 x .type , x , ir .Attribute .parse ("#tpu.reduction_kind<sum>" ), mask = c1v )
555594
556- def masked_cumsum (x : jax .Array , mask : jax .Array ) -> jax .Array :
595+ def cumsum (x : jax .Array , * , mask : jax .Array | None = None ) -> jax .Array :
557596 """Returns the cumulative sum of the array along its innermost axis.
558597
559598 This differs from `jnp.cumsum` in that it takes an additional `mask` argument.
560599
561600 Args:
562601 x: An array of integers or floats.
563- mask: An optional array of booleans, which specifies which elements ``x` `
564- are eligible for summing. If `` None` `, all elements are eligible.
602+ mask: An optional array of booleans, which specifies which elements of `x `
603+ are eligible for summing. If `None`, all elements are eligible.
565604 """
566605 if x .ndim != 1 :
567- raise NotImplementedError (f"masked_cumsum: x={ x .aval } must be rank 1" )
606+ raise NotImplementedError (f"cumsum: x={ x .aval } must be rank 1" )
607+ if mask is None :
608+ mask = lax .full (x .shape , True )
568609 return masked_cumsum_p .bind (x , mask )
569610
570611
0 commit comments