Skip to content

Commit 3045147

Browse files
apaszkeGoogle-ML-Automation
authored andcommitted
[Pallas][NFC] Move the remainder of Semaphore-related extended dtypes to Pallas core
This completes the move started in jax-ml#26673. PiperOrigin-RevId: 741487331
1 parent efa5ae8 commit 3045147

File tree

3 files changed

+54
-41
lines changed

3 files changed

+54
-41
lines changed

jax/_src/pallas/core.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,9 +67,55 @@ def __repr__(self):
6767
SEMAPHORE_INTERPRET_DTYPE = jnp.int16
6868
SEMAPHORE_MAX_VALUE = jnp.iinfo(SEMAPHORE_INTERPRET_DTYPE).max
6969

70-
class semaphore_dtype(dtypes.extended): pass
71-
class semaphore(semaphore_dtype): pass
72-
class barrier_semaphore(semaphore_dtype): pass
70+
class AbstractSemaphoreTyRules:
71+
@staticmethod
72+
def pallas_interpret_element_aval(_) -> jax_core.ShapedArray:
73+
return jax_core.ShapedArray((), SEMAPHORE_INTERPRET_DTYPE)
74+
75+
@staticmethod
76+
def physical_element_aval(_) -> jax_core.ShapedArray:
77+
return jax_core.ShapedArray((), jnp.int32)
78+
79+
# TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy
80+
class AbstractSemaphoreTy(dtypes.ExtendedDType):
81+
name: str
82+
_rules = AbstractSemaphoreTyRules
83+
84+
def __repr__(self) -> str:
85+
return self.name
86+
87+
def __eq__(self, other):
88+
return self.__class__ == other.__class__
89+
90+
def __hash__(self) -> int:
91+
return hash(self.__class__)
92+
93+
class semaphore_dtype(dtypes.extended):
94+
"""Common dtype for all kinds of semaphore dtypes.
95+
96+
This is an abstract class that should never be instantiated, but rather
97+
exists for the sake of `jnp.issubdtype`.
98+
"""
99+
100+
class semaphore(semaphore_dtype):
101+
"""Regular semaphore dtype.
102+
103+
Like its superclass, this class should never be instantiated.
104+
"""
105+
106+
class Semaphore(AbstractSemaphoreTy):
107+
name = "semaphore"
108+
type = semaphore
109+
110+
class barrier_semaphore(semaphore_dtype):
111+
"""Barrier semaphore dtype.
112+
113+
Like its superclass, this class should never be instantiated.
114+
"""
115+
116+
class BarrierSemaphore(AbstractSemaphoreTy):
117+
name = "barrier_semaphore"
118+
type = barrier_semaphore
73119

74120
@runtime_checkable
75121
class CompilerParams(Protocol):

jax/_src/pallas/mosaic/core.py

Lines changed: 4 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import jax
2626
from jax._src import config
2727
from jax._src import core as jax_core
28-
from jax._src import dtypes
2928
from jax._src import util
3029
from jax._src.pallas import core as pallas_core
3130
import jax.numpy as jnp
@@ -114,42 +113,10 @@ def __call__(self, shape: tuple[int, ...], dtype: jnp.dtype):
114113

115114
class dma_semaphore(pallas_core.semaphore_dtype): pass
116115

117-
class AbstractSemaphoreTyRules:
118-
@staticmethod
119-
def pallas_interpret_element_aval(_) -> jax_core.ShapedArray:
120-
return jax_core.ShapedArray((), pallas_core.SEMAPHORE_INTERPRET_DTYPE)
121-
122-
@staticmethod
123-
def physical_element_aval(_) -> jax_core.ShapedArray:
124-
return jax_core.ShapedArray((), jnp.int32)
125-
126-
class AbstractSemaphoreTy(dtypes.ExtendedDType):
127-
name: str
128-
_rules = AbstractSemaphoreTyRules
129-
130-
def __repr__(self) -> str:
131-
return self.name
132-
133-
def __eq__(self, other):
134-
return self.__class__ == other.__class__
135-
136-
def __hash__(self) -> int:
137-
return hash(self.__class__)
138-
139-
# TODO(sharadmv): implement dtype rules for AbstractSemaphoreTy
140-
141-
class SemaphoreTy(AbstractSemaphoreTy):
142-
type = pallas_core.semaphore
143-
name = "sem"
144-
145-
class DmaSemaphoreTy(AbstractSemaphoreTy):
116+
class DMASemaphore(pallas_core.AbstractSemaphoreTy):
146117
type = dma_semaphore
147118
name = "dma_sem"
148119

149-
class BarrierSemaphoreTy(AbstractSemaphoreTy):
150-
type = pallas_core.barrier_semaphore
151-
name = "barrier_sem"
152-
153120
class SemaphoreType(enum.Enum):
154121
REGULAR = "regular"
155122
DMA = "dma"
@@ -158,11 +125,11 @@ class SemaphoreType(enum.Enum):
158125
def __call__(self, shape: tuple[int, ...]):
159126
dtype: Any
160127
if self == SemaphoreType.DMA:
161-
dtype = DmaSemaphoreTy()
128+
dtype = DMASemaphore()
162129
elif self == SemaphoreType.BARRIER:
163-
dtype = BarrierSemaphoreTy()
130+
dtype = pallas_core.BarrierSemaphore()
164131
else:
165-
dtype = SemaphoreTy()
132+
dtype = pallas_core.Semaphore()
166133
return pallas_core.MemoryRef(shape, dtype, TPUMemorySpace.SEMAPHORE)
167134

168135
def get_array_aval(self) -> pallas_core.ShapedArrayWithMemorySpace:

jax/_src/pallas/mosaic/primitives.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -623,7 +623,7 @@ def async_remote_copy(src_ref, dst_ref, send_sem, recv_sem, device_id,
623623
@get_barrier_semaphore_p.def_abstract_eval
624624
def _get_barrier_semaphore_abstract_eval():
625625
return pl_core.AbstractMemoryRef(
626-
jax_core.ShapedArray((), tpu_core.BarrierSemaphoreTy()),
626+
jax_core.ShapedArray((), pl_core.BarrierSemaphore()),
627627
tpu_core.TPUMemorySpace.SEMAPHORE,
628628
)
629629

0 commit comments

Comments
 (0)