diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index 33a415a38f7..38470ad4586 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -153,7 +153,7 @@ def train_step(self, state, data): metrics_variables, unscaled_loss, x, y, y_pred, sample_weight ) - state = self._enforce_jax_state_sharding( + state = ( trainable_variables, non_trainable_variables, optimizer_variables, @@ -185,17 +185,6 @@ def test_step(self, state, data): metrics_variables, unscaled_loss, x, y, y_pred, sample_weight ) - ( - trainable_variables, - non_trainable_variables, - _, - metrics_variables, - ) = self._enforce_jax_state_sharding( - trainable_variables=trainable_variables, - non_trainable_variables=non_trainable_variables, - optimizer_variables=None, - metrics_variables=metrics_variables, - ) state = ( trainable_variables, non_trainable_variables, @@ -213,17 +202,6 @@ def predict_step(self, state, data): outputs, non_trainable_variables = self.stateless_call( trainable_variables, non_trainable_variables, x, **kwargs ) - ( - _, - non_trainable_variables, - _, - _, - ) = self._enforce_jax_state_sharding( - trainable_variables=None, - non_trainable_variables=non_trainable_variables, - optimizer_variables=None, - metrics_variables=None, - ) return outputs, non_trainable_variables def _make_function(self, step_function, concatenate_outputs=False): @@ -281,11 +259,15 @@ def make_train_function(self, force=False): if self.train_function is not None and not force: return if not self.run_eagerly and self.jit_compile: - # Note that we mark the state to be donated to jax, - # so that jax will reuse the memory buffer for outputs. - # This will reduce the memory usage of the training function by - # half. - train_step = jit(self.train_step, donate_argnums=0) + out_shardings = None + if distribution_lib.distribution() is not None: + state_shardings = self._get_state_sharding_spec() + out_shardings = (None, state_shardings) + train_step = jit( + self.train_step, + donate_argnums=0, + out_shardings=out_shardings, + ) else: train_step = self.train_step @@ -297,12 +279,25 @@ def make_test_function(self, force=False): if self.test_function is not None and not force: return if not self.run_eagerly and self.jit_compile: - # Note that we mark the state to be donated to jax, - # so that jax will reuse the memory buffer for outputs. - # This will reduce the memory usage of the training function by - # half. - test_step = jit(self.test_step, donate_argnums=0) - + out_shardings = None + if distribution_lib.distribution() is not None: + ( + trainable_shardings, + non_trainable_shardings, + _, # optimizer_shardings + metrics_shardings, + ) = self._get_state_sharding_spec() + state_shardings = ( + trainable_shardings, + non_trainable_shardings, + metrics_shardings, + ) + out_shardings = (None, state_shardings) + test_step = jit( + self.test_step, + donate_argnums=0, + out_shardings=out_shardings, + ) else: test_step = self.test_step @@ -319,7 +314,24 @@ def predict_step(state, data): return outputs, (state[0], non_trainable_variables) if not self.run_eagerly and self.jit_compile: - predict_step = jit(predict_step, donate_argnums=0) + out_shardings = None + if distribution_lib.distribution() is not None: + ( + trainable_shardings, + non_trainable_shardings, + _, # optimizer_shardings + _, # metrics_shardings + ) = self._get_state_sharding_spec() + state_shardings = ( + trainable_shardings, + non_trainable_shardings, + ) + out_shardings = (None, state_shardings) + predict_step = jit( + predict_step, + donate_argnums=0, + out_shardings=out_shardings, + ) _step_function = self._make_function( predict_step, concatenate_outputs=True @@ -402,7 +414,6 @@ def fit( steps=epoch_iterator.num_batches, model=self, ) - self._record_training_state_sharding_spec() self.make_train_function() self.stop_training = False @@ -518,7 +529,6 @@ def fit( if training_finished: callbacks.on_train_end(logs=training_logs) self._jax_state = None - self._clear_jax_state_sharding() return self.history @traceback_utils.filter_traceback @@ -568,7 +578,6 @@ def evaluate( steps=epoch_iterator.num_batches, model=self, ) - self._record_training_state_sharding_spec() self.make_test_function() self.stop_evaluating = False @@ -620,9 +629,6 @@ def evaluate( logs = self._get_metrics_result_or_logs(logs) callbacks.on_test_end(logs) self._jax_state = None - if not use_cached_eval_dataset: - # Only clear sharding if evaluate is not called from `fit`. - self._clear_jax_state_sharding() if return_dict: return logs return self._flatten_metrics_in_order(logs) @@ -664,7 +670,6 @@ def predict( steps=epoch_iterator.num_batches, model=self, ) - self._record_training_state_sharding_spec() self.make_predict_function() self.stop_predicting = False @@ -723,7 +728,6 @@ def append_to_outputs(batch_outputs, outputs): self.jax_state_sync() callbacks.on_predict_end() self._jax_state = None - self._clear_jax_state_sharding() return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs) def train_on_batch( @@ -752,7 +756,6 @@ def data(): # Maybe build model self._symbolic_build(data_batch=next(data())) - self._record_training_state_sharding_spec() self.make_train_function() # Train step @@ -801,7 +804,6 @@ def data(): # Maybe build model self._symbolic_build(data_batch=next(data())) - self._record_training_state_sharding_spec() self.make_test_function() # Test step @@ -834,7 +836,6 @@ def predict_on_batch(self, x): # Build model with backend.StatelessScope(): self(x) - self._record_training_state_sharding_spec() self.make_predict_function() state = self._get_jax_state( @@ -884,75 +885,25 @@ def jax_state_sync(self): ref_v.assign(v) self._jax_state_synced = True - def _record_training_state_sharding_spec(self): - self._trainable_variable_shardings = [ + def _get_state_sharding_spec(self): + trainable_shardings = [ v.value.sharding for v in self.trainable_variables ] - self._non_trainable_variable_shardings = [ + non_trainable_shardings = [ v.value.sharding for v in self.non_trainable_variables ] if hasattr(self, "optimizer") and self.optimizer is not None: - self._optimizer_variable_shardings = [ + optimizer_shardings = [ v.value.sharding for v in self.optimizer.variables ] else: - self._optimizer_variable_shardings = [] - self._metrics_variable_shardings = [ - v.value.sharding for v in self.metrics_variables - ] - - def _clear_jax_state_sharding(self): - self._trainable_variable_shardings = None - self._non_trainable_variable_shardings = None - self._optimizer_variable_shardings = None - self._metrics_variable_shardings = None - - def _enforce_jax_state_sharding( - self, - trainable_variables=None, - non_trainable_variables=None, - optimizer_variables=None, - metrics_variables=None, - ): - """Enforce the sharding spec constraint for all the training state. - - Since the output of the train/eval step will be used as inputs to next - step, we need to ensure that they have the same sharding spec, so that - nnx.jit/jax.jit won't have to recompile the train/eval function. - - Note that this function will also rely on the recorded sharding spec - for each of states. - - This function is expected to be called within the jitted train/eval - function, especially around the end of the function. - """ - trainable_variables = trainable_variables or [] - non_trainable_variables = non_trainable_variables or [] - optimizer_variables = optimizer_variables or [] - metrics_variables = metrics_variables or [] - - for i in range(len(trainable_variables)): - trainable_variables[i] = jax.lax.with_sharding_constraint( - trainable_variables[i], self._trainable_variable_shardings[i] - ) - for i in range(len(non_trainable_variables)): - non_trainable_variables[i] = jax.lax.with_sharding_constraint( - non_trainable_variables[i], - self._non_trainable_variable_shardings[i], - ) - for i in range(len(optimizer_variables)): - optimizer_variables[i] = jax.lax.with_sharding_constraint( - optimizer_variables[i], self._optimizer_variable_shardings[i] - ) - for i in range(len(metrics_variables)): - metrics_variables[i] = jax.lax.with_sharding_constraint( - metrics_variables[i], self._metrics_variable_shardings[i] - ) + optimizer_shardings = [] + metrics_shardings = [v.value.sharding for v in self.metrics_variables] return ( - trainable_variables, - non_trainable_variables, - optimizer_variables, - metrics_variables, + trainable_shardings, + non_trainable_shardings, + optimizer_shardings, + metrics_shardings, ) def _purge_model_variables( diff --git a/keras/src/trainers/trainer_test.py b/keras/src/trainers/trainer_test.py index 21080dea2bf..05e910aa603 100644 --- a/keras/src/trainers/trainer_test.py +++ b/keras/src/trainers/trainer_test.py @@ -1,5 +1,6 @@ from unittest import mock +import jax import numpy as np import pytest from absl.testing import parameterized @@ -17,12 +18,17 @@ from keras.src.backend import config from keras.src.backend.common.symbolic_scope import in_symbolic_scope from keras.src.callbacks.callback import Callback +from keras.src.distribution.distribution_lib import DataParallel +from keras.src.distribution.distribution_lib import DeviceMesh from keras.src.optimizers.rmsprop import RMSprop +from keras.src.testing import test_case from keras.src.testing.test_utils import named_product from keras.src.trainers.data_adapters import py_dataset_adapter if backend.backend() == "jax": from keras.src.backend.jax.trainer import JAXTrainer as Trainer + from keras.src.distribution import DataParallel + from keras.src.distribution import DeviceMesh elif backend.backend() == "torch": from keras.src.backend.torch.trainer import TorchTrainer as Trainer elif backend.backend() == "tensorflow": @@ -2857,3 +2863,53 @@ def predict_step(self, *args): verbose=0, ) self.assertLessEqual(tracing_count[0], 2) + + +class JAXTrainerCorrectnessTest(test_case.TestCase, parameterized.TestCase): + @parameterized.named_parameters( + ("single_device", False), + ("distributed", True), + ) + def test_jit_fit_with_out_shardings_logic(self, distributed): + if keras.backend.backend() != "jax": + self.skipTest("This test requires the JAX backend.") + x = np.random.rand(64, 8).astype("float32") + y = np.random.rand(64, 1).astype("float32") + + distribution = None + if distributed: + if len(jax.local_devices()) < 2: + self.skipTest( + "Distributed test requires at least 2 JAX devices." + ) + + devices = jax.local_devices() + mesh = DeviceMesh( + shape=(len(devices),), axis_names=("batch",), devices=devices + ) + distribution = DataParallel(mesh) + + scope = distribution.scope() if distribution else mock.MagicMock() + + with scope: + model = models.Sequential( + [ + layers.Dense(4, activation="relu", input_shape=(8,)), + layers.Dense(1), + ] + ) + model.compile(optimizer="adam", loss="mse", jit_compile=True) + + if distribution: + expected_shardings = [ + v.value.sharding for v in model.trainable_variables + ] + self.assertNotEqual(len(set(expected_shardings)), 1) + + model.fit(x, y, epochs=2, batch_size=32, verbose=0) + + if distribution: + actual_shardings = [ + v.value.sharding for v in model.trainable_variables + ] + self.assertListEqual(actual_shardings, expected_shardings)