Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions jax/_src/pallas/mosaic_gpu/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
Layout = gpu_core.Layout
ParameterizedLayout = gpu_core.ParameterizedLayout
SomeLayout = gpu_core.SomeLayout
ReducedLayout = gpu_core.ReducedLayout


def _check_ref(
Expand Down Expand Up @@ -2216,10 +2217,10 @@ def add_one(ctx, smem_ref):
raise ValueError(
"inline_mgpu_p only supports plgpu.ShapeDtypeStruct return types."
)
if not all(isinstance(r, (Layout, ParameterizedLayout, RefType)) for r in flat_arg_types):
if not all(isinstance(r, (Layout, ParameterizedLayout, ReducedLayout, RefType)) for r in flat_arg_types):
raise ValueError(
"inline_mgpu_p only supports only Layout, ParameterizedLayout and"
" RefType arg types."
"inline_mgpu_p only supports only Layout, ParameterizedLayout,"
" ReducedLayout and RefType arg types."
)

def inner(f):
Expand All @@ -2236,7 +2237,9 @@ def wrapper(*args):
if isinstance(a, state_types.TransformedRef) and isinstance(t, RefType):
raw_flat_args.append(a.ref)
ref_transforms.append(a.transforms)
elif isinstance(aval := jax_core.get_aval(a), jax_core.ShapedArray) and isinstance(t, (ParameterizedLayout, Layout)):
elif isinstance(
aval := jax_core.get_aval(a), jax_core.ShapedArray
) and isinstance(t, (ParameterizedLayout, Layout, ReducedLayout)):
raw_flat_args.append(a)
ref_transforms.append(None)
elif isinstance(aval, state.AbstractRef) and isinstance(t, RefType):
Expand Down Expand Up @@ -2313,7 +2316,11 @@ def _type_check_mgpu_lane_semantics(v, ty):
raise ValueError(
f"Array layout mismatch: expected {v.layout} got {ty.layout.to_mgpu()}."
)
case (Layout() , mgpu.FragmentedArray()) | (ParameterizedLayout(), mgpu.FragmentedArray()):
case (
(Layout(), mgpu.FragmentedArray())
| (ParameterizedLayout(), mgpu.FragmentedArray())
| (ReducedLayout(), mgpu.FragmentedArray())
):
if ty.to_mgpu() != v.layout:
raise ValueError(f"Unexpected layout for {v} (expected: {ty})")
case _:
Expand Down
Loading