Skip to content

Commit a138d9f

Browse files
Cristian GarciaFlax Authors
authored andcommitted
fix transform_metadata
PiperOrigin-RevId: 887307013
1 parent 1572c84 commit a138d9f

File tree

4 files changed

+100
-16
lines changed

4 files changed

+100
-16
lines changed

flax/nnx/extract.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -517,7 +517,14 @@ def updates_and_snapshot(args: A) -> tuple[A, A]:
517517
for leaf in leaves:
518518
if isinstance(leaf, variablelib.Variable):
519519
updates_leaves.append(leaf)
520-
snapshot_leaves.append(leaf.copy())
520+
# don't snapshot hijax or ref Variables as their updates are automatically
521+
# masked out in mask_variable_updates. However, the leaf is kept in the
522+
# updates to check for aliasing. This avoids a copy operation which has
523+
# significance for ref Variables.
524+
if leaf.hijax or leaf.ref:
525+
snapshot_leaves.append(Mask())
526+
else:
527+
snapshot_leaves.append(leaf.copy())
521528
else:
522529
updates_leaves.append(Mask())
523530
snapshot_leaves.append(Mask())
@@ -597,6 +604,10 @@ def mask_variable_updates(
597604
keep_fn = lambda _, _pfx, cur, snap: variable_changed(cur, snap)
598605

599606
def _mask_updates(path, prefix_leaf, current, snapshot):
607+
if current is None:
608+
# None leaves should remain None, they only appear here because
609+
# is_leaf catches None values for the prefix
610+
return None
600611
if isinstance(current, variablelib.Variable):
601612
if current.hijax or current.ref:
602613
return Mask()
@@ -610,7 +621,7 @@ def _mask_updates(path, prefix_leaf, current, snapshot):
610621
current_tree, snapshot_tree, is_leaf=is_leaf,
611622
)
612623
return broadcast_prefix_map(
613-
_mask_updates, prefix, current_tree, snapshot_tree, is_leaf=is_leaf,
624+
_mask_updates, prefix, current_tree, snapshot_tree, is_leaf=is_leaf
614625
)
615626

616627

flax/nnx/transforms/iteration.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -69,21 +69,19 @@ def _apply_axis_fn(
6969
axis_fn: tp.Callable[..., tp.Any],
7070
) -> None:
7171
is_leaf = lambda x: x is None or isinstance(x, variablelib.Variable)
72-
_, per_leaf_axes = extract.broadcast_prefix2(axes, tree, is_leaf=is_leaf)
73-
leaves = jax.tree_util.tree_leaves(tree, is_leaf=is_leaf)
74-
for leaf, axis in zip(leaves, per_leaf_axes):
75-
if (axis is None or isinstance(axis, int)) and isinstance(
76-
leaf, variablelib.Variable
77-
):
72+
def apply_fn(path, axis, leaf):
73+
if isinstance(axis, int) and isinstance(leaf, variablelib.Variable):
7874
axis_fn(leaf, axis, metadata)
7975

76+
extract.broadcast_prefix_map(apply_fn, axes, tree, is_leaf=is_leaf)
77+
8078

8179
@tp.overload
8280
def transform_metadata(
8381
*,
8482
in_axes: tp.Any = 0,
8583
out_axes: tp.Any = 0,
86-
partition: str,
84+
partition: str | None,
8785
graph: bool | None = None,
8886
) -> tp.Callable[[F], F]:
8987
...
@@ -96,7 +94,7 @@ def transform_metadata(
9694
in_axes: tp.Any = 0,
9795
out_axes: tp.Any = 0,
9896
graph: bool | None = None,
99-
partition: str,
97+
partition: str | None,
10098
) -> F:
10199
...
102100

