Skip to content

Commit c860bee

Browse files
Cristian GarciaFlax Authors
authored andcommitted
rename graph.py -> graphlib.py
PiperOrigin-RevId: 877411273
1 parent babce88 commit c860bee

28 files changed

+4022
-3906
lines changed

flax/nnx/__init__.py

Lines changed: 31 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,9 @@
2828
from .filterlib import Not as Not
2929
from .filterlib import Everything as Everything
3030
from .filterlib import Nothing as Nothing
31-
from .graph import GraphDef as GraphDef
32-
from .graph import GraphState as GraphState
33-
from .graph import PureState as PureState
31+
from .graphlib import GraphDef as GraphDef
32+
from .graphlib import GraphState as GraphState
33+
from .graphlib import PureState as PureState
3434
from . import pytreelib as object
3535
from .pytreelib import Pytree as Pytree
3636
from .pytreelib import Object as Object
@@ -52,31 +52,31 @@
5252
from .module import view as view
5353
from .module import view_info as view_info
5454
from .module import iter_children as iter_children, iter_modules as iter_modules
55-
from .graph import merge as merge
56-
from .graph import UpdateContext as UpdateContext
57-
from .graph import update_context as update_context
58-
from .graph import current_update_context as current_update_context
59-
from .graph import split as split
60-
from .graph import update as update
61-
from .graph import clone as clone
62-
from .graph import pop as pop
63-
from .graph import state as state
64-
from .graph import graphdef as graphdef
65-
from .graph import iter_graph as iter_graph
66-
from .graph import recursive_map as recursive_map
67-
from .graph import find_duplicates as find_duplicates
68-
from .graph import call as call
69-
from .graph import set_metadata as set_metadata
70-
from .graph import SplitContext as SplitContext
71-
from .graph import split_context as split_context
72-
from .graph import MergeContext as MergeContext
73-
from .graph import merge_context as merge_context
74-
from .graph import variables as variables
75-
from .graph import vars_as as vars_as
76-
from .graph import pure as pure
77-
from .graph import cached_partial as cached_partial
78-
from .graph import flatten as flatten
79-
from .graph import unflatten as unflatten
55+
from .graphlib import merge as merge
56+
from .graphlib import UpdateContext as UpdateContext
57+
from .graphlib import update_context as update_context
58+
from .graphlib import current_update_context as current_update_context
59+
from .graphlib import split as split
60+
from .graphlib import update as update
61+
from .graphlib import clone as clone
62+
from .graphlib import pop as pop
63+
from .graphlib import state as state
64+
from .graphlib import graphdef as graphdef
65+
from .graphlib import iter_graph as iter_graph
66+
from .graphlib import recursive_map as recursive_map
67+
from .graphlib import find_duplicates as find_duplicates
68+
from .graphlib import call as call
69+
from .graphlib import set_metadata as set_metadata
70+
from .graphlib import SplitContext as SplitContext
71+
from .graphlib import split_context as split_context
72+
from .graphlib import MergeContext as MergeContext
73+
from .graphlib import merge_context as merge_context
74+
from .graphlib import variables as variables
75+
from .graphlib import vars_as as vars_as
76+
from .graphlib import pure as pure
77+
from .graphlib import cached_partial as cached_partial
78+
from .graphlib import flatten as flatten
79+
from .graphlib import unflatten as unflatten
8080
from .nn import initializers as initializers
8181
from .nn.activations import celu as celu
8282
from .nn.activations import elu as elu
@@ -210,7 +210,9 @@
210210
from .extract import NodeStates as NodeStates
211211
from .summary import tabulate as tabulate
212212
from . import traversals as traversals
213-
213+
from . import graphlib as graphlib
214+
# import last to prevent potential import cycles
215+
from . import graph as graph
214216

215217
import typing as _tp
216218

flax/nnx/bridge/interop.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import typing as tp
1616

