Skip to content

Commit 0387d30

Browse files
Refactor(JAX): Use jax.jit's out_shardings instead of _enforce_jax_state_sharding (#21559)
* Refactor JAXTrainer sharding to use out_shardings * Added tests for the Jax out sharding * Update keras/src/trainers/trainer_test.py for nit changes Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> * Added a helper to reduce code duplication * Updated the test and the out sharding logic * Removing clear jax state sharding * Reworked on comments * Reworked on comments * Reworked on comments * Made minor changes to function names * Modifying get_state_sharding_spec and adding unit test for ensuring corectness * Modying test to skip if backend is not jax --------- Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent ce0d278 commit 0387d30

File tree

2 files changed

+113
-106
lines changed

2 files changed

+113
-106
lines changed

keras/src/backend/jax/trainer.py

Lines changed: 57 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def train_step(self, state, data):
153153
metrics_variables, unscaled_loss, x, y, y_pred, sample_weight
154154
)
155155

156-
state = self._enforce_jax_state_sharding(
156+
state = (
157157
trainable_variables,
158158
non_trainable_variables,
159159
optimizer_variables,
@@ -185,17 +185,6 @@ def test_step(self, state, data):
185185
metrics_variables, unscaled_loss, x, y, y_pred, sample_weight
186186
)
187187

188-
(
189-
trainable_variables,
190-
non_trainable_variables,
191-
_,
192-
metrics_variables,
193-
) = self._enforce_jax_state_sharding(
194-
trainable_variables=trainable_variables,
195-
non_trainable_variables=non_trainable_variables,
196-
optimizer_variables=None,
197-
metrics_variables=metrics_variables,
198-
)
199188
state = (
200189
trainable_variables,
201190
non_trainable_variables,
@@ -213,17 +202,6 @@ def predict_step(self, state, data):
213202
outputs, non_trainable_variables = self.stateless_call(
214203
trainable_variables, non_trainable_variables, x, **kwargs
215204
)
216-
(
217-
_,
218-
non_trainable_variables,
219-
_,
220-
_,
221-
) = self._enforce_jax_state_sharding(
222-
trainable_variables=None,
223-
non_trainable_variables=non_trainable_variables,
224-
optimizer_variables=None,
225-
metrics_variables=None,
226-
)
227205
return outputs, non_trainable_variables
228206

229207
def _make_function(self, step_function, concatenate_outputs=False):
@@ -281,11 +259,15 @@ def make_train_function(self, force=False):
281259
if self.train_function is not None and not force:
282260
return
283261
if not self.run_eagerly and self.jit_compile:
284-
# Note that we mark the state to be donated to jax,
285-
# so that jax will reuse the memory buffer for outputs.
286-
# This will reduce the memory usage of the training function by
287-
# half.
288-
train_step = jit(self.train_step, donate_argnums=0)
262+
out_shardings = None
263+
if distribution_lib.distribution() is not None:
264+
state_shardings = self._get_state_sharding_spec()
265+
out_shardings = (None, state_shardings)
266+
train_step = jit(
267+
self.train_step,
268+
donate_argnums=0,
269+
out_shardings=out_shardings,
270+
)
289271
else:
290272
train_step = self.train_step
291273

@@ -297,12 +279,25 @@ def make_test_function(self, force=False):
297279
if self.test_function is not None and not force:
298280
return
299281
if not self.run_eagerly and self.jit_compile:
300-
# Note that we mark the state to be donated to jax,
301-
# so that jax will reuse the memory buffer for outputs.
302-
# This will reduce the memory usage of the training function by
303-
# half.
304-
test_step = jit(self.test_step, donate_argnums=0)
305-
282+
out_shardings = None
283+
if distribution_lib.distribution() is not None:
284+
(
285+
trainable_shardings,
286+
non_trainable_shardings,
287+
_, # optimizer_shardings
288+
metrics_shardings,
289+
) = self._get_state_sharding_spec()
290+
state_shardings = (
291+
trainable_shardings,
292+
non_trainable_shardings,
293+
metrics_shardings,
294+
)
295+
out_shardings = (None, state_shardings)
296+
test_step = jit(
297+
self.test_step,
298+
donate_argnums=0,
299+
out_shardings=out_shardings,
300+
)
306301
else:
307302
test_step = self.test_step
308303

@@ -319,7 +314,24 @@ def predict_step(state, data):
319314
return outputs, (state[0], non_trainable_variables)
320315

321316
if not self.run_eagerly and self.jit_compile:
322-
predict_step = jit(predict_step, donate_argnums=0)
317+
out_shardings = None
318+
if distribution_lib.distribution() is not None:
319+
(
320+
trainable_shardings,
321+
non_trainable_shardings,
322+
_, # optimizer_shardings
323+
_, # metrics_shardings
324+
) = self._get_state_sharding_spec()
325+
state_shardings = (
326+
trainable_shardings,
327+
non_trainable_shardings,
328+
)
329+
out_shardings = (None, state_shardings)
330+
predict_step = jit(
331+
predict_step,
332+
donate_argnums=0,
333+
out_shardings=out_shardings,
334+
)
323335

324336
_step_function = self._make_function(
325337
predict_step, concatenate_outputs=True
@@ -402,7 +414,6 @@ def fit(
402414
steps=epoch_iterator.num_batches,
403415
model=self,
404416
)
405-
self._record_training_state_sharding_spec()
406417

407418
self.make_train_function()
408419
self.stop_training = False
@@ -518,7 +529,6 @@ def fit(
518529
if training_finished:
519530
callbacks.on_train_end(logs=training_logs)
520531
self._jax_state = None
521-
self._clear_jax_state_sharding()
522532
return self.history
523533

524534
@traceback_utils.filter_traceback
@@ -568,7 +578,6 @@ def evaluate(
568578
steps=epoch_iterator.num_batches,
569579
model=self,
570580
)
571-
self._record_training_state_sharding_spec()
572581

573582
self.make_test_function()
574583
self.stop_evaluating = False
@@ -620,9 +629,6 @@ def evaluate(
620629
logs = self._get_metrics_result_or_logs(logs)
621630
callbacks.on_test_end(logs)
622631
self._jax_state = None
623-
if not use_cached_eval_dataset:
624-
# Only clear sharding if evaluate is not called from `fit`.
625-
self._clear_jax_state_sharding()
626632
if return_dict:
627633
return logs
628634
return self._flatten_metrics_in_order(logs)
@@ -664,7 +670,6 @@ def predict(
664670
steps=epoch_iterator.num_batches,
665671
model=self,
666672
)
667-
self._record_training_state_sharding_spec()
668673

669674
self.make_predict_function()
670675
self.stop_predicting = False
@@ -723,7 +728,6 @@ def append_to_outputs(batch_outputs, outputs):
723728
self.jax_state_sync()
724729
callbacks.on_predict_end()
725730
self._jax_state = None
726-
self._clear_jax_state_sharding()
727731
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)
728732

729733
def train_on_batch(
@@ -752,7 +756,6 @@ def data():
752756

753757
# Maybe build model
754758
self._symbolic_build(data_batch=next(data()))
755-
self._record_training_state_sharding_spec()
756759
self.make_train_function()
757760

758761
# Train step
@@ -801,7 +804,6 @@ def data():
801804

802805
# Maybe build model
803806
self._symbolic_build(data_batch=next(data()))
804-
self._record_training_state_sharding_spec()
805807
self.make_test_function()
806808

807809
# Test step
@@ -834,7 +836,6 @@ def predict_on_batch(self, x):
834836
# Build model
835837
with backend.StatelessScope():
836838
self(x)
837-
self._record_training_state_sharding_spec()
838839
self.make_predict_function()
839840

840841
state = self._get_jax_state(
@@ -884,75 +885,25 @@ def jax_state_sync(self):
884885
ref_v.assign(v)
885886
self._jax_state_synced = True
886887

887-
def _record_training_state_sharding_spec(self):
888-
self._trainable_variable_shardings = [
888+
def _get_state_sharding_spec(self):
889+
trainable_shardings = [
889890
v.value.sharding for v in self.trainable_variables
890891
]
891-
self._non_trainable_variable_shardings = [
892+
non_trainable_shardings = [
892893
v.value.sharding for v in self.non_trainable_variables
893894
]
894895
if hasattr(self, "optimizer") and self.optimizer is not None:
895-
self._optimizer_variable_shardings = [
896+
optimizer_shardings = [
896897
v.value.sharding for v in self.optimizer.variables
897898
]
898899
else:
899-
self._optimizer_variable_shardings = []
900-
self._metrics_variable_shardings = [
901-
v.value.sharding for v in self.metrics_variables
902-
]
903-
904-
def _clear_jax_state_sharding(self):
905-
self._trainable_variable_shardings = None
906-
self._non_trainable_variable_shardings = None
907-
self._optimizer_variable_shardings = None
908-
self._metrics_variable_shardings = None
909-
910-
def _enforce_jax_state_sharding(
911-
self,
912-
trainable_variables=None,
913-
non_trainable_variables=None,
914-
optimizer_variables=None,
915-
metrics_variables=None,
916-
):
917-
"""Enforce the sharding spec constraint for all the training state.
918-
919-
Since the output of the train/eval step will be used as inputs to next
920-
step, we need to ensure that they have the same sharding spec, so that
921-
nnx.jit/jax.jit won't have to recompile the train/eval function.
922-
923-
Note that this function will also rely on the recorded sharding spec
924-
for each of states.
925-
926-
This function is expected to be called within the jitted train/eval
927-
function, especially around the end of the function.
928-
"""
929-
trainable_variables = trainable_variables or []
930-
non_trainable_variables = non_trainable_variables or []
931-
optimizer_variables = optimizer_variables or []
932-
metrics_variables = metrics_variables or []
933-
934-
for i in range(len(trainable_variables)):
935-
trainable_variables[i] = jax.lax.with_sharding_constraint(
936-
trainable_variables[i], self._trainable_variable_shardings[i]
937-
)
938-
for i in range(len(non_trainable_variables)):
939-
non_trainable_variables[i] = jax.lax.with_sharding_constraint(
940-
non_trainable_variables[i],
941-
self._non_trainable_variable_shardings[i],
942-
)
943-
for i in range(len(optimizer_variables)):
944-
optimizer_variables[i] = jax.lax.with_sharding_constraint(
945-
optimizer_variables[i], self._optimizer_variable_shardings[i]
946-
)
947-
for i in range(len(metrics_variables)):
948-
metrics_variables[i] = jax.lax.with_sharding_constraint(
949-
metrics_variables[i], self._metrics_variable_shardings[i]
950-
)
900+
optimizer_shardings = []
901+
metrics_shardings = [v.value.sharding for v in self.metrics_variables]
951902
return (
952-
trainable_variables,
953-
non_trainable_variables,
954-
optimizer_variables,
955-
metrics_variables,
903+
trainable_shardings,
904+
non_trainable_shardings,
905+
optimizer_shardings,
906+
metrics_shardings,
956907
)
957908

958909
def _purge_model_variables(

keras/src/trainers/trainer_test.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from unittest import mock
22

3+
import jax
34
import numpy as np
45
import pytest
56
from absl.testing import parameterized
@@ -17,12 +18,17 @@
1718
from keras.src.backend import config
1819
from keras.src.backend.common.symbolic_scope import in_symbolic_scope
1920
from keras.src.callbacks.callback import Callback
21+
from keras.src.distribution.distribution_lib import DataParallel
22+
from keras.src.distribution.distribution_lib import DeviceMesh
2023
from keras.src.optimizers.rmsprop import RMSprop
24+
from keras.src.testing import test_case
2125
from keras.src.testing.test_utils import named_product
2226
from keras.src.trainers.data_adapters import py_dataset_adapter
2327

2428
if backend.backend() == "jax":
2529
from keras.src.backend.jax.trainer import JAXTrainer as Trainer
30+
from keras.src.distribution import DataParallel
31+
from keras.src.distribution import DeviceMesh
2632
elif backend.backend() == "torch":
2733
from keras.src.backend.torch.trainer import TorchTrainer as Trainer
2834
elif backend.backend() == "tensorflow":
@@ -2857,3 +2863,53 @@ def predict_step(self, *args):
28572863
verbose=0,
28582864
)
28592865
self.assertLessEqual(tracing_count[0], 2)
2866+
2867+
2868+
class JAXTrainerCorrectnessTest(test_case.TestCase, parameterized.TestCase):
2869+
@parameterized.named_parameters(
2870+
("single_device", False),
2871+
("distributed", True),
2872+
)
2873+
def test_jit_fit_with_out_shardings_logic(self, distributed):
2874+
if keras.backend.backend() != "jax":
2875+
self.skipTest("This test requires the JAX backend.")
2876+
x = np.random.rand(64, 8).astype("float32")
2877+
y = np.random.rand(64, 1).astype("float32")
2878+
2879+
distribution = None
2880+
if distributed:
2881+
if len(jax.local_devices()) < 2:
2882+
self.skipTest(
2883+
"Distributed test requires at least 2 JAX devices."
2884+
)
2885+
2886+
devices = jax.local_devices()
2887+
mesh = DeviceMesh(
2888+
shape=(len(devices),), axis_names=("batch",), devices=devices
2889+
)
2890+
distribution = DataParallel(mesh)
2891+
2892+
scope = distribution.scope() if distribution else mock.MagicMock()
2893+
2894+
with scope:
2895+
model = models.Sequential(
2896+
[
2897+
layers.Dense(4, activation="relu", input_shape=(8,)),
2898+
layers.Dense(1),
2899+
]
2900+
)
2901+
model.compile(optimizer="adam", loss="mse", jit_compile=True)
2902+
2903+
if distribution:
2904+
expected_shardings = [
2905+
v.value.sharding for v in model.trainable_variables
2906+
]
2907+
self.assertNotEqual(len(set(expected_shardings)), 1)
2908+
2909+
model.fit(x, y, epochs=2, batch_size=32, verbose=0)
2910+
2911+
if distribution:
2912+
actual_shardings = [
2913+
v.value.sharding for v in model.trainable_variables
2914+
]
2915+
self.assertListEqual(actual_shardings, expected_shardings)

0 commit comments

Comments
 (0)