3838)
3939from flax import config
4040from flax .nnx .variablelib import Variable
41+ from flax .nnx import filterlib
4142from flax .typing import MISSING , Missing , SizeBytes
4243
4344BUILDING_DOCS = 'FLAX_DOC_BUILD' in os .environ
@@ -112,6 +113,18 @@ def __init__(self):
112113 metadata ['static' ] = False
113114 return dataclasses .field (** kwargs , metadata = metadata ) # type: ignore[return-value]
114115
116+ def prefix (pytree , filter_map ):
117+ preds = [(filterlib .to_predicate (k ), v ) for k , v in filter_map .items ()]
118+
119+ def lookup (path , value ):
120+ for (pred ,obj ) in preds :
121+ if pred (graphlib .jax_to_nnx_path (path ), value ):
122+ return obj
123+ return None
124+
125+ return jax .tree .map_with_path (lookup , pytree ,
126+ is_leaf = lambda p , v : isinstance (v , variablelib .Variable ) or lookup (p , v ) is not None ,
127+ is_leaf_takes_path = True )
115128
116129def register_data_type (type_ : T , / ) -> T :
117130 """Registers a type as pytree data type recognized by Object.
@@ -1061,4 +1074,4 @@ def _maybe_int(x):
10611074 return x
10621075
10631076def _get_str (x ):
1064- return x if isinstance (x , str ) else str (x )
1077+ return x if isinstance (x , str ) else str (x )
0 commit comments