Skip to content

Commit a626c66

Browse files
yashk2810Flax Authors
authored andcommitted
Rename SDS(vma: frozenset = ...) to SDS(manual_type: jax.sharding.ManualAxisType = ...)
PiperOrigin-RevId: 885309371
1 parent 93c430d commit a626c66

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

flax/core/axes_scan.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,8 @@ def build_shaped_array(x, batch_dim: bool = False) -> core.ShapedArray:
5252
shape=shape,
5353
dtype=jnp.result_type(x),
5454
sharding=sharding,
55-
**{k: getattr(x, k) for k in ["weak_type", "vma"] if hasattr(x, k)},
55+
**{k: getattr(x, k) for k in ["weak_type", "manual_type"]
56+
if hasattr(x, k)},
5657
)
5758

5859

0 commit comments

Comments
 (0)