Skip to content

Commit 04bc802

Browse files
Fix nnx object state (#21565)
* Update operation.py * Update actions.yml * Update operation.py * Update actions.yml * Update operation.py * Update operation.py * Update operation.py * fix test * code reformat
1 parent ea62750 commit 04bc802

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

.github/workflows/actions.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ jobs:
5757
run: |
5858
pip install -r requirements.txt --progress-bar off --upgrade
5959
if [ "${{ matrix.nnx_enabled }}" == "true" ]; then
60-
pip install --upgrade flax>=0.11.0
60+
pip install --upgrade flax>=0.11.1
6161
fi
6262
pip uninstall -y keras keras-nightly
6363
pip install -e "." --progress-bar off --upgrade

keras/src/ops/operation.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,11 @@ def __new__(cls, *args, **kwargs):
123123
if backend.backend() == "jax" and is_nnx_enabled():
124124
from flax import nnx
125125

126-
vars(instance)["_object__state"] = nnx.object.ObjectState()
126+
try:
127+
vars(instance)["_pytree__state"] = nnx.pytreelib.PytreeState()
128+
except AttributeError:
129+
vars(instance)["_object__state"] = nnx.object.ObjectState()
130+
127131
# Generate a config to be returned by default by `get_config()`.
128132
arg_names = inspect.getfullargspec(cls.__init__).args
129133
kwargs.update(dict(zip(arg_names[1 : len(args) + 1], args)))
@@ -206,10 +210,9 @@ def __init__(self, arg1, arg2, **kwargs):
206210
207211
def get_config(self):
208212
config = super().get_config()
209-
config.update({{
210-
"arg1": self.arg1,
213+
config.update({"arg1": self.arg1,
211214
"arg2": self.arg2,
212-
}})
215+
})
213216
return config"""
214217
)
215218
)

0 commit comments

Comments
 (0)