Skip to content

Commit 37f7475

Browse files
Add memory management to JAX backend _direct_assign methods
- Implement memory cleanup in JaxVariable._direct_assign() to prevent memory leaks - Store old values before assignment and clean up references to non-sharded arrays - For sharded arrays, rely on garbage collection to avoid breaking references - Apply same memory management to NnxVariable._direct_assign() method - Use getattr() to safely handle cases where _value/raw_value don't exist yet - All JAX backend functionality tested and working correctly
1 parent 92bf1ed commit 37f7475

File tree

1 file changed

+18
-1
lines changed

1 file changed

+18
-1
lines changed

keras/src/backend/jax/core.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,13 @@ def _initialize(self, value):
4949
def _direct_assign(self, value):
5050
if self._layout is not None:
5151
value = distribution_lib.distribute_variable(value, self._layout)
52-
self._value = value
52+
old_value = getattr(self, "_value", None)
53+
self._value = value # New array assigned, references updated
54+
# Now safe to delete old_value reference
55+
if old_value is not None and not (
56+
hasattr(old_value, "sharding") and old_value.sharding is not None
57+
):
58+
del old_value # Remove reference to old non-sharded array
5359

5460
def _convert_to_tensor(self, value, dtype=None):
5561
return convert_to_tensor(value, dtype=dtype, sparse=False)
@@ -201,10 +207,21 @@ def _direct_assign(self, value):
201207
):
202208
value = self._var_metadata["on_set_value"](self, value)
203209

210+
# Store old value for cleanup
211+
old_value = getattr(self, "raw_value", None)
212+
204213
# Set the value for both Keras and NNX parts
205214
# This ensures both systems see the same value
206215
object.__setattr__(self, "raw_value", value)
207216

217+
# Clean up old value reference
218+
if old_value is not None and not (
219+
hasattr(old_value, "sharding")
220+
and old_value.sharding is not None
221+
):
222+
del old_value # Remove reference to old non-sharded array
223+
# For sharded arrays, rely on GC as deletion breaks references
224+
208225
@property
209226
def value(self):
210227
if in_stateless_scope():

0 commit comments

Comments
 (0)