Skip to content

Commit 1572c84

Browse files
Cristian GarciaFlax Authors
authored andcommitted
add nnx.map and nnx.abstract_with_sharding
PiperOrigin-RevId: 886959302
1 parent 3bcc4a9 commit 1572c84

File tree

5 files changed

+226
-10
lines changed

5 files changed

+226
-10
lines changed

flax/nnx/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
from .graphlib import iter_graph as iter_graph
6868
from .graphlib import recursive_map as recursive_map
6969
from .graphlib import find_duplicates as find_duplicates
70+
from .graphlib import map as map
7071
from .graphlib import call as call
7172
from .graphlib import set_metadata as set_metadata
7273
from .graphlib import SplitContext as SplitContext
@@ -151,6 +152,7 @@
151152
from .spmd import get_named_sharding as get_named_sharding
152153
from .spmd import with_partitioning as with_partitioning
153154
from .spmd import get_abstract_model as get_abstract_model
155+
from .spmd import abstract_with_sharding as abstract_with_sharding
154156
from .statelib import FlatState as FlatState
155157
from .statelib import State as State
156158
from .statelib import to_flat_state as to_flat_state

flax/nnx/graphlib.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import functools
2020
import threading
2121
import typing as tp
22+
import builtins
2223

2324
import jax.core
2425

