Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions keras/src/backend/jax/core_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,16 @@
from keras.src.backend.jax.core import NnxVariable


class JaxCoreTest(testing.TestCase):
def _require_min_devices(self, min_devices):
"""Skip test if fewer than min_devices are available."""
if len(jax.devices()) < min_devices:
pytest.skip(
f"Test requires at least {min_devices} devices, "
f"but only {len(jax.devices())} available"
)


@pytest.mark.skipif(
backend.backend() != "jax",
reason="JAX backend specific test for core Variable integration with NNX.",
Expand All @@ -25,8 +35,8 @@
reason="Test requires NNX backend to be enabled by default for setup.",
)
class NnxVariableTest(testing.TestCase):
def setup(self):
super().setup()
def setUp(self):
super().setUp()

class NNXModel(nnx.Module):
def __init__(self, rngs):
Expand Down
21 changes: 21 additions & 0 deletions keras/src/backend/jax/distribution_lib_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@
reason="Backend specific test",
)
class JaxDistributionLibTest(testing.TestCase):
def _require_min_devices(self, min_devices):
"""Skip test if fewer than min_devices are available."""
if len(jax.devices()) < min_devices:
pytest.skip(
f"Test requires at least {min_devices} devices, "
f"but only {len(jax.devices())} available"
)

def _create_jax_layout(self, sharding):
# Use jax_layout.Format or jax_layout.Layout if available.
if hasattr(jax_layout, "Format"):
Expand All @@ -43,6 +51,7 @@ def _create_jax_layout(self, sharding):
return sharding

def test_list_devices(self):
self._require_min_devices(8)
self.assertEqual(len(distribution_lib.list_devices()), 8)
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
self.assertEqual(len(distribution_lib.list_devices("cpu")), 8)
Expand Down Expand Up @@ -77,6 +86,7 @@ def test_initialize_with_coordinator_address(self, mock_jax_initialize):
)

def test_distribute_tensor(self):
self._require_min_devices(8)
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
)
Expand All @@ -101,6 +111,7 @@ def test_function(inputs, target_layout):
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))

def test_distribute_variable(self):
self._require_min_devices(8)
# This test only verify the single worker/process behavior.
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
Expand All @@ -118,6 +129,7 @@ def test_distribute_variable(self):
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))

def test_distribute_input_data(self):
self._require_min_devices(8)
# This test only verify the single worker/process behavior.
# The multi-process test lives in g3.
jax_mesh = jax.sharding.Mesh(
Expand All @@ -136,6 +148,7 @@ def test_distribute_input_data(self):
self.assertTrue(result.sharding.is_equivalent_to(target_layout, ndim=2))

def test_distribute_tensor_with_jax_layout(self):
self._require_min_devices(8)
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
)
Expand Down Expand Up @@ -166,6 +179,7 @@ def test_function(inputs, target_layout):
)

def test_distribute_variable_with_jax_layout(self):
self._require_min_devices(8)
# This test only verify the single worker/process behavior.
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
Expand All @@ -187,6 +201,7 @@ def test_distribute_variable_with_jax_layout(self):
)

def test_distribute_input_data_with_jax_layout(self):
self._require_min_devices(8)
# This test only verify the single worker/process behavior.
jax_mesh = jax.sharding.Mesh(
np.array(jax.devices()).reshape(2, 4), ("batch", "model")
Expand All @@ -212,6 +227,7 @@ def test_processes(self):
self.assertEqual(backend_dlib.num_processes(), 1)

def test_to_backend_mesh(self):
self._require_min_devices(8)
devices = [f"cpu:{i}" for i in range(8)]
shape = (4, 2)
axis_names = ["batch", "model"]
Expand All @@ -224,6 +240,7 @@ def test_to_backend_mesh(self):
self.assertEqual(jax_mesh.axis_names, ("batch", "model"))

def test_to_backend_layout(self):
self._require_min_devices(8)
axes = ["data", None]
mesh = distribution_lib.DeviceMesh(
(4, 2), ["data", "model"], [f"cpu:{i}" for i in range(8)]
Expand All @@ -248,6 +265,7 @@ def test_validation_for_device_mesh(self):
backend_dlib._to_backend_layout(layout)

def test_variable_assignment_reuse_layout(self):
self._require_min_devices(8)
shape = (4, 2)
axis_names = ["batch", "model"]
device_mesh = distribution_lib.DeviceMesh(
Expand Down Expand Up @@ -310,6 +328,7 @@ def test_e2e_data_parallel_model(self):
model.fit(inputs, labels)

def test_e2e_model_parallel_model(self):
self._require_min_devices(8)
shape = (4, 2)
axis_names = ["batch", "model"]
device_mesh = distribution_lib.DeviceMesh(
Expand Down Expand Up @@ -349,6 +368,7 @@ def test_e2e_model_parallel_model(self):
model.fit(inputs, labels)

def test_e2e_model_parallel_with_output_sharding(self):
self._require_min_devices(8)
shape = (4, 2)
axis_names = ["batch", "model"]
device_mesh = distribution_lib.DeviceMesh(
Expand Down Expand Up @@ -405,6 +425,7 @@ def test_e2e_model_parallel_with_output_sharding(self):
)

def test_distribute_data_input(self):
self._require_min_devices(4)
per_process_batch = jax.numpy.arange(24).reshape(
6, 4
) # Example input array
Expand Down
1 change: 1 addition & 0 deletions keras/src/backend/torch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ def _initialize(self, value):
).to(get_device())

def _direct_assign(self, value):
value = convert_to_tensor(value, dtype=self._dtype)
with torch.no_grad():
self.value.copy_(value)

Expand Down
Loading