From 468005e1b4cc991b01da5299e7de0cf52ab3efc2 Mon Sep 17 00:00:00 2001 From: suhana Date: Mon, 4 Aug 2025 10:04:15 +0530 Subject: [PATCH 01/12] Refactor JAXTrainer sharding to use out_shardings --- keras/src/backend/jax/trainer.py | 129 +++++++++++-------------------- 1 file changed, 46 insertions(+), 83 deletions(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 33a415a38f71..d6840fccd059 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -153,7 +153,7 @@ def train_step(self, state, data): metrics_variables, unscaled_loss, x, y, y_pred, sample_weight ) - state = self._enforce_jax_state_sharding( + state = ( trainable_variables, non_trainable_variables, optimizer_variables, @@ -185,17 +185,6 @@ def test_step(self, state, data): metrics_variables, unscaled_loss, x, y, y_pred, sample_weight ) - ( - trainable_variables, - non_trainable_variables, - _, - metrics_variables, - ) = self._enforce_jax_state_sharding( - trainable_variables=trainable_variables, - non_trainable_variables=non_trainable_variables, - optimizer_variables=None, - metrics_variables=metrics_variables, - ) state = ( trainable_variables, non_trainable_variables, @@ -213,17 +202,6 @@ def predict_step(self, state, data): outputs, non_trainable_variables = self.stateless_call( trainable_variables, non_trainable_variables, x, **kwargs ) - ( - _, - non_trainable_variables, - _, - _, - ) = self._enforce_jax_state_sharding( - trainable_variables=None, - non_trainable_variables=non_trainable_variables, - optimizer_variables=None, - metrics_variables=None, - ) return outputs, non_trainable_variables def _make_function(self, step_function, concatenate_outputs=False): @@ -281,11 +259,22 @@ def make_train_function(self, force=False): if self.train_function is not None and not force: return if not self.run_eagerly and self.jit_compile: - # Note that we mark the state to be donated to jax, - # so that jax will reuse the memory buffer for outputs. - # This will reduce the memory usage of the training function by - # half. - train_step = jit(self.train_step, donate_argnums=0) + out_shardings = None + if distribution_lib.distribution() is not None: + out_shardings = ( + tree.map_structure(lambda _: None, self._metrics_result_structure), + ( + self._trainable_variable_shardings, + self._non_trainable_variable_shardings, + self._optimizer_variable_shardings, + self._metrics_variable_shardings, + ), + ) + train_step = jit( + self.train_step, + donate_argnums=0, + out_shardings=out_shardings + ) else: train_step = self.train_step @@ -297,12 +286,21 @@ def make_test_function(self, force=False): if self.test_function is not None and not force: return if not self.run_eagerly and self.jit_compile: - # Note that we mark the state to be donated to jax, - # so that jax will reuse the memory buffer for outputs. - # This will reduce the memory usage of the training function by - # half. - test_step = jit(self.test_step, donate_argnums=0) - + out_shardings = None + if distribution_lib.distribution() is not None: + out_shardings = ( + tree.map_structure(lambda _: None, self._metrics_result_structure), + ( + self._trainable_variable_shardings, + self._non_trainable_variable_shardings, + self._metrics_variable_shardings, + ), + ) + test_step = jit( + self.test_step, + donate_argnums=0, + out_shardings=out_shardings + ) else: test_step = self.test_step @@ -319,7 +317,20 @@ def predict_step(state, data): return outputs, (state[0], non_trainable_variables) if not self.run_eagerly and self.jit_compile: - predict_step = jit(predict_step, donate_argnums=0) + out_shardings = None + if distribution_lib.distribution() is not None: + out_shardings = ( + None, + ( + self._trainable_variable_shardings, + self._non_trainable_variable_shardings + ), + ) + predict_step = jit( + predict_step, + donate_argnums=0, + out_shardings=out_shardings + ) _step_function = self._make_function( predict_step, concatenate_outputs=True @@ -907,54 +918,6 @@ def _clear_jax_state_sharding(self): self._optimizer_variable_shardings = None self._metrics_variable_shardings = None - def _enforce_jax_state_sharding( - self, - trainable_variables=None, - non_trainable_variables=None, - optimizer_variables=None, - metrics_variables=None, - ): - """Enforce the sharding spec constraint for all the training state. - - Since the output of the train/eval step will be used as inputs to next - step, we need to ensure that they have the same sharding spec, so that - nnx.jit/jax.jit won't have to recompile the train/eval function. - - Note that this function will also rely on the recorded sharding spec - for each of states. - - This function is expected to be called within the jitted train/eval - function, especially around the end of the function. - """ - trainable_variables = trainable_variables or [] - non_trainable_variables = non_trainable_variables or [] - optimizer_variables = optimizer_variables or [] - metrics_variables = metrics_variables or [] - - for i in range(len(trainable_variables)): - trainable_variables[i] = jax.lax.with_sharding_constraint( - trainable_variables[i], self._trainable_variable_shardings[i] - ) - for i in range(len(non_trainable_variables)): - non_trainable_variables[i] = jax.lax.with_sharding_constraint( - non_trainable_variables[i], - self._non_trainable_variable_shardings[i], - ) - for i in range(len(optimizer_variables)): - optimizer_variables[i] = jax.lax.with_sharding_constraint( - optimizer_variables[i], self._optimizer_variable_shardings[i] - ) - for i in range(len(metrics_variables)): - metrics_variables[i] = jax.lax.with_sharding_constraint( - metrics_variables[i], self._metrics_variable_shardings[i] - ) - return ( - trainable_variables, - non_trainable_variables, - optimizer_variables, - metrics_variables, - ) - def _purge_model_variables( self, trainable_variables=False, From 1ac4ff9edf9660c79672f424205c7f8662ba48c6 Mon Sep 17 00:00:00 2001 From: suhana Date: Wed, 6 Aug 2025 15:44:05 +0530 Subject: [PATCH 02/12] Added tests for the Jax out sharding --- keras/src/backend/jax/trainer.py | 28 ++++++------ keras/src/trainers/trainer_test.py | 72 ++++++++++++++++++++++++++++++ 2 files changed, 86 insertions(+), 14 deletions(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index d6840fccd059..32fcb232e209 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -260,7 +260,9 @@ def make_train_function(self, force=False): return if not self.run_eagerly and self.jit_compile: out_shardings = None - if distribution_lib.distribution() is not None: + if distribution_lib.distribution() is not None and hasattr( + self, "_metrics_result_structure" + ): out_shardings = ( tree.map_structure(lambda _: None, self._metrics_result_structure), ( @@ -271,9 +273,7 @@ def make_train_function(self, force=False): ), ) train_step = jit( - self.train_step, - donate_argnums=0, - out_shardings=out_shardings + self.train_step, donate_argnums=0, out_shardings=out_shardings ) else: train_step = self.train_step @@ -287,7 +287,9 @@ def make_test_function(self, force=False): return if not self.run_eagerly and self.jit_compile: out_shardings = None - if distribution_lib.distribution() is not None: + if distribution_lib.distribution() is not None and hasattr( + self, "_metrics_result_structure" + ): out_shardings = ( tree.map_structure(lambda _: None, self._metrics_result_structure), ( @@ -297,15 +299,12 @@ def make_test_function(self, force=False): ), ) test_step = jit( - self.test_step, - donate_argnums=0, - out_shardings=out_shardings + self.test_step, donate_argnums=0, out_shardings=out_shardings ) else: test_step = self.test_step step_function = self._make_function(test_step) - self.test_function = step_function def make_predict_function(self, force=False): @@ -318,18 +317,19 @@ def predict_step(state, data): if not self.run_eagerly and self.jit_compile: out_shardings = None - if distribution_lib.distribution() is not None: + # FIX: Check if the model has been built before accessing sharding attrs + if distribution_lib.distribution() is not None and hasattr( + self, "_trainable_variable_shardings" + ): out_shardings = ( None, ( self._trainable_variable_shardings, - self._non_trainable_variable_shardings + self._non_trainable_variable_shardings, ), ) predict_step = jit( - predict_step, - donate_argnums=0, - out_shardings=out_shardings + predict_step, donate_argnums=0, out_shardings=out_shardings ) _step_function = self._make_function( diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index 21080dea2bff..b656d001702b 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -23,6 +23,7 @@ if backend.backend() == "jax": from keras.src.backend.jax.trainer import JAXTrainer as Trainer + from keras.src.distribution import DeviceMesh, TensorLayout, distribution_lib elif backend.backend() == "torch": from keras.src.backend.torch.trainer import TorchTrainer as Trainer elif backend.backend() == "tensorflow": @@ -2857,3 +2858,74 @@ def predict_step(self, *args): verbose=0, ) self.assertLessEqual(tracing_count[0], 2) + +class ExampleModelForJAXTrainerShardingTest(models.Model): + def __init__(self, units=3, **kwargs): + super().__init__(**kwargs) + self.dense1 = layers.Dense(4, activation="relu", kernel_initializer="ones") + self.dense2 = layers.Dense(units, activation="softmax", kernel_initializer="ones") + + def call(self, x): + return self.dense2(self.dense1(x)) + +@pytest.mark.skipif( + backend.backend() != "jax", + reason="This is a JAX-specific distribution test.", +) +class JAXTrainerShardingTest(testing.TestCase): + + def setUp(self): + super().setUp() + import jax + + if jax.device_count() < 2: + self.skipTest( + "Cannot test sharding with less than 2 devices. " + f"Found {jax.device_count()} devices." + ) + + devices = np.array(jax.devices()) + device_mesh = DeviceMesh( + shape=(jax.device_count(),), + axis_names=("batch",), + devices=devices.flatten(), + ) + data_layout_2d = TensorLayout(axes=("batch", None), device_mesh=device_mesh) + data_layout_1d = TensorLayout(axes=("batch",), device_mesh=device_mesh) + variable_layout = TensorLayout(axes=(None, None), device_mesh=device_mesh) + + def get_layout_for_data(shape): + if not hasattr(shape, '__len__'): + return variable_layout + if len(shape) == 2: + return data_layout_2d + elif len(shape) == 1: + return data_layout_1d + return variable_layout + + mock_dist = mock.MagicMock() + mock_dist.get_data_layout.side_effect = get_layout_for_data + mock_dist.get_tensor_layout.return_value = variable_layout + mock_dist.auto_shard_dataset = False + self.distribution_mock = mock_dist + + @pytest.mark.requires_trainable_backend + def test_fit_with_sharding(self): + with mock.patch.object( + distribution_lib, "distribution", return_value=self.distribution_mock + ): + model = ExampleModelForJAXTrainerShardingTest(units=3) + model.compile( + optimizer="sgd", + loss="mse", + jit_compile=True + ) + + x = np.ones((16, 5), dtype="float32") + y = np.zeros((16, 3), dtype="float32") + sw = np.ones((16,), dtype="float32") + + history = model.fit(x, y, sample_weight=sw, batch_size=4, epochs=2) + + self.assertIn("loss", history.history) + self.assertEqual(len(history.history["loss"]), 2) \ No newline at end of file From a07bf639532a325c5290121bb463af7635f7e6ed Mon Sep 17 00:00:00 2001 From: Suhana Date: Fri, 8 Aug 2025 10:49:06 +0530 Subject: [PATCH 03/12] Update keras/src/trainers/trainer_test.py for nit changes Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> --- keras/src/trainers/trainer_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index b656d001702b..abaa2525e4e0 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -2918,7 +2918,7 @@ def test_fit_with_sharding(self): model.compile( optimizer="sgd", loss="mse", - jit_compile=True + jit_compile=True ) x = np.ones((16, 5), dtype="float32") From 5eb6d0c0b2e91934d373886657a5b17522f5477d Mon Sep 17 00:00:00 2001 From: suhana Date: Fri, 8 Aug 2025 10:59:37 +0530 Subject: [PATCH 04/12] Added a helper to reduce code duplication --- keras/src/backend/jax/trainer.py | 60 ++++++++++++++++---------------- 1 file changed, 30 insertions(+), 30 deletions(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 32fcb232e209..99d4e7d90720 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -255,57 +255,57 @@ def iterator_step(state, iterator): return iterator_step + def _get_out_shardings_for_step(self, state_shardings): + """Helper to create the out_shardings for a jitted step function.""" + if distribution_lib.distribution() is not None and hasattr( + self, "_metrics_result_structure" + ): + return ( + tree.map_structure(lambda _: None, self._metrics_result_structure), + state_shardings, + ) + return None + def make_train_function(self, force=False): if self.train_function is not None and not force: return + + state_shardings = ( + self._trainable_variable_shardings, + self._non_trainable_variable_shardings, + self._optimizer_variable_shardings, + self._metrics_variable_shardings, + ) + out_shardings = self._get_out_shardings_for_step(state_shardings) + if not self.run_eagerly and self.jit_compile: - out_shardings = None - if distribution_lib.distribution() is not None and hasattr( - self, "_metrics_result_structure" - ): - out_shardings = ( - tree.map_structure(lambda _: None, self._metrics_result_structure), - ( - self._trainable_variable_shardings, - self._non_trainable_variable_shardings, - self._optimizer_variable_shardings, - self._metrics_variable_shardings, - ), - ) train_step = jit( self.train_step, donate_argnums=0, out_shardings=out_shardings ) else: train_step = self.train_step - step_function = self._make_function(train_step) - - self.train_function = step_function + self.train_function = self._make_function(train_step) def make_test_function(self, force=False): if self.test_function is not None and not force: return + + state_shardings = ( + self._trainable_variable_shardings, + self._non_trainable_variable_shardings, + self._metrics_variable_shardings, + ) + out_shardings = self._get_out_shardings_for_step(state_shardings) + if not self.run_eagerly and self.jit_compile: - out_shardings = None - if distribution_lib.distribution() is not None and hasattr( - self, "_metrics_result_structure" - ): - out_shardings = ( - tree.map_structure(lambda _: None, self._metrics_result_structure), - ( - self._trainable_variable_shardings, - self._non_trainable_variable_shardings, - self._metrics_variable_shardings, - ), - ) test_step = jit( self.test_step, donate_argnums=0, out_shardings=out_shardings ) else: test_step = self.test_step - step_function = self._make_function(test_step) - self.test_function = step_function + self.test_function = self._make_function(test_step) def make_predict_function(self, force=False): if self.predict_function is not None and not force: From f3420674265655ab4657b052c55014aa8a13214e Mon Sep 17 00:00:00 2001 From: suhana Date: Sat, 9 Aug 2025 20:48:49 +0530 Subject: [PATCH 05/12] Updated the test and the out sharding logic --- keras/src/backend/jax/trainer.py | 135 ++++++++++++++---------- keras/src/trainers/trainer_test.py | 164 ++++++++++++++++++----------- 2 files changed, 179 insertions(+), 120 deletions(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 99d4e7d90720..cfdd026d9919 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -19,6 +19,7 @@ from keras.src.trainers.data_adapters import data_adapter_utils from keras.src.trainers.epoch_iterator import EpochIterator from keras.src.utils import traceback_utils +from jax.sharding import PartitionSpec, NamedSharding if is_nnx_enabled(): from flax import nnx @@ -255,82 +256,78 @@ def iterator_step(state, iterator): return iterator_step - def _get_out_shardings_for_step(self, state_shardings): - """Helper to create the out_shardings for a jitted step function.""" - if distribution_lib.distribution() is not None and hasattr( - self, "_metrics_result_structure" - ): - return ( - tree.map_structure(lambda _: None, self._metrics_result_structure), - state_shardings, - ) - return None - def make_train_function(self, force=False): if self.train_function is not None and not force: return - - state_shardings = ( - self._trainable_variable_shardings, - self._non_trainable_variable_shardings, - self._optimizer_variable_shardings, - self._metrics_variable_shardings, - ) - out_shardings = self._get_out_shardings_for_step(state_shardings) - if not self.run_eagerly and self.jit_compile: + out_shardings = None + if distribution_lib.distribution() is not None: + state_shardings = ( + self._trainable_variable_shardings, + self._non_trainable_variable_shardings, + self._optimizer_variable_shardings, + self._metrics_variable_shardings, + ) + out_shardings = (None, state_shardings) train_step = jit( - self.train_step, donate_argnums=0, out_shardings=out_shardings + self.train_step, + donate_argnums=0, + out_shardings=out_shardings, ) else: train_step = self.train_step - self.train_function = self._make_function(train_step) + step_function = self._make_function(train_step) + + self.train_function = step_function def make_test_function(self, force=False): if self.test_function is not None and not force: return - - state_shardings = ( - self._trainable_variable_shardings, - self._non_trainable_variable_shardings, - self._metrics_variable_shardings, - ) - out_shardings = self._get_out_shardings_for_step(state_shardings) - if not self.run_eagerly and self.jit_compile: + out_shardings = None + if distribution_lib.distribution() is not None: + state_shardings = ( + self._trainable_variable_shardings, + self._non_trainable_variable_shardings, + self._metrics_variable_shardings, + ) + out_shardings = (None, state_shardings) test_step = jit( - self.test_step, donate_argnums=0, out_shardings=out_shardings + self.test_step, + donate_argnums=0, + out_shardings=out_shardings, ) else: test_step = self.test_step - self.test_function = self._make_function(test_step) + step_function = self._make_function(test_step) + + self.test_function = step_function def make_predict_function(self, force=False): if self.predict_function is not None and not force: return self.predict_function - def predict_step(state, data): + def predict_step_wrapper(state, data): outputs, non_trainable_variables = self.predict_step(state, data) return outputs, (state[0], non_trainable_variables) if not self.run_eagerly and self.jit_compile: out_shardings = None - # FIX: Check if the model has been built before accessing sharding attrs - if distribution_lib.distribution() is not None and hasattr( - self, "_trainable_variable_shardings" - ): - out_shardings = ( - None, - ( - self._trainable_variable_shardings, - self._non_trainable_variable_shardings, - ), + if distribution_lib.distribution() is not None: + state_shardings = ( + self._trainable_variable_shardings, + self._non_trainable_variable_shardings, ) + out_shardings = (None, state_shardings) predict_step = jit( - predict_step, donate_argnums=0, out_shardings=out_shardings + predict_step_wrapper, + donate_argnums=0, + out_shardings=out_shardings ) + else: + predict_step = predict_step_wrapper _step_function = self._make_function( predict_step, concatenate_outputs=True @@ -896,21 +893,43 @@ def jax_state_sync(self): self._jax_state_synced = True def _record_training_state_sharding_spec(self): - self._trainable_variable_shardings = [ - v.value.sharding for v in self.trainable_variables - ] - self._non_trainable_variable_shardings = [ - v.value.sharding for v in self.non_trainable_variables - ] - if hasattr(self, "optimizer") and self.optimizer is not None: - self._optimizer_variable_shardings = [ - v.value.sharding for v in self.optimizer.variables - ] + if not self.jit_compile: + return + + distribution = distribution_lib.distribution() + + def get_partition_spec(variable): + if distribution is None: + return PartitionSpec() + + if not hasattr(distribution, "layout_map"): + return PartitionSpec() + tensor_layout = distribution.layout_map.get(variable.path) + + if tensor_layout is None: + return PartitionSpec() + return PartitionSpec(*tensor_layout.axes) + + self._trainable_variable_shardings = tuple( + get_partition_spec(v) for v in self.trainable_variables + ) + self._non_trainable_variable_shardings = tuple( + get_partition_spec(v) for v in self.non_trainable_variables + ) + + if hasattr(self, "optimizer") and self.optimizer: + self._optimizer_variable_shardings = tuple( + get_partition_spec(v) for v in self.optimizer.variables + ) + else: + self._optimizer_variable_shardings = () + + if hasattr(self, "metrics_variables"): + self._metrics_variable_shardings = tuple( + get_partition_spec(v) for v in self.metrics_variables + ) else: - self._optimizer_variable_shardings = [] - self._metrics_variable_shardings = [ - v.value.sharding for v in self.metrics_variables - ] + self._metrics_variable_shardings = () def _clear_jax_state_sharding(self): self._trainable_variable_shardings = None diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index abaa2525e4e0..2e30d9eb0628 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -5,6 +5,10 @@ from absl.testing import parameterized import keras +import contextlib +import types +import jax +from jax.sharding import NamedSharding, PartitionSpec from keras.src import backend from keras.src import initializers from keras.src import layers @@ -20,10 +24,13 @@ from keras.src.optimizers.rmsprop import RMSprop from keras.src.testing.test_utils import named_product from keras.src.trainers.data_adapters import py_dataset_adapter +from keras.src.optimizers import loss_scale_optimizer +from keras.src.optimizers import optimizer +from keras.src.backend.jax import trainer as jax_trainer_lib if backend.backend() == "jax": from keras.src.backend.jax.trainer import JAXTrainer as Trainer - from keras.src.distribution import DeviceMesh, TensorLayout, distribution_lib + from keras.src.distribution import DeviceMesh, TensorLayout, distribution_lib, DataParallel elif backend.backend() == "torch": from keras.src.backend.torch.trainer import TorchTrainer as Trainer elif backend.backend() == "tensorflow": @@ -2859,73 +2866,106 @@ def predict_step(self, *args): ) self.assertLessEqual(tracing_count[0], 2) -class ExampleModelForJAXTrainerShardingTest(models.Model): - def __init__(self, units=3, **kwargs): - super().__init__(**kwargs) - self.dense1 = layers.Dense(4, activation="relu", kernel_initializer="ones") - self.dense2 = layers.Dense(units, activation="softmax", kernel_initializer="ones") - - def call(self, x): - return self.dense2(self.dense1(x)) - +@pytest.mark.requires_trainable_backend @pytest.mark.skipif( backend.backend() != "jax", - reason="This is a JAX-specific distribution test.", + reason="This test is specific to the JAX backend trainer.", ) -class JAXTrainerShardingTest(testing.TestCase): - - def setUp(self): - super().setUp() - import jax +class JAXTrainerCorrectnessTest(testing.TestCase, parameterized.TestCase): + @parameterized.named_parameters( + ("single_device", False), + ("distributed", True), + ) + def test_jit_fit_with_out_shardings_logic(self, distributed): + def patched_record_sharding_spec(self_model): + if not self_model.jit_compile: + return + distribution = distribution_lib.distribution() + + def get_sharding_object(variable): + if distribution is None: + return None + jax_mesh = distribution.device_mesh.backend_mesh + if hasattr(distribution, "layout_map"): + tensor_layout = distribution.layout_map.get(variable.path) + if tensor_layout is not None: + return NamedSharding( + jax_mesh, PartitionSpec(*tensor_layout.axes) + ) + return NamedSharding(jax_mesh, PartitionSpec()) - if jax.device_count() < 2: - self.skipTest( - "Cannot test sharding with less than 2 devices. " - f"Found {jax.device_count()} devices." + self_model._trainable_variable_shardings = tuple( + get_sharding_object(v) for v in self_model.trainable_variables ) + self_model._non_trainable_variable_shardings = tuple( + get_sharding_object(v) for v in self_model.non_trainable_variables + ) + if hasattr(self_model, "optimizer") and self_model.optimizer: + self_model._optimizer_variable_shardings = tuple( + get_sharding_object(v) for v in self_model.optimizer.variables + ) + else: + self_model._optimizer_variable_shardings = () + if hasattr(self_model, "metrics_variables"): + self_model._metrics_variable_shardings = tuple( + get_sharding_object(v) for v in self_model.metrics_variables + ) + else: + self_model._metrics_variable_shardings = () - devices = np.array(jax.devices()) - device_mesh = DeviceMesh( - shape=(jax.device_count(),), - axis_names=("batch",), - devices=devices.flatten(), - ) - data_layout_2d = TensorLayout(axes=("batch", None), device_mesh=device_mesh) - data_layout_1d = TensorLayout(axes=("batch",), device_mesh=device_mesh) - variable_layout = TensorLayout(axes=(None, None), device_mesh=device_mesh) - - def get_layout_for_data(shape): - if not hasattr(shape, '__len__'): - return variable_layout - if len(shape) == 2: - return data_layout_2d - elif len(shape) == 1: - return data_layout_1d - return variable_layout - - mock_dist = mock.MagicMock() - mock_dist.get_data_layout.side_effect = get_layout_for_data - mock_dist.get_tensor_layout.return_value = variable_layout - mock_dist.auto_shard_dataset = False - self.distribution_mock = mock_dist - - @pytest.mark.requires_trainable_backend - def test_fit_with_sharding(self): - with mock.patch.object( - distribution_lib, "distribution", return_value=self.distribution_mock + def patched_get_jax_state( + self_model, trainable_variables=False, non_trainable_variables=False, + optimizer_variables=False, metrics_variables=False, + purge_model_variables=False ): - model = ExampleModelForJAXTrainerShardingTest(units=3) - model.compile( - optimizer="sgd", - loss="mse", - jit_compile=True + state = [] + if trainable_variables: + state.append(tuple(v.value for v in self_model.trainable_variables)) + if non_trainable_variables: + state.append(tuple(v.value for v in self_model.non_trainable_variables)) + if optimizer_variables: + state.append(tuple(v.value for v in self_model.optimizer.variables)) + if metrics_variables: + state.append(tuple(v.value for v in self_model.metrics_variables)) + return tuple(state) + + original_train_step = jax_trainer_lib.JAXTrainer.train_step + + def patched_train_step_wrapper(self, state, data): + logs, new_state = original_train_step(self, state, data) + + fixed_new_state = tuple( + tuple(var_group) if isinstance(var_group, list) else var_group + for var_group in new_state ) - x = np.ones((16, 5), dtype="float32") - y = np.zeros((16, 3), dtype="float32") - sw = np.ones((16,), dtype="float32") - - history = model.fit(x, y, sample_weight=sw, batch_size=4, epochs=2) - - self.assertIn("loss", history.history) - self.assertEqual(len(history.history["loss"]), 2) \ No newline at end of file + return logs, fixed_new_state + + x = np.random.rand(64, 8).astype("float32") + y = np.random.rand(64, 1).astype("float32") + + if distributed: + if len(jax.local_devices()) < 2: + self.skipTest("Distributed test requires at least 2 JAX devices.") + + devices = jax.local_devices() + mesh = DeviceMesh(shape=(len(devices),), axis_names=("batch",), devices=devices) + distribution = DataParallel(mesh) + + with mock.patch( + "keras.src.backend.jax.trainer.JAXTrainer.train_step", + new=patched_train_step_wrapper + ), distribution.scope(): + + model = models.Sequential([layers.Dense(4, activation="relu", input_shape=(8,)), layers.Dense(1)]) + model._record_training_state_sharding_spec = types.MethodType(patched_record_sharding_spec, model) + model._get_jax_state = types.MethodType(patched_get_jax_state, model) + model.compile(optimizer="adam", loss="mse", jit_compile=True) + model.fit(x, y, epochs=2, batch_size=32, verbose=0) + else: + with contextlib.nullcontext(): + model = models.Sequential([layers.Dense(4, activation="relu", input_shape=(8,)), layers.Dense(1)]) + model._record_training_state_sharding_spec = types.MethodType(patched_record_sharding_spec, model) + model._get_jax_state = types.MethodType(patched_get_jax_state, model) + model.compile(optimizer="adam", loss="mse", jit_compile=True) + model.fit(x, y, epochs=2, batch_size=32, verbose=0) From 18e213af4231a04df9ed4284c1739d912ff4dc42 Mon Sep 17 00:00:00 2001 From: suhana Date: Mon, 11 Aug 2025 09:57:10 +0530 Subject: [PATCH 06/12] Removing clear jax state sharding --- keras/src/backend/jax/trainer.py | 109 ++++++++++++++----------------- 1 file changed, 49 insertions(+), 60 deletions(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index cfdd026d9919..18cc1d5f04eb 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -262,12 +262,7 @@ def make_train_function(self, force=False): if not self.run_eagerly and self.jit_compile: out_shardings = None if distribution_lib.distribution() is not None: - state_shardings = ( - self._trainable_variable_shardings, - self._non_trainable_variable_shardings, - self._optimizer_variable_shardings, - self._metrics_variable_shardings, - ) + state_shardings = self._get_training_state_shardings() out_shardings = (None, state_shardings) train_step = jit( self.train_step, @@ -287,10 +282,16 @@ def make_test_function(self, force=False): if not self.run_eagerly and self.jit_compile: out_shardings = None if distribution_lib.distribution() is not None: + ( + trainable_shardings, + non_trainable_shardings, + _, # optimizer_shardings + metrics_shardings, + ) = self._get_training_state_shardings() state_shardings = ( - self._trainable_variable_shardings, - self._non_trainable_variable_shardings, - self._metrics_variable_shardings, + trainable_shardings, + non_trainable_shardings, + metrics_shardings, ) out_shardings = (None, state_shardings) test_step = jit( @@ -309,25 +310,29 @@ def make_predict_function(self, force=False): if self.predict_function is not None and not force: return self.predict_function - def predict_step_wrapper(state, data): + def predict_step(state, data): outputs, non_trainable_variables = self.predict_step(state, data) return outputs, (state[0], non_trainable_variables) if not self.run_eagerly and self.jit_compile: out_shardings = None if distribution_lib.distribution() is not None: + ( + trainable_shardings, + non_trainable_shardings, + _, # optimizer_shardings + _, # metrics_shardings + ) = self._get_training_state_shardings() state_shardings = ( - self._trainable_variable_shardings, - self._non_trainable_variable_shardings, + trainable_shardings, + non_trainable_shardings, ) out_shardings = (None, state_shardings) predict_step = jit( - predict_step_wrapper, + predict_step, donate_argnums=0, - out_shardings=out_shardings + out_shardings=out_shardings, ) - else: - predict_step = predict_step_wrapper _step_function = self._make_function( predict_step, concatenate_outputs=True @@ -410,7 +415,6 @@ def fit( steps=epoch_iterator.num_batches, model=self, ) - self._record_training_state_sharding_spec() self.make_train_function() self.stop_training = False @@ -526,7 +530,6 @@ def fit( if training_finished: callbacks.on_train_end(logs=training_logs) self._jax_state = None - self._clear_jax_state_sharding() return self.history @traceback_utils.filter_traceback @@ -576,7 +579,6 @@ def evaluate( steps=epoch_iterator.num_batches, model=self, ) - self._record_training_state_sharding_spec() self.make_test_function() self.stop_evaluating = False @@ -628,9 +630,6 @@ def evaluate( logs = self._get_metrics_result_or_logs(logs) callbacks.on_test_end(logs) self._jax_state = None - if not use_cached_eval_dataset: - # Only clear sharding if evaluate is not called from `fit`. - self._clear_jax_state_sharding() if return_dict: return logs return self._flatten_metrics_in_order(logs) @@ -672,7 +671,6 @@ def predict( steps=epoch_iterator.num_batches, model=self, ) - self._record_training_state_sharding_spec() self.make_predict_function() self.stop_predicting = False @@ -731,7 +729,6 @@ def append_to_outputs(batch_outputs, outputs): self.jax_state_sync() callbacks.on_predict_end() self._jax_state = None - self._clear_jax_state_sharding() return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) def train_on_batch( @@ -760,7 +757,6 @@ def data(): # Maybe build model self._symbolic_build(data_batch=next(data())) - self._record_training_state_sharding_spec() self.make_train_function() # Train step @@ -809,7 +805,6 @@ def data(): # Maybe build model self._symbolic_build(data_batch=next(data())) - self._record_training_state_sharding_spec() self.make_test_function() # Test step @@ -842,7 +837,6 @@ def predict_on_batch(self, x): # Build model with backend.StatelessScope(): self(x) - self._record_training_state_sharding_spec() self.make_predict_function() state = self._get_jax_state( @@ -892,50 +886,45 @@ def jax_state_sync(self): ref_v.assign(v) self._jax_state_synced = True - def _record_training_state_sharding_spec(self): - if not self.jit_compile: - return - + def _get_training_state_shardings(self): distribution = distribution_lib.distribution() - - def get_partition_spec(variable): - if distribution is None: - return PartitionSpec() - - if not hasattr(distribution, "layout_map"): - return PartitionSpec() - tensor_layout = distribution.layout_map.get(variable.path) - - if tensor_layout is None: - return PartitionSpec() - return PartitionSpec(*tensor_layout.axes) - - self._trainable_variable_shardings = tuple( - get_partition_spec(v) for v in self.trainable_variables + mesh = distribution.device_mesh.backend_mesh + + def get_sharding(variable): + partition_spec = PartitionSpec() + if hasattr(distribution, "layout_map"): + tensor_layout = distribution.layout_map.get(variable.path) + if tensor_layout is not None: + partition_spec = PartitionSpec(*tensor_layout.axes) + return NamedSharding(mesh, partition_spec) + + trainable_shardings = tuple( + get_sharding(v) for v in self.trainable_variables ) - self._non_trainable_variable_shardings = tuple( - get_partition_spec(v) for v in self.non_trainable_variables + non_trainable_shardings = tuple( + get_sharding(v) for v in self.non_trainable_variables ) - + if hasattr(self, "optimizer") and self.optimizer: - self._optimizer_variable_shardings = tuple( - get_partition_spec(v) for v in self.optimizer.variables + optimizer_shardings = tuple( + get_sharding(v) for v in self.optimizer.variables ) else: - self._optimizer_variable_shardings = () + optimizer_shardings = () if hasattr(self, "metrics_variables"): - self._metrics_variable_shardings = tuple( - get_partition_spec(v) for v in self.metrics_variables + metrics_shardings = tuple( + get_sharding(v) for v in self.metrics_variables ) else: - self._metrics_variable_shardings = () + metrics_shardings = () - def _clear_jax_state_sharding(self): - self._trainable_variable_shardings = None - self._non_trainable_variable_shardings = None - self._optimizer_variable_shardings = None - self._metrics_variable_shardings = None + return ( + trainable_shardings, + non_trainable_shardings, + optimizer_shardings, + metrics_shardings, + ) def _purge_model_variables( self, From ed6ae345e9efcc9472cd1375371bf8d2245b768a Mon Sep 17 00:00:00 2001 From: suhana Date: Tue, 12 Aug 2025 11:19:20 +0530 Subject: [PATCH 07/12] Reworked on comments --- keras/src/backend/jax/trainer.py | 51 ++++++------ keras/src/trainers/trainer_test.py | 123 ++++++++++------------------- 2 files changed, 67 insertions(+), 107 deletions(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 18cc1d5f04eb..cc0125dac4e1 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -19,7 +19,6 @@ from keras.src.trainers.data_adapters import data_adapter_utils from keras.src.trainers.epoch_iterator import EpochIterator from keras.src.utils import traceback_utils -from jax.sharding import PartitionSpec, NamedSharding if is_nnx_enabled(): from flax import nnx @@ -887,37 +886,39 @@ def jax_state_sync(self): self._jax_state_synced = True def _get_training_state_shardings(self): + """Retrieves the sharding specifications for all training-related state. + + This method reads the pre-computed sharding specification directly from + each variable's `.value.sharding` attribute. The returned structure is a + tuple of lists to exactly match the PyTree structure of the state data + itself. + """ distribution = distribution_lib.distribution() - mesh = distribution.device_mesh.backend_mesh - - def get_sharding(variable): - partition_spec = PartitionSpec() - if hasattr(distribution, "layout_map"): - tensor_layout = distribution.layout_map.get(variable.path) - if tensor_layout is not None: - partition_spec = PartitionSpec(*tensor_layout.axes) - return NamedSharding(mesh, partition_spec) - - trainable_shardings = tuple( - get_sharding(v) for v in self.trainable_variables - ) - non_trainable_shardings = tuple( - get_sharding(v) for v in self.non_trainable_variables - ) + if distribution is None: + return None + + # Change the inner comprehensions from tuple() to [] to match the + # data structure, which is a list of variables. + trainable_shardings = [ + v.value.sharding for v in self.trainable_variables + ] + non_trainable_shardings = [ + v.value.sharding for v in self.non_trainable_variables + ] if hasattr(self, "optimizer") and self.optimizer: - optimizer_shardings = tuple( - get_sharding(v) for v in self.optimizer.variables - ) + optimizer_shardings = [ + v.value.sharding for v in self.optimizer.variables + ] else: - optimizer_shardings = () + optimizer_shardings = [] if hasattr(self, "metrics_variables"): - metrics_shardings = tuple( - get_sharding(v) for v in self.metrics_variables - ) + metrics_shardings = [ + v.value.sharding for v in self.metrics_variables + ] else: - metrics_shardings = () + metrics_shardings = [] return ( trainable_shardings, diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index 2e30d9eb0628..32acb319df13 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -1,14 +1,12 @@ +import types from unittest import mock +import jax import numpy as np import pytest from absl.testing import parameterized import keras -import contextlib -import types -import jax -from jax.sharding import NamedSharding, PartitionSpec from keras.src import backend from keras.src import initializers from keras.src import layers @@ -20,17 +18,16 @@ from keras.src import testing from keras.src.backend import config from keras.src.backend.common.symbolic_scope import in_symbolic_scope +from keras.src.backend.jax import trainer as jax_trainer_lib from keras.src.callbacks.callback import Callback from keras.src.optimizers.rmsprop import RMSprop from keras.src.testing.test_utils import named_product from keras.src.trainers.data_adapters import py_dataset_adapter -from keras.src.optimizers import loss_scale_optimizer -from keras.src.optimizers import optimizer -from keras.src.backend.jax import trainer as jax_trainer_lib if backend.backend() == "jax": from keras.src.backend.jax.trainer import JAXTrainer as Trainer - from keras.src.distribution import DeviceMesh, TensorLayout, distribution_lib, DataParallel + from keras.src.distribution import DataParallel + from keras.src.distribution import DeviceMesh elif backend.backend() == "torch": from keras.src.backend.torch.trainer import TorchTrainer as Trainer elif backend.backend() == "tensorflow": @@ -2866,106 +2863,68 @@ def predict_step(self, *args): ) self.assertLessEqual(tracing_count[0], 2) -@pytest.mark.requires_trainable_backend -@pytest.mark.skipif( - backend.backend() != "jax", - reason="This test is specific to the JAX backend trainer.", -) -class JAXTrainerCorrectnessTest(testing.TestCase, parameterized.TestCase): @parameterized.named_parameters( ("single_device", False), ("distributed", True), ) def test_jit_fit_with_out_shardings_logic(self, distributed): - def patched_record_sharding_spec(self_model): - if not self_model.jit_compile: - return - distribution = distribution_lib.distribution() - - def get_sharding_object(variable): - if distribution is None: - return None - jax_mesh = distribution.device_mesh.backend_mesh - if hasattr(distribution, "layout_map"): - tensor_layout = distribution.layout_map.get(variable.path) - if tensor_layout is not None: - return NamedSharding( - jax_mesh, PartitionSpec(*tensor_layout.axes) - ) - return NamedSharding(jax_mesh, PartitionSpec()) - - self_model._trainable_variable_shardings = tuple( - get_sharding_object(v) for v in self_model.trainable_variables - ) - self_model._non_trainable_variable_shardings = tuple( - get_sharding_object(v) for v in self_model.non_trainable_variables - ) - if hasattr(self_model, "optimizer") and self_model.optimizer: - self_model._optimizer_variable_shardings = tuple( - get_sharding_object(v) for v in self_model.optimizer.variables - ) - else: - self_model._optimizer_variable_shardings = () - if hasattr(self_model, "metrics_variables"): - self_model._metrics_variable_shardings = tuple( - get_sharding_object(v) for v in self_model.metrics_variables - ) - else: - self_model._metrics_variable_shardings = () - def patched_get_jax_state( - self_model, trainable_variables=False, non_trainable_variables=False, - optimizer_variables=False, metrics_variables=False, - purge_model_variables=False + self_model, + trainable_variables=False, + non_trainable_variables=False, + optimizer_variables=False, + metrics_variables=False, + purge_model_variables=False, ): state = [] if trainable_variables: - state.append(tuple(v.value for v in self_model.trainable_variables)) + state.append([v.value for v in self_model.trainable_variables]) if non_trainable_variables: - state.append(tuple(v.value for v in self_model.non_trainable_variables)) + state.append( + [v.value for v in self_model.non_trainable_variables] + ) if optimizer_variables: - state.append(tuple(v.value for v in self_model.optimizer.variables)) + state.append([v.value for v in self_model.optimizer.variables]) if metrics_variables: - state.append(tuple(v.value for v in self_model.metrics_variables)) + state.append([v.value for v in self_model.metrics_variables]) return tuple(state) original_train_step = jax_trainer_lib.JAXTrainer.train_step def patched_train_step_wrapper(self, state, data): logs, new_state = original_train_step(self, state, data) + return logs, new_state - fixed_new_state = tuple( - tuple(var_group) if isinstance(var_group, list) else var_group - for var_group in new_state - ) - - return logs, fixed_new_state - x = np.random.rand(64, 8).astype("float32") y = np.random.rand(64, 1).astype("float32") if distributed: if len(jax.local_devices()) < 2: - self.skipTest("Distributed test requires at least 2 JAX devices.") - + self.skipTest( + "Distributed test requires at least 2 JAX devices." + ) + devices = jax.local_devices() - mesh = DeviceMesh(shape=(len(devices),), axis_names=("batch",), devices=devices) + mesh = DeviceMesh( + shape=(len(devices),), axis_names=("batch",), devices=devices + ) distribution = DataParallel(mesh) - with mock.patch( - "keras.src.backend.jax.trainer.JAXTrainer.train_step", - new=patched_train_step_wrapper - ), distribution.scope(): - - model = models.Sequential([layers.Dense(4, activation="relu", input_shape=(8,)), layers.Dense(1)]) - model._record_training_state_sharding_spec = types.MethodType(patched_record_sharding_spec, model) - model._get_jax_state = types.MethodType(patched_get_jax_state, model) - model.compile(optimizer="adam", loss="mse", jit_compile=True) - model.fit(x, y, epochs=2, batch_size=32, verbose=0) - else: - with contextlib.nullcontext(): - model = models.Sequential([layers.Dense(4, activation="relu", input_shape=(8,)), layers.Dense(1)]) - model._record_training_state_sharding_spec = types.MethodType(patched_record_sharding_spec, model) - model._get_jax_state = types.MethodType(patched_get_jax_state, model) + with ( + mock.patch( + "keras.src.backend.jax.trainer.JAXTrainer.train_step", + new=patched_train_step_wrapper, + ), + distribution.scope(), + ): + model = models.Sequential( + [ + layers.Dense(4, activation="relu", input_shape=(8,)), + layers.Dense(1), + ] + ) + model._get_jax_state = types.MethodType( + patched_get_jax_state, model + ) model.compile(optimizer="adam", loss="mse", jit_compile=True) model.fit(x, y, epochs=2, batch_size=32, verbose=0) From 9f6e147c4a2e6553d82f5a73efb8d6dd3cb4855b Mon Sep 17 00:00:00 2001 From: suhana Date: Tue, 12 Aug 2025 11:44:56 +0530 Subject: [PATCH 08/12] Reworked on comments --- keras/src/backend/jax/trainer.py | 12 ------------ 1 file changed, 12 deletions(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index cc0125dac4e1..284771c364f7 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -886,26 +886,15 @@ def jax_state_sync(self): self._jax_state_synced = True def _get_training_state_shardings(self): - """Retrieves the sharding specifications for all training-related state. - - This method reads the pre-computed sharding specification directly from - each variable's `.value.sharding` attribute. The returned structure is a - tuple of lists to exactly match the PyTree structure of the state data - itself. - """ distribution = distribution_lib.distribution() if distribution is None: return None - - # Change the inner comprehensions from tuple() to [] to match the - # data structure, which is a list of variables. trainable_shardings = [ v.value.sharding for v in self.trainable_variables ] non_trainable_shardings = [ v.value.sharding for v in self.non_trainable_variables ] - if hasattr(self, "optimizer") and self.optimizer: optimizer_shardings = [ v.value.sharding for v in self.optimizer.variables @@ -919,7 +908,6 @@ def _get_training_state_shardings(self): ] else: metrics_shardings = [] - return ( trainable_shardings, non_trainable_shardings, From 88694aa3aeea75bec961e98898b1f1e88a679b1c Mon Sep 17 00:00:00 2001 From: suhana Date: Tue, 12 Aug 2025 11:59:07 +0530 Subject: [PATCH 09/12] Reworked on comments --- keras/src/backend/jax/trainer.py | 29 ++++++++--------------------- 1 file changed, 8 insertions(+), 21 deletions(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 284771c364f7..dc36a45f6f46 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -886,34 +886,21 @@ def jax_state_sync(self): self._jax_state_synced = True def _get_training_state_shardings(self): - distribution = distribution_lib.distribution() - if distribution is None: - return None - trainable_shardings = [ + self._trainable_variable_shardings = [ v.value.sharding for v in self.trainable_variables ] - non_trainable_shardings = [ + self._non_trainable_variable_shardings = [ v.value.sharding for v in self.non_trainable_variables ] - if hasattr(self, "optimizer") and self.optimizer: - optimizer_shardings = [ + if hasattr(self, "optimizer") and self.optimizer is not None: + self._optimizer_variable_shardings = [ v.value.sharding for v in self.optimizer.variables ] else: - optimizer_shardings = [] - - if hasattr(self, "metrics_variables"): - metrics_shardings = [ - v.value.sharding for v in self.metrics_variables - ] - else: - metrics_shardings = [] - return ( - trainable_shardings, - non_trainable_shardings, - optimizer_shardings, - metrics_shardings, - ) + self._optimizer_variable_shardings = [] + self._metrics_variable_shardings = [ + v.value.sharding for v in self.metrics_variables + ] def _purge_model_variables( self, From d534ec34c38870c82c90bef57f9c0d5c448ca506 Mon Sep 17 00:00:00 2001 From: suhana Date: Tue, 12 Aug 2025 12:05:39 +0530 Subject: [PATCH 10/12] Made minor changes to function names --- keras/src/backend/jax/trainer.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index dc36a45f6f46..da7be0a61158 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -261,7 +261,7 @@ def make_train_function(self, force=False): if not self.run_eagerly and self.jit_compile: out_shardings = None if distribution_lib.distribution() is not None: - state_shardings = self._get_training_state_shardings() + state_shardings = self._record_training_state_sharding_spec() out_shardings = (None, state_shardings) train_step = jit( self.train_step, @@ -286,7 +286,7 @@ def make_test_function(self, force=False): non_trainable_shardings, _, # optimizer_shardings metrics_shardings, - ) = self._get_training_state_shardings() + ) = self._record_training_state_sharding_spec() state_shardings = ( trainable_shardings, non_trainable_shardings, @@ -321,7 +321,7 @@ def predict_step(state, data): non_trainable_shardings, _, # optimizer_shardings _, # metrics_shardings - ) = self._get_training_state_shardings() + ) = self._record_training_state_sharding_spec() state_shardings = ( trainable_shardings, non_trainable_shardings, @@ -885,7 +885,7 @@ def jax_state_sync(self): ref_v.assign(v) self._jax_state_synced = True - def _get_training_state_shardings(self): + def _record_training_state_sharding_spec(self): self._trainable_variable_shardings = [ v.value.sharding for v in self.trainable_variables ] From b3b26f463080161146836cba982ad76b95662fbf Mon Sep 17 00:00:00 2001 From: suhana Date: Wed, 13 Aug 2025 20:12:14 +0530 Subject: [PATCH 11/12] Modifying get_state_sharding_spec and adding unit test for ensuring corectness --- keras/src/backend/jax/trainer.py | 26 +++++----- keras/src/trainers/trainer_test.py | 77 ++++++++++++------------------ 2 files changed, 45 insertions(+), 58 deletions(-) diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index da7be0a61158..38470ad45869 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -261,7 +261,7 @@ def make_train_function(self, force=False): if not self.run_eagerly and self.jit_compile: out_shardings = None if distribution_lib.distribution() is not None: - state_shardings = self._record_training_state_sharding_spec() + state_shardings = self._get_state_sharding_spec() out_shardings = (None, state_shardings) train_step = jit( self.train_step, @@ -286,7 +286,7 @@ def make_test_function(self, force=False): non_trainable_shardings, _, # optimizer_shardings metrics_shardings, - ) = self._record_training_state_sharding_spec() + ) = self._get_state_sharding_spec() state_shardings = ( trainable_shardings, non_trainable_shardings, @@ -321,7 +321,7 @@ def predict_step(state, data): non_trainable_shardings, _, # optimizer_shardings _, # metrics_shardings - ) = self._record_training_state_sharding_spec() + ) = self._get_state_sharding_spec() state_shardings = ( trainable_shardings, non_trainable_shardings, @@ -885,22 +885,26 @@ def jax_state_sync(self): ref_v.assign(v) self._jax_state_synced = True - def _record_training_state_sharding_spec(self): - self._trainable_variable_shardings = [ + def _get_state_sharding_spec(self): + trainable_shardings = [ v.value.sharding for v in self.trainable_variables ] - self._non_trainable_variable_shardings = [ + non_trainable_shardings = [ v.value.sharding for v in self.non_trainable_variables ] if hasattr(self, "optimizer") and self.optimizer is not None: - self._optimizer_variable_shardings = [ + optimizer_shardings = [ v.value.sharding for v in self.optimizer.variables ] else: - self._optimizer_variable_shardings = [] - self._metrics_variable_shardings = [ - v.value.sharding for v in self.metrics_variables - ] + optimizer_shardings = [] + metrics_shardings = [v.value.sharding for v in self.metrics_variables] + return ( + trainable_shardings, + non_trainable_shardings, + optimizer_shardings, + metrics_shardings, + ) def _purge_model_variables( self, diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index 32acb319df13..090eabcaf5fa 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -1,4 +1,3 @@ -import types from unittest import mock import jax @@ -18,9 +17,11 @@ from keras.src import testing from keras.src.backend import config from keras.src.backend.common.symbolic_scope import in_symbolic_scope -from keras.src.backend.jax import trainer as jax_trainer_lib from keras.src.callbacks.callback import Callback +from keras.src.distribution.distribution_lib import DataParallel +from keras.src.distribution.distribution_lib import DeviceMesh from keras.src.optimizers.rmsprop import RMSprop +from keras.src.testing import test_case from keras.src.testing.test_utils import named_product from keras.src.trainers.data_adapters import py_dataset_adapter @@ -2863,41 +2864,17 @@ def predict_step(self, *args): ) self.assertLessEqual(tracing_count[0], 2) + +class JAXTrainerCorrectnessTest(test_case.TestCase, parameterized.TestCase): @parameterized.named_parameters( ("single_device", False), ("distributed", True), ) def test_jit_fit_with_out_shardings_logic(self, distributed): - def patched_get_jax_state( - self_model, - trainable_variables=False, - non_trainable_variables=False, - optimizer_variables=False, - metrics_variables=False, - purge_model_variables=False, - ): - state = [] - if trainable_variables: - state.append([v.value for v in self_model.trainable_variables]) - if non_trainable_variables: - state.append( - [v.value for v in self_model.non_trainable_variables] - ) - if optimizer_variables: - state.append([v.value for v in self_model.optimizer.variables]) - if metrics_variables: - state.append([v.value for v in self_model.metrics_variables]) - return tuple(state) - - original_train_step = jax_trainer_lib.JAXTrainer.train_step - - def patched_train_step_wrapper(self, state, data): - logs, new_state = original_train_step(self, state, data) - return logs, new_state - x = np.random.rand(64, 8).astype("float32") y = np.random.rand(64, 1).astype("float32") + distribution = None if distributed: if len(jax.local_devices()) < 2: self.skipTest( @@ -2910,21 +2887,27 @@ def patched_train_step_wrapper(self, state, data): ) distribution = DataParallel(mesh) - with ( - mock.patch( - "keras.src.backend.jax.trainer.JAXTrainer.train_step", - new=patched_train_step_wrapper, - ), - distribution.scope(), - ): - model = models.Sequential( - [ - layers.Dense(4, activation="relu", input_shape=(8,)), - layers.Dense(1), - ] - ) - model._get_jax_state = types.MethodType( - patched_get_jax_state, model - ) - model.compile(optimizer="adam", loss="mse", jit_compile=True) - model.fit(x, y, epochs=2, batch_size=32, verbose=0) + scope = distribution.scope() if distribution else mock.MagicMock() + + with scope: + model = models.Sequential( + [ + layers.Dense(4, activation="relu", input_shape=(8,)), + layers.Dense(1), + ] + ) + model.compile(optimizer="adam", loss="mse", jit_compile=True) + + if distribution: + expected_shardings = [ + v.value.sharding for v in model.trainable_variables + ] + self.assertNotEqual(len(set(expected_shardings)), 1) + + model.fit(x, y, epochs=2, batch_size=32, verbose=0) + + if distribution: + actual_shardings = [ + v.value.sharding for v in model.trainable_variables + ] + self.assertListEqual(actual_shardings, expected_shardings) From bd8e0a6f8cc1768aa4666aa9f669ee6a6d964a38 Mon Sep 17 00:00:00 2001 From: suhana Date: Wed, 13 Aug 2025 20:24:47 +0530 Subject: [PATCH 12/12] Modying test to skip if backend is not jax --- keras/src/trainers/trainer_test.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index 090eabcaf5fa..05e910aa6038 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -2871,6 +2871,8 @@ class JAXTrainerCorrectnessTest(test_case.TestCase, parameterized.TestCase): ("distributed", True), ) def test_jit_fit_with_out_shardings_logic(self, distributed): + if keras.backend.backend() != "jax": + self.skipTest("This test requires the JAX backend.") x = np.random.rand(64, 8).astype("float32") y = np.random.rand(64, 1).astype("float32")