@@ -2526,6 +2527,48 @@ def state(
25262527
variables = state
25272528

25282529

2530+
def map(
2531+
f: tp.Callable[[tuple, tp.Any], tp.Any],
2532+
node: A,
2533+
/,
2534+
*,
2535+
graph: bool | None = None,
2536+
) -> A:
2537+
"""Map a function over the state of a graph node.
2538+
2539+
``map`` extracts the state from ``node`` using :func:`split`, applies ``f``
2540+
to every ``(path, value)`` pair using :func:`map_state`, and returns a
2541+
new node with the mapped values merged back into the original structure.
2542+
Note that the leaves in the state are :class:`Variable` objects, so ``f``
2543+
should handle them accordingly.
2544+
2545+
Example usage::
2546+
2547+
>>> from flax import nnx
2548+
>>> import jax.numpy as jnp
2549+
2550+
>>> model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
2551+
>>> new_model = nnx.map(lambda path, v: v.replace(jnp.zeros_like(v)), model)
2552+
>>> assert jnp.all(new_model.kernel[...] == 0)
2553+
>>> assert jnp.all(new_model.bias[...] == 0)
2554+
2555+
Args:
2556+
f: A callable ``(path, value) -> new_value`` applied to each leaf in the
2557+
state. ``path`` is a tuple of path parts and ``value`` is the
2558+
corresponding leaf (typically a :class:`Variable`).
2559+
node: A graph node object.
2560+
graph: If ``True``, uses graph-mode which supports the full
2561+
NNX feature set including shared references. If ``False``, uses
2562+
tree-mode which treats Modules as regular JAX pytrees, avoiding
2563+
the overhead of the graph protocol.
2564+
Returns:
2565+
A :class:`State` with the mapped values.
2566+
"""
2567+
graphdef, state = split(node, graph=graph)
2568+
state = statelib.map_state(f, state)
2569+
return merge(graphdef, state)
2570+
2571+
25292572
def graphdef(
25302573
node: tp.Any, /, *, graph: bool | None = None,
25312574
) -> GraphDef[tp.Any]:
@@ -2695,7 +2738,7 @@ def _different_vars(path, x):
26952738
duplicates_strs = '\n ---'
26962739
for node_duplicates in all_duplicates:
26972740
for path in node_duplicates:
2698-
path_str = '/'.join(map(str, path))
2741+
path_str = '/'.join(builtins.map(str, path))
26992742
duplicates_strs += f'\n {path_str}'
27002743
duplicates_strs += '\n ---'
27012744
raise ValueError(f'Found duplicate at paths:{duplicates_strs}')
@@ -2976,10 +3019,10 @@ def _iter_tree(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
29763019
continue
29773020

29783021
if not is_pytree_node(current, check_graph_registry=False):
2979-
_check_valid_pytree(current, 'iter_graph', '/'.join(map(str, path)))
3022+
_check_valid_pytree(current, 'iter_graph', '/'.join(builtins.map(str, path)))
29803023
if isinstance(current, Variable) or variablelib.is_array_ref(current):
29813024
obj_id = id(current)
2982-
str_path = '/'.join(map(str, path))
3025+
str_path = '/'.join(builtins.map(str, path))
29833026
if obj_id in seen_refs:
29843027
raise ValueError(
29853028
f'Duplicate {current}\nfound at paths:\n\n'
@@ -2993,7 +3036,7 @@ def _iter_tree(node: tp.Any, /) -> tp.Iterator[tuple[PathParts, tp.Any]]:
29933036
continue
29943037

29953038
obj_id = id(current)
2996-
str_path = '/'.join(map(str, path))
3039+
str_path = '/'.join(builtins.map(str, path))
29973040
if obj_id in in_progress:
29983041
raise ValueError(
29993042
f'Cycle detected for {type(current).__name__}\nfound at paths:\n\n'
@@ -3146,7 +3189,7 @@ def _recursive_map_graph(
31463189
if node_id in visited:
31473190
if node_id in results:
31483191
return results[node_id]
3149-
path_str = '/'.join(map(str, path))
3192+
path_str = '/'.join(builtins.map(str, path))
31503193
raise ValueError(
31513194
f"Found cycle in the graph at path '{path_str}'. Node of type"
31523195
f' {type(node)} has already been visited but has not been returned yet.'
@@ -3184,10 +3227,10 @@ def _recursive_map_tree(
31843227

31853228
def _recurse(path: PathParts, current: tp.Any) -> tp.Any:
31863229
if not is_pytree_node(current, check_graph_registry=False):
3187-
_check_valid_pytree(current, 'recursive_map', '/'.join(map(str, path)))
3230+
_check_valid_pytree(current, 'recursive_map', '/'.join(builtins.map(str, path)))
31883231
if isinstance(current, Variable) or is_array_ref(current):
31893232
obj_id = id(current)
3190-
str_path = '/'.join(map(str, path))
3233+
str_path = '/'.join(builtins.map(str, path))
31913234
if obj_id in seen_refs:
31923235
raise ValueError(
31933236
f'Duplicate {current}\nfound at paths:\n\n'
@@ -3200,7 +3243,7 @@ def _recurse(path: PathParts, current: tp.Any) -> tp.Any:
32003243
return f(path, current)
32013244

32023245
obj_id = id(current)
3203-
str_path = '/'.join(map(str, path))
3246+
str_path = '/'.join(builtins.map(str, path))
32043247
if obj_id in in_progress:
32053248
raise ValueError(
32063249
f'Cycle detected for {type(current).__name__}\nfound at paths:\n\n'

flax/nnx/spmd.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,4 +179,62 @@ def get_abstract_model(init_fn, mesh, *, graph: bool | None = None):
179179
lambda a, s: jax.ShapeDtypeStruct(a.shape, a.dtype, sharding=s),
180180
abs_state, get_named_sharding(abs_state, mesh)
181181
)
182-
return gdef, abs_state
182+
return gdef, abs_state
183+
184+
185+
def abstract_with_sharding(
186+
tree: A, graph: bool | None = None
187+
) -> A:
188+
"""Add sharding information to abstract Variables.
189+
190+
When creating models with :func:`eval_shape`, Variables are abstract
191+
(backed by ``jax.ShapeDtypeStruct``) and may not carry sharding
192+
information, especially when using meshes with
193+
:attr:`jax.sharding.AxisType.Auto` axes. ``abstract_with_sharding`` inspects each
194+
Variable in ``tree`` and, if it has ``out_sharding`` metadata but no
195+
sharding already set, attaches a :class:`jax.sharding.NamedSharding`
196+
derived from the Variable's ``out_sharding`` and either its ``mesh``
197+
metadata or the current abstract mesh (``jax.sharding.get_abstract_mesh``).
198+
199+
Example usage::
200+
201+
from flax import nnx
202+
import jax
203+
204+
mesh = jax.make_mesh((2, 2), ('a', 'b'),
205+
axis_types=(jax.sharding.AxisType.Auto,) * 2)
206+
with jax.set_mesh(mesh):
207+
abs_model = nnx.eval_shape(
208+
lambda: nnx.Linear(4, 8, rngs=nnx.Rngs(0),
209+
kernel_metadata={'out_sharding': ('a', 'b')}))
210+
abs_model = nnx.abstract_with_sharding(abs_model)
211+
assert abs_model.kernel.sharding.spec == jax.P('a', 'b')
212+
213+
Args:
214+
tree: A graph node (e.g. an :class:`nnx.Module`) whose Variables should
215+
be annotated with sharding (via ``out_sharding`` metadata).
216+
graph: Forwarded to :func:`nnx.map`. If ``True``, uses graph-mode;
217+
if ``False``, uses tree-mode.
218+
Returns:
219+
A tree with sharding-annotated ShapeDtypeStruct values inside Variables.
220+
"""
221+
def add_sharding(_path, x):
222+
if (
223+
isinstance(x, variablelib.Variable)
224+
and hasattr(value := x.get_value(), 'shape')
225+
and hasattr(value, 'dtype')
226+
and getattr(value, 'sharding', None) is None
227+
and x.has_metadata('out_sharding')
228+
):
229+
if x.has_metadata('mesh'):
230+
mesh = x.get_metadata('mesh')
231+
else:
232+
mesh = jax.sharding.get_abstract_mesh()
233+
specs = get_var_pspec(x)
234+
sharding = jax.sharding.NamedSharding(mesh, specs)
235+
abs_var = x.replace(
236+
jax.ShapeDtypeStruct(value.shape, value.dtype, sharding=sharding)
237+
)
238+
return abs_var
239+
return x
240+
return graphlib.map(add_sharding, tree, graph=graph)

tests/nnx/graph_utils_test.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1652,7 +1652,53 @@ def __init__(self):
16521652
):
16531653
nnx.recursive_map(lambda path, node: node, node, graph=False)
16541654

1655+
@parameterized.parameters(True, False)
1656+
def test_map(self, graph):
1657+
model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
1658+
new_model = nnx.map(lambda path, x: jnp.zeros_like(x), model, graph=graph)
1659+
1660+
self.assertTrue(hasattr(new_model, 'kernel'))
1661+
self.assertTrue(hasattr(new_model, 'bias'))
1662+
np.testing.assert_array_equal(new_model.kernel, jnp.zeros((2, 3)))
1663+
np.testing.assert_array_equal(new_model.bias, jnp.zeros((3,)))
1664+
1665+
def test_map_with_path(self):
1666+
model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
1667+
paths_seen = []
1668+
1669+
def record_path(path, x):
1670+
paths_seen.append(path)
1671+
return x
1672+
1673+
nnx.map(record_path, model)
1674+
self.assertLen(paths_seen, 2)
1675+
path_last_parts = sorted(p[-1] for p in paths_seen)
1676+
self.assertEqual(path_last_parts, ['bias', 'kernel'])
1677+
1678+
def test_map_nested(self):
1679+
class Model(nnx.Module):
1680+
def __init__(self, rngs):
1681+
self.linear = nnx.Linear(2, 3, rngs=rngs)
1682+
1683+
model = Model(rngs=nnx.Rngs(0))
1684+
new_model = nnx.map(lambda path, x: jnp.ones_like(x), model)
1685+
1686+
self.assertTrue(hasattr(new_model, 'linear'))
1687+
np.testing.assert_array_equal(new_model.linear.kernel, jnp.ones((2, 3)))
1688+
np.testing.assert_array_equal(new_model.linear.bias, jnp.ones((3,)))
1689+
1690+
def test_map_replace(self):
1691+
model = nnx.Linear(2, 3, rngs=nnx.Rngs(0))
1692+
new_model = nnx.map(
1693+
lambda path, v: v.replace(jnp.zeros_like(v)), model
1694+
)
1695+
1696+
self.assertTrue(hasattr(new_model, 'kernel'))
1697+
self.assertTrue(hasattr(new_model, 'bias'))
1698+
self.assertIsInstance(new_model.kernel, nnx.Param)
1699+
np.testing.assert_array_equal(new_model.kernel[...], jnp.zeros((2, 3)))
1700+
np.testing.assert_array_equal(new_model.bias[...], jnp.zeros((3,)))
1701+
16551702

16561703
if __name__ == '__main__':
16571704
absltest.main()
1658-

tests/nnx/spmd_test.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,73 @@ def test_variable_out_sharding_types(self, axis_type_name):
457457
v_format = nnx.Variable(value, out_sharding=Format(Layout(major_to_minor=(1, 0)), ns))
458458
self.assertEqual(v_format.sharding, ns)
459459

460+
def test_get_abstract_with_abstract_mesh(self):
461+
mesh = jax.make_mesh(
462+
(2, 2),
463+
('a', 'b'),
464+
axis_types=(jax.sharding.AxisType.Auto,) * 2,
465+
)
466+
with jax.set_mesh(mesh):
467+
abs_model = nnx.eval_shape(
468+
lambda: nnx.Linear(
469+
4,
470+
8,
471+
rngs=nnx.Rngs(0),
472+
kernel_metadata={'out_sharding': ('a', 'b')},
473+
)
474+
)
475+
abs_model = nnx.abstract_with_sharding(abs_model)
476+
477+
self.assertIsInstance(abs_model.kernel, nnx.Param)
478+
self.assertEqual(abs_model.kernel.sharding.spec, P('a', 'b'))
479+
self.assertEqual(
480+
abs_model.kernel.sharding.mesh.axis_names,
481+
mesh.axis_names,
482+
)
483+
484+
def test_get_abstract_with_per_variable_mesh(self):
485+
mesh1 = jax.make_mesh(
486+
(2, 2),
487+
('a', 'b'),
488+
axis_types=(jax.sharding.AxisType.Auto,) * 2,
489+
)
490+
mesh2 = jax.make_mesh(
491+
(1, 4),
492+
('c', 'd'),
493+
axis_types=(jax.sharding.AxisType.Auto,) * 2,
494+
)
495+
496+
class Model(nnx.Module):
497+
def __init__(self):
498+
self.p1 = nnx.Linear(
499+
4,
500+
8,
501+
rngs=nnx.Rngs(0),
502+
kernel_metadata={'out_sharding': ('a', 'b'), 'mesh': mesh1},
503+
)
504+
self.p2 = nnx.Linear(
505+
4,
506+
8,
507+
rngs=nnx.Rngs(0),
508+
kernel_metadata={'out_sharding': ('c', 'd'), 'mesh': mesh2},
509+
)
510+
511+
abs_model = nnx.eval_shape(lambda: Model())
512+
abs_model = nnx.abstract_with_sharding(abs_model)
513+
514+
self.assertEqual(abs_model.p1.kernel.sharding.spec, P('a', 'b'))
515+
self.assertEqual(abs_model.p1.kernel.sharding.mesh, mesh1)
516+
self.assertEqual(abs_model.p2.kernel.sharding.spec, P('c', 'd'))
517+
self.assertEqual(abs_model.p2.kernel.sharding.mesh, mesh2)
518+
519+
def test_get_abstract_no_sharding_metadata(self):
520+
abs_model = nnx.eval_shape(lambda: nnx.Linear(4, 8, rngs=nnx.Rngs(0)))
521+
abs_model = nnx.abstract_with_sharding(abs_model)
522+
523+
self.assertIsInstance(abs_model.kernel, nnx.Param)
524+
self.assertIsNone(
525+
getattr(abs_model.kernel.get_value(), 'sharding', None)
526+
)
460527

461528
def has_sharding_spec(array):
462529
sharding = array.sharding

0 commit comments

Comments
 (0)