Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
433 changes: 381 additions & 52 deletions keras/src/backend/jax/core.py

Large diffs are not rendered by default.

130 changes: 128 additions & 2 deletions keras/src/backend/jax/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,141 @@

import keras
from keras.src import backend
from keras.src import layers
from keras.src import models
from keras.src import testing
from keras.src.backend.config import is_nnx_enabled
from keras.src.backend.jax.core import JaxVariable
from keras.src.backend.jax.core import _ProtectedShardedArray

if is_nnx_enabled():
from keras.src.backend.jax.core import NnxVariable

if is_nnx_enabled():
from flax import nnx

from keras.src.backend.jax.core import NnxVariable


class JaxCoreTest(testing.TestCase):
def _require_min_devices(self, min_devices):
"""Skip test if fewer than min_devices are available."""
if len(jax.devices()) < min_devices:
pytest.skip(
f"Test requires at least {min_devices} devices, "
f"but only {len(jax.devices())} available"
)

def test_protected_sharded_array_deletion(self):
"""Test _ProtectedShardedArray prevents deletion of sharded arrays."""
# Create a mock sharded array
array = jax.numpy.ones((10, 10))
sharded_array = jax.device_put(array, jax.devices()[0])
sharded_array.addressable_shards = [
jax.device_put(array, d) for d in jax.devices()
]

protected = _ProtectedShardedArray(sharded_array)

# Attempt deletion (should not delete sharded arrays)
protected.delete()

# Verify array is still accessible
self.assertIs(protected._array, sharded_array)
self.assertTrue(
hasattr(protected, "_is_sharded") and protected._is_sharded
)

def test_jax_variable_strong_references_and_logging(self):
"""Test JaxVariable strong references and logging."""
self._require_min_devices(2) # Requires multiple devices for sharding

# Create a sharded variable
var = JaxVariable(jax.numpy.ones((100, 100)))

# Check strong references
self.assertTrue(hasattr(var, "_shard_references"))
self.assertGreater(len(var._shard_references), 0)

# Access value multiple times to simulate inference
for _ in range(5):
value = var.value
self.assertIsNotNone(
value
) # Ensure no "Array has been deleted" error

# Final check: Value should still be accessible
self.assertIsNotNone(var.value)

@pytest.mark.skipif(not is_nnx_enabled(), reason="NNX not enabled")
def test_nnx_variable_strong_references_and_logging(self):
"""Test NnxVariable strong references and logging."""
self._require_min_devices(2) # Requires multiple devices for sharding

# Create NNX variable with sharding
var = NnxVariable(jax.numpy.ones((50, 50)), layout=("model", None))

# Check strong references
self.assertTrue(hasattr(var, "_shard_references"))
self.assertGreater(len(var._shard_references), 0)

# Access value (simulates inference) and assert no deletion
value = var.value
self.assertIsNotNone(value) # Ensure no "Array has been deleted" error

# Additional accesses to simulate repeated inference
for _ in range(5):
value = var.value
self.assertIsNotNone(value)

def test_variable_loading_with_sharding(self):
"""Test variable loading with sharding support."""
self._require_min_devices(2) # Requires multiple devices for sharding

# Create test data
test_data = jax.numpy.ones((10, 10))

# Create variable with sharding
var = JaxVariable(jax.numpy.zeros((10, 10)))
# Load data into it
var._direct_assign(test_data)

# Verify it's a JaxVariable with sharding
self.assertIsInstance(var, JaxVariable)
self.assertTrue(hasattr(var, "_shard_references"))
self.assertGreater(len(var._shard_references), 0)

# Access value to ensure no deletion
self.assertIsNotNone(var.value)

def test_inference_simulation_no_array_deletion(self):
"""Test inference simulation for no 'Array has been deleted' errors."""
self._require_min_devices(2) # Requires multiple devices for sharding

# Create a simple model with sharding
inputs = layers.Input(shape=(10,))
x = layers.Dense(50, name="dense")(inputs)
model = models.Model(inputs, x)

# Build and access weights (triggers sharding and protection)
model.build((None, 10))
for var in model.weights:
value = var.value # Access to trigger protection
self.assertIsNotNone(value) # Ensure initial access succeeds

# Simulate inference (multiple accesses) and assert no deletion
test_input = np.random.randn(1, 10)
for _ in range(10):
output = model(test_input)
self.assertIsNotNone(
output
) # Ensure inference succeeds without errors

# Final check: Weights should still be accessible
for var in model.weights:
self.assertIsNotNone(var.value)


@pytest.mark.skipif(
backend.backend() != "jax",
reason="JAX backend specific test for core Variable integration with NNX.",
Expand All @@ -25,8 +151,8 @@
reason="Test requires NNX backend to be enabled by default for setup.",
)
class NnxVariableTest(testing.TestCase):
def setup(self):
super().setup()
def setUp(self):
super().setUp()