@@ -106,8 +104,8 @@ def transform_metadata(
106104
*,
107105
in_axes: tp.Any = 0,
108106
out_axes: tp.Any = 0,
107+
partition: str | None,
109108
graph: bool | None = None,
110-
partition: str,
111109
) -> F | tp.Callable[[F], F]:
112110
if f is Missing:
113111
return functools.partial(

tests/nnx/spmd_test.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,30 @@ def f(v):
219219
self.assertEqual(v2[...], 10)
220220

221221

222+
def test_transform_metadata_decorator_none_partition(self):
223+
v = nnx.Param(
224+
jnp.array(1),
225+
out_sharding=(None, 'dout'),
226+
eager_sharding=False,
227+
)
228+
229+
@nnx.transform_metadata(in_axes=0, out_axes=1, partition=None)
230+
def f(v):
231+
v[...] += 1
232+
self.assertEqual(v.out_sharding, ('dout',))
233+
v2 = nnx.Param(
234+
jnp.array(10),
235+
out_sharding=('dmid', 'dout'),
236+
eager_sharding=False,
237+
)
238+
return v2
239+
240+
v2 = f(v)
241+
self.assertEqual(v.out_sharding, (None, 'dout'))
242+
self.assertEqual(v[...], 2)
243+
self.assertEqual(v2.out_sharding, ('dmid', None, 'dout'))
244+
self.assertEqual(v2[...], 10)
245+
222246
@parameterized.product(use_eager_sharding=[True, False])
223247
def test_eager_sharding_context(self, use_eager_sharding):
224248
rngs = nnx.Rngs(0)

tests/nnx/transforms_test.py

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5537,12 +5537,9 @@ def dict_scan(carry, x):
55375537

55385538
def test_no_carry_all_scanned(self):
55395539
def double(x):
5540-
return (x * 2,)
5540+
return x * 2
55415541

5542-
(ys,) = pure_jax_fancy_scan(
5543-
double, jnp.arange(5.0),
5544-
in_axes=(0,), out_axes=(0,),
5545-
)
5542+
ys = pure_jax_fancy_scan(double, jnp.arange(5.0), in_axes=0, out_axes=0)
55465543
np.testing.assert_allclose(ys, jnp.arange(5.0) * 2)
55475544

55485545
def test_reverse(self):
@@ -5651,6 +5648,60 @@ def f(x):
56515648
)
56525649
np.testing.assert_allclose(ys, jnp.arange(5.0) * 2)
56535650

5651+
def test_scan_axis_1(self):
5652+
def cumsum(carry, x):
5653+
carry = carry + x
5654+
return carry, carry
5655+
5656+
x = jnp.arange(10.0).reshape((2, 5))
5657+
final_carry, ys = pure_jax_fancy_scan(
5658+
cumsum, jnp.zeros(2), x,
5659+
in_axes=(nnx.Carry, 1), out_axes=(nnx.Carry, 1),
5660+
)
5661+
np.testing.assert_allclose(final_carry, jnp.array([10.0, 35.0]))
5662+
expected_ys = jnp.array([
5663+
[0., 1., 3., 6., 10.],
5664+
[5., 11., 18., 26., 35.]
5665+
])
5666+
np.testing.assert_allclose(ys, expected_ys)
5667+
5668+
def test_scan_axis_negative_1(self):
5669+
def cumsum(carry, x):
5670+
carry = carry + x
5671+
return carry, carry
5672+
5673+
x = jnp.arange(10.0).reshape((2, 5))
5674+
final_carry, ys = pure_jax_fancy_scan(
5675+
cumsum, jnp.zeros(2), x,
5676+
in_axes=(nnx.Carry, -1), out_axes=(nnx.Carry, -1),
5677+
)
5678+
np.testing.assert_allclose(final_carry, jnp.array([10.0, 35.0]))
5679+
expected_ys = jnp.array([
5680+
[0., 1., 3., 6., 10.],
5681+
[5., 11., 18., 26., 35.]
5682+
])
5683+
np.testing.assert_allclose(ys, expected_ys)
5684+
5685+
def test_scan_different_in_out_axes(self):
5686+
def cumsum(carry, x):
5687+
carry = carry + x
5688+
return carry, carry
5689+
5690+
x = jnp.arange(10.0).reshape((2, 5))
5691+
final_carry, ys = pure_jax_fancy_scan(
5692+
cumsum, jnp.zeros(2), x,
5693+
in_axes=(nnx.Carry, 1), out_axes=(nnx.Carry, 0),
5694+
)
5695+
np.testing.assert_allclose(final_carry, jnp.array([10.0, 35.0]))
5696+
expected_ys = jnp.array([
5697+
[0., 5.],
5698+
[1., 11.],
5699+
[3., 18.],
5700+
[6., 26.],
5701+
[10., 35.]
5702+
])
5703+
np.testing.assert_allclose(ys, expected_ys)
5704+
56545705

56555706
if __name__ == '__main__':
56565707
absltest.main()

0 commit comments

Comments
 (0)