Skip to content

Commit 46eb77b

Browse files
superbobryGoogle-ML-Automation
authored andcommitted
[pallas:mosaic_gpu] Use jax.tree_util.register_dataclass for transforms
PiperOrigin-RevId: 702733084
1 parent 12b45b3 commit 46eb77b

File tree

1 file changed

+5
-31
lines changed

1 file changed

+5
-31
lines changed

jax/_src/pallas/mosaic_gpu/core.py

Lines changed: 5 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,6 @@ def __str__(self) -> str:
104104
return self.value
105105

106106
def __call__(
107-
108107
self,
109108
shape: tuple[int, ...],
110109
dtype: jnp.dtype,
@@ -161,7 +160,6 @@ class TilingTransform(MemoryRefTransform):
161160
shape of (M // X, N // Y, X, Y). Ex. A (256, 256) block that is tiled with a
162161
tiling of (64, 32) will be tiled as (4, 8, 64, 32).
163162
"""
164-
165163
tiling: tuple[int, ...]
166164

167165
def undo(self, ref: pallas_core.TransformedRef) -> pallas_core.TransformedRef:
@@ -176,10 +174,10 @@ def to_gpu_transform(self) -> mgpu.MemRefTransform:
176174
return mgpu.TileTransform(self.tiling)
177175

178176

179-
@tree_util.register_pytree_node_class
177+
@tree_util.register_dataclass
180178
@dataclasses.dataclass(frozen=True)
181179
class UntileRef(state_types.Transform):
182-
tiling: tuple[int, ...]
180+
tiling: tuple[int, ...] = dataclasses.field(metadata=dict(static=True))
183181

184182
def transform_shape(self, shape):
185183
if shape is None:
@@ -214,14 +212,6 @@ def untransform_index(
214212
def undo_to_gpu_transform(self) -> mgpu.MemRefTransform:
215213
return mgpu.TileTransform(self.tiling)
216214

217-
def tree_flatten(self):
218-
return (), (self.tiling,)
219-
220-
@classmethod
221-
def tree_unflatten(cls, metadata, arrays):
222-
assert not arrays
223-
return cls(*metadata)
224-
225215

226216
def _perm_inverse(permutation: tuple[int, ...]) -> tuple[int, ...]:
227217
inverse = [-1] * len(permutation)
@@ -257,7 +247,7 @@ def to_gpu_transform(self) -> mgpu.MemRefTransform:
257247
return mgpu.TransposeTransform(self.permutation)
258248

259249

260-
@tree_util.register_pytree_node_class
250+
@tree_util.register_dataclass
261251
@dataclasses.dataclass(frozen=True)
262252
class TransposeRef(state_types.Transform):
263253
permutation: tuple[int, ...]
@@ -287,14 +277,6 @@ def untransform_index(
287277
def undo_to_gpu_transform(self) -> mgpu.MemRefTransform:
288278
return mgpu.TransposeTransform(_perm_inverse(self.permutation))
289279

290-
def tree_flatten(self):
291-
return (), (self.permutation,)
292-
293-
@classmethod
294-
def tree_unflatten(cls, metadata, arrays):
295-
assert not arrays
296-
return cls(*metadata)
297-
298280

299281
def transpose_ref(
300282
ref: pallas_core.TransformedRef | Any,
@@ -345,10 +327,10 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray:
345327
return aval
346328

347329

348-
@tree_util.register_pytree_node_class
330+
@tree_util.register_dataclass
349331
@dataclasses.dataclass(frozen=True)
350332
class UnswizzleRef(state_types.Transform):
351-
swizzle: int
333+
swizzle: int = dataclasses.field(metadata=dict(static=True))
352334

353335
def untransform_index(
354336
self, idxs: tuple[Index, ...]
@@ -369,14 +351,6 @@ def untransform_index(
369351
raise ValueError("Swizzled dims cannot be sliced")
370352
return idxs, self
371353

372-
def tree_flatten(self):
373-
return (), (self.swizzle,)
374-
375-
@classmethod
376-
def tree_unflatten(cls, metadata, arrays):
377-
assert not arrays
378-
return cls(*metadata)
379-
380354

381355
@dataclasses.dataclass
382356
class GPUBlockSpec(pallas_core.BlockSpec):

0 commit comments

Comments
 (0)