Skip to content

Commit eda5176

Browse files
Fix ModelParallel OOM issue during weight loading
- Modified load_own_variables() to use _direct_assign() for sharded variables - Prevents loading full weight tensors on single device before distribution - Resolves RESOURCE_EXHAUSTED errors when loading large models with ModelParallel - Maintains backward compatibility for non-sharded variables - Enables loading of models like Gemma2 2B/7B without OOM errors - Added EinsumDense layer testing to ModelParallel sharded variable loading
1 parent 0ecb55d commit eda5176

File tree

13 files changed

+921
-90
lines changed

13 files changed

+921
-90
lines changed

keras/src/backend/jax/core.py

Lines changed: 381 additions & 52 deletions
Large diffs are not rendered by default.

keras/src/backend/jax/core_test.py

Lines changed: 128 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,141 @@
77

88
import keras
99
from keras.src import backend
10+
from keras.src import layers
11+
from keras.src import models
1012
from keras.src import testing
1113
from keras.src.backend.config import is_nnx_enabled
14+
from keras.src.backend.jax.core import JaxVariable
15+
from keras.src.backend.jax.core import _ProtectedShardedArray
16+
17+
if is_nnx_enabled():
18+
from keras.src.backend.jax.core import NnxVariable
1219

1320
if is_nnx_enabled():
1421
from flax import nnx
1522

1623
from keras.src.backend.jax.core import NnxVariable
1724

1825

26+
class JaxCoreTest(testing.TestCase):
27+
def _require_min_devices(self, min_devices):
28+
"""Skip test if fewer than min_devices are available."""
29+
if len(jax.devices()) < min_devices:
30+
pytest.skip(
31+
f"Test requires at least {min_devices} devices, "
32+
f"but only {len(jax.devices())} available"
33+
)
34+
35+
def test_protected_sharded_array_deletion(self):
36+
"""Test _ProtectedShardedArray prevents deletion of sharded arrays."""
37+
# Create a mock sharded array
38+
array = jax.numpy.ones((10, 10))
39+
sharded_array = jax.device_put(array, jax.devices()[0])
40+
sharded_array.addressable_shards = [
41+
jax.device_put(array, d) for d in jax.devices()
42+
]
43+
44+
protected = _ProtectedShardedArray(sharded_array)
45+
46+
# Attempt deletion (should not delete sharded arrays)
47+
protected.delete()
48+
49+
# Verify array is still accessible
50+
self.assertIs(protected._array, sharded_array)
51+
self.assertTrue(
52+
hasattr(protected, "_is_sharded") and protected._is_sharded
53+
)
54+
55+
def test_jax_variable_strong_references_and_logging(self):
56+
"""Test JaxVariable strong references and logging."""
57+
self._require_min_devices(2) # Requires multiple devices for sharding
58+
59+
# Create a sharded variable
60+
var = JaxVariable(jax.numpy.ones((100, 100)))
61+
62+
# Check strong references
63+
self.assertTrue(hasattr(var, "_shard_references"))
64+
self.assertGreater(len(var._shard_references), 0)
65+
66+
# Access value multiple times to simulate inference
67+
for _ in range(5):
68+
value = var.value
69+
self.assertIsNotNone(
70+
value
71+
) # Ensure no "Array has been deleted" error
72+
73+
# Final check: Value should still be accessible
74+
self.assertIsNotNone(var.value)
75+
76+
@pytest.mark.skipif(not is_nnx_enabled(), reason="NNX not enabled")
77+
def test_nnx_variable_strong_references_and_logging(self):
78+
"""Test NnxVariable strong references and logging."""
79+
self._require_min_devices(2) # Requires multiple devices for sharding
80+
81+
# Create NNX variable with sharding
82+
var = NnxVariable(jax.numpy.ones((50, 50)), layout=("model", None))
83+
84+
# Check strong references
85+
self.assertTrue(hasattr(var, "_shard_references"))
86+
self.assertGreater(len(var._shard_references), 0)
87+
88+
# Access value (simulates inference) and assert no deletion
89+
value = var.value
90+
self.assertIsNotNone(value) # Ensure no "Array has been deleted" error
91+
92+
# Additional accesses to simulate repeated inference
93+
for _ in range(5):
94+
value = var.value
95+
self.assertIsNotNone(value)
96+
97+
def test_variable_loading_with_sharding(self):
98+
"""Test variable loading with sharding support."""
99+
self._require_min_devices(2) # Requires multiple devices for sharding
100+
101+
# Create test data
102+
test_data = jax.numpy.ones((10, 10))
103+
104+
# Create variable with sharding
105+
var = JaxVariable(jax.numpy.zeros((10, 10)))
106+
# Load data into it
107+
var._direct_assign(test_data)
108+
109+
# Verify it's a JaxVariable with sharding
110+
self.assertIsInstance(var, JaxVariable)
111+
self.assertTrue(hasattr(var, "_shard_references"))
112+
self.assertGreater(len(var._shard_references), 0)
113+
114+
# Access value to ensure no deletion
115+
self.assertIsNotNone(var.value)
116+
117+
def test_inference_simulation_no_array_deletion(self):
118+
"""Test inference simulation for no 'Array has been deleted' errors."""
119+
self._require_min_devices(2) # Requires multiple devices for sharding
120+
121+
# Create a simple model with sharding
122+
inputs = layers.Input(shape=(10,))
123+
x = layers.Dense(50, name="dense")(inputs)
124+
model = models.Model(inputs, x)
125+
126+
# Build and access weights (triggers sharding and protection)
127+
model.build((None, 10))
128+
for var in model.weights:
129+
value = var.value # Access to trigger protection
130+
self.assertIsNotNone(value) # Ensure initial access succeeds
131+
132+
# Simulate inference (multiple accesses) and assert no deletion
133+
test_input = np.random.randn(1, 10)
134+
for _ in range(10):
135+
output = model(test_input)
136+
self.assertIsNotNone(
137+
output
138+
) # Ensure inference succeeds without errors
139+
140+
# Final check: Weights should still be accessible
141+
for var in model.weights:
142+
self.assertIsNotNone(var.value)
143+
144+
19145
@pytest.mark.skipif(
20146
backend.backend() != "jax",
21147
reason="JAX backend specific test for core Variable integration with NNX.",
@@ -25,8 +151,8 @@
25151
reason="Test requires NNX backend to be enabled by default for setup.",
26152
)
27153
class NnxVariableTest(testing.TestCase):
28-
def setup(self):
29-
super().setup()
154+
def setUp(self):
155+
super().setUp()
30156