class NNXModel(nnx.Module):
def __init__(self, rngs):
Expand Down
21 changes: 21 additions & 0 deletions keras/src/backend/jax/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@
reason="Backend specific test",
)
class JaxDistributionLibTest(testing.TestCase):
def _require_min_devices(self, min_devices):
"""Skip test if fewer than min_devices are available."""
if len(jax.devices()) < min_devices:
pytest.skip(
f"Test requires at least {min_devices} devices, "
f"but only {len(jax.devices())} available"
)

def _create_jax_layout(self, sharding):
# Use jax_layout.Format or jax_layout.Layout if available.
if hasattr(jax_layout, "Format"):
Expand All @@ -43,6 +51,7 @@ def _create_jax_layout(self, sharding):
return sharding

def test_list_devices(self):
self._require_min_devices(8)
self.assertEqual(len(distribution_lib.list_devices()), 8)
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
Expand Down Expand Up @@ -77,6 +86,7 @@ def test_initialize_with_coordinator_address(self, mock_jax_initialize):
)

def test_distribute_tensor(self):
self._require_min_devices(8)
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
)
Expand All @@ -101,6 +111,7 @@ def test_function(inputs, target_layout):
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))

def test_distribute_variable(self):
self._require_min_devices(8)
# This test only verify the single worker/process behavior.
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
Expand All @@ -118,6 +129,7 @@ def test_distribute_variable(self):
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))

def test_distribute_input_data(self):
self._require_min_devices(8)
# This test only verify the single worker/process behavior.
# The multi-process test lives in g3.
jax_mesh = jax.sharding.Mesh(
Expand All @@ -136,6 +148,7 @@ def test_distribute_input_data(self):
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))

def test_distribute_tensor_with_jax_layout(self):
self._require_min_devices(8)
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
)
Expand Down Expand Up @@ -166,6 +179,7 @@ def test_function(inputs, target_layout):
)

def test_distribute_variable_with_jax_layout(self):
self._require_min_devices(8)
# This test only verify the single worker/process behavior.
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
Expand All @@ -187,6 +201,7 @@ def test_distribute_variable_with_jax_layout(self):
)

def test_distribute_input_data_with_jax_layout(self):
self._require_min_devices(8)
# This test only verify the single worker/process behavior.
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
Expand All @@ -212,6 +227,7 @@ def test_processes(self):
self.assertEqual(backend_dlib.num_processes(), 1)

def test_to_backend_mesh(self):
self._require_min_devices(8)
devices = [f"cpu:{i}" for i in range(8)]
shape = (4, 2)
axis_names = ["batch", "model"]
Expand All @@ -224,6 +240,7 @@ def test_to_backend_mesh(self):
self.assertEqual(jax_mesh.axis_names, ("batch", "model"))

def test_to_backend_layout(self):
self._require_min_devices(8)
axes = ["data", None]
mesh = distribution_lib.DeviceMesh(
(4, 2), ["data", "model"], [f"cpu:{i}" for i in range(8)]
Expand All @@ -248,6 +265,7 @@ def test_validation_for_device_mesh(self):
backend_dlib._to_backend_layout(layout)

def test_variable_assignment_reuse_layout(self):
self._require_min_devices(8)
shape = (4, 2)
axis_names = ["batch", "model"]
device_mesh = distribution_lib.DeviceMesh(
Expand Down Expand Up @@ -310,6 +328,7 @@ def test_e2e_data_parallel_model(self):
model.fit(inputs, labels)

def test_e2e_model_parallel_model(self):
self._require_min_devices(8)
shape = (4, 2)
axis_names = ["batch", "model"]
device_mesh = distribution_lib.DeviceMesh(
Expand Down Expand Up @@ -349,6 +368,7 @@ def test_e2e_model_parallel_model(self):
model.fit(inputs, labels)

def test_e2e_model_parallel_with_output_sharding(self):
self._require_min_devices(8)
shape = (4, 2)
axis_names = ["batch", "model"]
device_mesh = distribution_lib.DeviceMesh(
Expand Down Expand Up @@ -405,6 +425,7 @@ def test_e2e_model_parallel_with_output_sharding(self):
)

def test_distribute_data_input(self):
self._require_min_devices(4)
per_process_batch = jax.numpy.arange(24).reshape(
6, 4
) # Example input array
Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def _initialize(self, value):
).to(get_device())

def _direct_assign(self, value):
value = convert_to_tensor(value, dtype=self._dtype)
with torch.no_grad():
self.value.copy_(value)

Expand Down
Loading