@@ -104,7 +104,6 @@ def __str__(self) -> str:
104104 return self .value
105105
106106 def __call__ (
107-
108107 self ,
109108 shape : tuple [int , ...],
110109 dtype : jnp .dtype ,
@@ -161,7 +160,6 @@ class TilingTransform(MemoryRefTransform):
161160 shape of (M // X, N // Y, X, Y). Ex. A (256, 256) block that is tiled with a
162161 tiling of (64, 32) will be tiled as (4, 8, 64, 32).
163162 """
164-
165163 tiling : tuple [int , ...]
166164
167165 def undo (self , ref : pallas_core .TransformedRef ) -> pallas_core .TransformedRef :
@@ -176,10 +174,10 @@ def to_gpu_transform(self) -> mgpu.MemRefTransform:
176174 return mgpu .TileTransform (self .tiling )
177175
178176
179- @tree_util .register_pytree_node_class
177+ @tree_util .register_dataclass
180178@dataclasses .dataclass (frozen = True )
181179class UntileRef (state_types .Transform ):
182- tiling : tuple [int , ...]
180+ tiling : tuple [int , ...] = dataclasses . field ( metadata = dict ( static = True ))
183181
184182 def transform_shape (self , shape ):
185183 if shape is None :
@@ -214,14 +212,6 @@ def untransform_index(
214212 def undo_to_gpu_transform (self ) -> mgpu .MemRefTransform :
215213 return mgpu .TileTransform (self .tiling )
216214
217- def tree_flatten (self ):
218- return (), (self .tiling ,)
219-
220- @classmethod
221- def tree_unflatten (cls , metadata , arrays ):
222- assert not arrays
223- return cls (* metadata )
224-
225215
226216def _perm_inverse (permutation : tuple [int , ...]) -> tuple [int , ...]:
227217 inverse = [- 1 ] * len (permutation )
@@ -257,7 +247,7 @@ def to_gpu_transform(self) -> mgpu.MemRefTransform:
257247 return mgpu .TransposeTransform (self .permutation )
258248
259249
260- @tree_util .register_pytree_node_class
250+ @tree_util .register_dataclass
261251@dataclasses .dataclass (frozen = True )
262252class TransposeRef (state_types .Transform ):
263253 permutation : tuple [int , ...]
@@ -287,14 +277,6 @@ def untransform_index(
287277 def undo_to_gpu_transform (self ) -> mgpu .MemRefTransform :
288278 return mgpu .TransposeTransform (_perm_inverse (self .permutation ))
289279
290- def tree_flatten (self ):
291- return (), (self .permutation ,)
292-
293- @classmethod
294- def tree_unflatten (cls , metadata , arrays ):
295- assert not arrays
296- return cls (* metadata )
297-
298280
299281def transpose_ref (
300282 ref : pallas_core .TransformedRef | Any ,
@@ -345,10 +327,10 @@ def __call__(self, aval: jax_core.ShapedArray) -> jax_core.ShapedArray:
345327 return aval
346328
347329
348- @tree_util .register_pytree_node_class
330+ @tree_util .register_dataclass
349331@dataclasses .dataclass (frozen = True )
350332class UnswizzleRef (state_types .Transform ):
351- swizzle : int
333+ swizzle : int = dataclasses . field ( metadata = dict ( static = True ))
352334
353335 def untransform_index (
354336 self , idxs : tuple [Index , ...]
@@ -369,14 +351,6 @@ def untransform_index(
369351 raise ValueError ("Swizzled dims cannot be sliced" )
370352 return idxs , self
371353
372- def tree_flatten (self ):
373- return (), (self .swizzle ,)
374-
375- @classmethod
376- def tree_unflatten (cls , metadata , arrays ):
377- assert not arrays
378- return cls (* metadata )
379-
380354
381355@dataclasses .dataclass
382356class GPUBlockSpec (pallas_core .BlockSpec ):
0 commit comments