diff --git a/jax/_src/pallas/mosaic_gpu/primitives.py b/jax/_src/pallas/mosaic_gpu/primitives.py index c4a567e688eb..6b6654402c30 100644 --- a/jax/_src/pallas/mosaic_gpu/primitives.py +++ b/jax/_src/pallas/mosaic_gpu/primitives.py @@ -61,6 +61,7 @@ Layout = gpu_core.Layout ParameterizedLayout = gpu_core.ParameterizedLayout SomeLayout = gpu_core.SomeLayout +ReducedLayout = gpu_core.ReducedLayout def _check_ref( @@ -2216,10 +2217,10 @@ def add_one(ctx, smem_ref): raise ValueError( "inline_mgpu_p only supports plgpu.ShapeDtypeStruct return types." ) - if not all(isinstance(r, (Layout, ParameterizedLayout, RefType)) for r in flat_arg_types): + if not all(isinstance(r, (Layout, ParameterizedLayout, ReducedLayout, RefType)) for r in flat_arg_types): raise ValueError( - "inline_mgpu_p only supports only Layout, ParameterizedLayout and" - " RefType arg types." + "inline_mgpu_p only supports only Layout, ParameterizedLayout," + " ReducedLayout and RefType arg types." ) def inner(f): @@ -2236,7 +2237,9 @@ def wrapper(*args): if isinstance(a, state_types.TransformedRef) and isinstance(t, RefType): raw_flat_args.append(a.ref) ref_transforms.append(a.transforms) - elif isinstance(aval := jax_core.get_aval(a), jax_core.ShapedArray) and isinstance(t, (ParameterizedLayout, Layout)): + elif isinstance( + aval := jax_core.get_aval(a), jax_core.ShapedArray + ) and isinstance(t, (ParameterizedLayout, Layout, ReducedLayout)): raw_flat_args.append(a) ref_transforms.append(None) elif isinstance(aval, state.AbstractRef) and isinstance(t, RefType): @@ -2313,7 +2316,11 @@ def _type_check_mgpu_lane_semantics(v, ty): raise ValueError( f"Array layout mismatch: expected {v.layout} got {ty.layout.to_mgpu()}." ) - case (Layout() , mgpu.FragmentedArray()) | (ParameterizedLayout(), mgpu.FragmentedArray()): + case ( + (Layout(), mgpu.FragmentedArray()) + | (ParameterizedLayout(), mgpu.FragmentedArray()) + | (ReducedLayout(), mgpu.FragmentedArray()) + ): if ty.to_mgpu() != v.layout: raise ValueError(f"Unexpected layout for {v} (expected: {ty})") case _: