Skip to content

Commit 54fd738

Browse files
Add SMEM as a supported Pallas output memory space.
PiperOrigin-RevId: 712144883
1 parent 9af2970 commit 54fd738

File tree

2 files changed

+5
-0
lines changed

2 files changed

+5
-0
lines changed

jax/_src/pallas/mosaic/pallas_call_registration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ def _get_memory_space_from_aval(
8484
return None
8585
case tpu_core.TPUMemorySpace.VMEM:
8686
return tpu_custom_call.MemorySpace.VMEM
87+
case tpu_core.TPUMemorySpace.SMEM:
88+
return tpu_custom_call.MemorySpace.SMEM
8789
case tpu_core.TPUMemorySpace.SEMAPHORE:
8890
return tpu_custom_call.MemorySpace.SEMAPHORE_MEM
8991
return None

jax/_src/tpu_custom_call.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ class MemorySpace(enum.Enum):
8383
HBM = enum.auto()
8484
VMEM = enum.auto()
8585
SEMAPHORE_MEM = enum.auto()
86+
SMEM = enum.auto()
8687

8788
@property
8889
def color(self) -> int:
@@ -92,6 +93,8 @@ def color(self) -> int:
9293
return 1
9394
elif self == MemorySpace.SEMAPHORE_MEM:
9495
return 2
96+
elif self == MemorySpace.SMEM:
97+
return 4
9598
else:
9699
raise ValueError("invalid memory space: " + str(self))
97100

0 commit comments

Comments
 (0)