31157
class NNXModel(nnx.Module):
32158
def __init__(self, rngs):

keras/src/backend/jax/distribution_lib_test.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,14 @@
3333
reason="Backend specific test",
3434
)
3535
class JaxDistributionLibTest(testing.TestCase):
36+
def _require_min_devices(self, min_devices):
37+
"""Skip test if fewer than min_devices are available."""
38+
if len(jax.devices()) < min_devices:
39+
pytest.skip(
40+
f"Test requires at least {min_devices} devices, "
41+
f"but only {len(jax.devices())} available"
42+
)
43+
3644
def _create_jax_layout(self, sharding):
3745
# Use jax_layout.Format or jax_layout.Layout if available.
3846
if hasattr(jax_layout, "Format"):
@@ -43,6 +51,7 @@ def _create_jax_layout(self, sharding):
4351
return sharding
4452

4553
def test_list_devices(self):
54+
self._require_min_devices(8)
4655
self.assertEqual(len(distribution_lib.list_devices()), 8)
4756
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
4857
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
@@ -77,6 +86,7 @@ def test_initialize_with_coordinator_address(self, mock_jax_initialize):
7786
)
7887

7988
def test_distribute_tensor(self):
89+
self._require_min_devices(8)
8090
jax_mesh = jax.sharding.Mesh(
8191
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
8292
)
@@ -101,6 +111,7 @@ def test_function(inputs, target_layout):
101111
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))
102112

