Skip to content

Commit 8e4c5d3

Browse files
yashk2810recml authors
authored andcommitted
Use jax_layout.Format in place of jax_layout.Layout
JAX is undergoing a rename of the contents of `jax.experimental.layouts` in preparation for its graduation from experimental: 1. "Formats" are replacing "layouts", and specifically `Layout` -> `Format` 2. "Layouts" are replacing "device local layouts", and specifically `DeviceLocalLayout` -> `Layout` This is an incremental update carrying out #1. PiperOrigin-RevId: 772114015
1 parent cbf9a8f commit 8e4c5d3

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

recml/layers/linen/sparsecore.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ class SparsecoreLayout(nn.Partitioned[A]):
381381

382382
def get_sharding(self, _):
383383
assert self.mesh is not None
384-
return layout.Layout(
384+
return layout.Format(
385385
layout.DeviceLocalLayout(major_to_minor=(0, 1), _tiling=((8,),)),
386386
jax.sharding.NamedSharding(self.mesh, self.get_partition_spec()),
387387
)

0 commit comments

Comments
 (0)