1717
from flax.linen import module as nn_module
18-
from flax.nnx import graph, rnglib
18+
from flax.nnx import graphlib, rnglib
1919
from flax.nnx.bridge import wrappers
2020
from flax.nnx.bridge import module as bdg_module
2121
import flax.nnx.module as nnx_module
@@ -47,15 +47,15 @@ def nnx_in_bridge_mdl(factory: tp.Callable[[rnglib.Rngs], nnx_module.Module],
4747
module = factory(parent.scope.rngs)
4848
else:
4949
rngs = parent.scope.rngs if parent.scope.rngs else rnglib.Rngs(7) # dummy
50-
module = nnx_eval_shape(factory, rngs)
50+
module = nnx_eval_shape(factory, rngs, graph=True)
5151

5252
@nnx_jit
5353
def rng_state(rngs):
54-
return graph.state(factory(rngs), rnglib.RngState)
54+
return graphlib.state(factory(rngs), rnglib.RngState, graph=True)
5555

5656
# Make sure the internal rng state is not abstract - other vars shall be
5757
if parent.scope.rngs:
58-
graph.update(module, rng_state(parent.scope.rngs))
58+
graphlib.update(module, rng_state(parent.scope.rngs))
5959

6060
# Automatically set the attribute if compact. If setup, user is responsible
6161
# for setting the attribute of the superlayer.

flax/nnx/bridge/module.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from flax.core import meta
2929
from flax.core.scope import CollectionFilter
3030
from flax.core.frozen_dict import FrozenDict
31-
from flax.nnx import graph, rnglib, statelib, traversals
31+
from flax.nnx import graphlib, rnglib, statelib, traversals
3232
from flax.nnx import variablelib
3333
import flax.nnx.module as nnx_module
3434
from flax.nnx.pytreelib import Pytree
@@ -104,7 +104,7 @@ def _maybe_call_setup(module: Module):
104104
def _bind_module(parent: Module, module: Module) -> Module:
105105
assert parent.scope is not None
106106

107-
for _, value in reversed(list(graph.iter_graph(module, graph=True))):
107+
for _, value in reversed(list(graphlib.iter_graph(module, graph=True))):
108108
if isinstance(value, Module):
109109
if module.scope is None:
110110
value.scope = parent.scope.copy() # type: ignore[attribute-error]
@@ -245,8 +245,8 @@ def _setattr(self, name: str, value: tp.Any) -> None:
245245
if name in vars(self) and isinstance(
246246
state := vars(self)[name], ModuleState
247247
):
248-
graph.update(value, state)
249-
for leaf in jax.tree.leaves(value, is_leaf=graph.is_graph_node):
248+
graphlib.update(value, state)
249+
for leaf in jax.tree.leaves(value, is_leaf=graphlib.is_graph_node):
250250
if isinstance(leaf, Module):
251251
leaf._pytree__state._initializing = self.is_initializing()
252252
_bind_module(self, leaf)
@@ -374,7 +374,7 @@ def variable( # type: ignore[invalid-annotation]
374374
return variable
375375

376376
def _get_variables(self) -> tp.Mapping:
377-
state = graph.state(self)
377+
state = graphlib.state(self, graph=True)
378378
_variables: dict = {}
379379

380380
variable: variablelib.Variable
@@ -422,7 +422,7 @@ def apply(
422422
_initialize: bool = False,
423423
**kwargs,
424424
) -> tp.Any:
425-
module = graph.clone(self)
425+
module = graphlib.clone(self, graph=True)
426426

427427
# create variables
428428
real_variables = dict(variables)
@@ -440,7 +440,7 @@ def to_variable(value):
440440

441441
states = ({},) if not real_variables else real_variables.values()
442442
state = statelib.merge_state(*states, cls=ModuleState)
443-
graph.update(module, state)
443+
graphlib.update(module, state)
444444

445445
if rngs is None:
446446
rngs = rnglib.Rngs()
@@ -471,7 +471,7 @@ def to_variable(value):
471471
_method = _get_unbound_fn(_method)
472472

473473
# set temporary state
474-
for _, value in graph.iter_graph(module, graph=True):
474+
for _, value in graphlib.iter_graph(module, graph=True):
475475
if isinstance(value, Pytree):
476476
value._pytree__state._initializing = _initialize
477477
if isinstance(value, Module):
@@ -486,7 +486,7 @@ def to_variable(value):
486486
finally:
487487
MODULE_CONTEXT.module_stack.pop()
488488
# reset temporary state
489-
for _, value in graph.iter_graph(module, graph=True):
489+
for _, value in graphlib.iter_graph(module, graph=True):
490490
if isinstance(value, Pytree):
491491
value._pytree__state._initializing = False
492492
if isinstance(value, Module):

flax/nnx/bridge/wrappers.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from flax import nnx
2424
from flax.core import FrozenDict
2525
from flax.core import meta
26-
from flax.nnx import graph
26+
from flax.nnx import graphlib
2727
from flax.nnx import variablelib
2828
from flax.nnx.bridge import variables as bv
2929
from flax.nnx.bridge import module as bdg_module
@@ -40,7 +40,7 @@
4040
@dataclasses.dataclass
4141
class Functional(tp.Generic[M]):
4242
module_type: tp.Type[M]
43-
graphdef: tp.Optional[graph.GraphDef[M]]
43+
graphdef: tp.Optional[graphlib.GraphDef[M]]
4444
args: tuple[tp.Any, ...]
4545
kwargs: dict[str, tp.Any]
4646

@@ -66,7 +66,7 @@ def _functional_constructor(*args: tp.Any, **kwargs: tp.Any) -> Functional[M]:
6666

6767

6868
def _set_initializing(module: Module, initializing: bool):
69-
for _, value in graph.iter_graph(module, graph=True):
69+
for _, value in graphlib.iter_graph(module, graph=True):
7070
if isinstance(value, Pytree):
7171
value._pytree__state._initializing = initializing
7272

0 commit comments

Comments
 (0)