103113
def test_distribute_variable(self):
114+
self._require_min_devices(8)
104115
# This test only verify the single worker/process behavior.
105116
jax_mesh = jax.sharding.Mesh(
106117
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
@@ -118,6 +129,7 @@ def test_distribute_variable(self):
118129
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))
119130

120131
def test_distribute_input_data(self):
132+
self._require_min_devices(8)
121133
# This test only verify the single worker/process behavior.
122134
# The multi-process test lives in g3.
123135
jax_mesh = jax.sharding.Mesh(
@@ -136,6 +148,7 @@ def test_distribute_input_data(self):
136148
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))
137149

138150
def test_distribute_tensor_with_jax_layout(self):
151+
self._require_min_devices(8)
139152
jax_mesh = jax.sharding.Mesh(
140153
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
141154
)
@@ -166,6 +179,7 @@ def test_function(inputs, target_layout):
166179
)
167180

168181
def test_distribute_variable_with_jax_layout(self):
182+
self._require_min_devices(8)
169183
# This test only verify the single worker/process behavior.
170184
jax_mesh = jax.sharding.Mesh(
171185
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
@@ -187,6 +201,7 @@ def test_distribute_variable_with_jax_layout(self):
187201
)
188202

189203
def test_distribute_input_data_with_jax_layout(self):
204+
self._require_min_devices(8)
190205
# This test only verify the single worker/process behavior.
191206
jax_mesh = jax.sharding.Mesh(
192207
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
@@ -212,6 +227,7 @@ def test_processes(self):
212227
self.assertEqual(backend_dlib.num_processes(), 1)
213228

214229
def test_to_backend_mesh(self):
230+
self._require_min_devices(8)
215231
devices = [f"cpu:{i}" for i in range(8)]
216232
shape = (4, 2)
217233
axis_names = ["batch", "model"]
@@ -224,6 +240,7 @@ def test_to_backend_mesh(self):
224240
self.assertEqual(jax_mesh.axis_names, ("batch", "model"))
225241

226242
def test_to_backend_layout(self):
243+
self._require_min_devices(8)
227244
axes = ["data", None]
228245
mesh = distribution_lib.DeviceMesh(
229246
(4, 2), ["data", "model"], [f"cpu:{i}" for i in range(8)]
@@ -248,6 +265,7 @@ def test_validation_for_device_mesh(self):
248265
backend_dlib._to_backend_layout(layout)
249266

250267
def test_variable_assignment_reuse_layout(self):
268+
self._require_min_devices(8)
251269
shape = (4, 2)
252270
axis_names = ["batch", "model"]
253271
device_mesh = distribution_lib.DeviceMesh(
@@ -310,6 +328,7 @@ def test_e2e_data_parallel_model(self):
310328
model.fit(inputs, labels)
311329

312330
def test_e2e_model_parallel_model(self):
331+
self._require_min_devices(8)
313332
shape = (4, 2)
314333
axis_names = ["batch", "model"]
315334
device_mesh = distribution_lib.DeviceMesh(
@@ -349,6 +368,7 @@ def test_e2e_model_parallel_model(self):
349368
model.fit(inputs, labels)
350369

351370
def test_e2e_model_parallel_with_output_sharding(self):
371+
self._require_min_devices(8)
352372
shape = (4, 2)
353373
axis_names = ["batch", "model"]
354374
device_mesh = distribution_lib.DeviceMesh(
@@ -405,6 +425,7 @@ def test_e2e_model_parallel_with_output_sharding(self):
405425
)
406426

407427
def test_distribute_data_input(self):
428+
self._require_min_devices(4)
408429
per_process_batch = jax.numpy.arange(24).reshape(
409430
6, 4
410431
) # Example input array

keras/src/backend/torch/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ def _initialize(self, value):
110110
).to(get_device())
111111

112112
def _direct_assign(self, value):
113+
value = convert_to_tensor(value, dtype=self._dtype)
113114
with torch.no_grad():
114115
self.value.copy_(value)
115116

0 commit comments

Comments
 (0)