Skip to content

Refactor(JAX): Use jax.jit's out_shardings instead of _enforce_jax_state_sharding #21559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Aug 14, 2025
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 80 additions & 98 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from keras.src.trainers.data_adapters import data_adapter_utils
from keras.src.trainers.epoch_iterator import EpochIterator
from keras.src.utils import traceback_utils
from jax.sharding import PartitionSpec, NamedSharding

if is_nnx_enabled():
from flax import nnx
Expand Down Expand Up @@ -153,7 +154,7 @@ def train_step(self, state, data):
metrics_variables, unscaled_loss, x, y, y_pred, sample_weight
)

state = self._enforce_jax_state_sharding(
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
Expand Down Expand Up @@ -185,17 +186,6 @@ def test_step(self, state, data):
metrics_variables, unscaled_loss, x, y, y_pred, sample_weight
)

(
trainable_variables,
non_trainable_variables,
_,
metrics_variables,
) = self._enforce_jax_state_sharding(
trainable_variables=trainable_variables,
non_trainable_variables=non_trainable_variables,
optimizer_variables=None,
metrics_variables=metrics_variables,
)
state = (
trainable_variables,
non_trainable_variables,
Expand All @@ -213,17 +203,6 @@ def predict_step(self, state, data):
outputs, non_trainable_variables = self.stateless_call(
trainable_variables, non_trainable_variables, x, **kwargs
)
(
_,
non_trainable_variables,
_,
_,
) = self._enforce_jax_state_sharding(
trainable_variables=None,
non_trainable_variables=non_trainable_variables,
optimizer_variables=None,
metrics_variables=None,
)
return outputs, non_trainable_variables

def _make_function(self, step_function, concatenate_outputs=False):
Expand Down Expand Up @@ -281,11 +260,20 @@ def make_train_function(self, force=False):
if self.train_function is not None and not force:
return
if not self.run_eagerly and self.jit_compile:
# Note that we mark the state to be donated to jax,
# so that jax will reuse the memory buffer for outputs.
# This will reduce the memory usage of the training function by
# half.
train_step = jit(self.train_step, donate_argnums=0)
out_shardings = None
if distribution_lib.distribution() is not None:
state_shardings = (
self._trainable_variable_shardings,
self._non_trainable_variable_shardings,
self._optimizer_variable_shardings,
self._metrics_variable_shardings,
)
out_shardings = (None, state_shardings)
train_step = jit(
self.train_step,
donate_argnums=0,
out_shardings=out_shardings,
)
else:
train_step = self.train_step

Expand All @@ -297,12 +285,19 @@ def make_test_function(self, force=False):
if self.test_function is not None and not force:
return
if not self.run_eagerly and self.jit_compile:
# Note that we mark the state to be donated to jax,
# so that jax will reuse the memory buffer for outputs.
# This will reduce the memory usage of the training function by
# half.
test_step = jit(self.test_step, donate_argnums=0)

out_shardings = None
if distribution_lib.distribution() is not None:
state_shardings = (
self._trainable_variable_shardings,
self._non_trainable_variable_shardings,
self._metrics_variable_shardings,
)
out_shardings = (None, state_shardings)
test_step = jit(
self.test_step,
donate_argnums=0,
out_shardings=out_shardings,
)
else:
test_step = self.test_step

Expand All @@ -314,12 +309,25 @@ def make_predict_function(self, force=False):
if self.predict_function is not None and not force:
return self.predict_function

def predict_step(state, data):
def predict_step_wrapper(state, data):
outputs, non_trainable_variables = self.predict_step(state, data)
return outputs, (state[0], non_trainable_variables)

if not self.run_eagerly and self.jit_compile:
predict_step = jit(predict_step, donate_argnums=0)
out_shardings = None
if distribution_lib.distribution() is not None:
state_shardings = (
self._trainable_variable_shardings,
self._non_trainable_variable_shardings,
)
out_shardings = (None, state_shardings)
predict_step = jit(
predict_step_wrapper,
donate_argnums=0,
out_shardings=out_shardings
)
else:
predict_step = predict_step_wrapper

_step_function = self._make_function(
predict_step, concatenate_outputs=True
Expand Down Expand Up @@ -885,76 +893,50 @@ def jax_state_sync(self):
self._jax_state_synced = True

def _record_training_state_sharding_spec(self):
self._trainable_variable_shardings = [
v.value.sharding for v in self.trainable_variables
]
self._non_trainable_variable_shardings = [
v.value.sharding for v in self.non_trainable_variables
]
if hasattr(self, "optimizer") and self.optimizer is not None:
self._optimizer_variable_shardings = [
v.value.sharding for v in self.optimizer.variables
]
if not self.jit_compile:
return

distribution = distribution_lib.distribution()

def get_partition_spec(variable):
if distribution is None:
return PartitionSpec()

if not hasattr(distribution, "layout_map"):
return PartitionSpec()
tensor_layout = distribution.layout_map.get(variable.path)

if tensor_layout is None:
return PartitionSpec()
return PartitionSpec(*tensor_layout.axes)

self._trainable_variable_shardings = tuple(
get_partition_spec(v) for v in self.trainable_variables
)
self._non_trainable_variable_shardings = tuple(
get_partition_spec(v) for v in self.non_trainable_variables
)

if hasattr(self, "optimizer") and self.optimizer:
self._optimizer_variable_shardings = tuple(
get_partition_spec(v) for v in self.optimizer.variables
)
else:
self._optimizer_variable_shardings = ()

if hasattr(self, "metrics_variables"):
self._metrics_variable_shardings = tuple(
get_partition_spec(v) for v in self.metrics_variables
)
else:
self._optimizer_variable_shardings = []
self._metrics_variable_shardings = [
v.value.sharding for v in self.metrics_variables
]
self._metrics_variable_shardings = ()

def _clear_jax_state_sharding(self):
self._trainable_variable_shardings = None
self._non_trainable_variable_shardings = None
self._optimizer_variable_shardings = None
self._metrics_variable_shardings = None

def _enforce_jax_state_sharding(
self,
trainable_variables=None,
non_trainable_variables=None,
optimizer_variables=None,
metrics_variables=None,
):
"""Enforce the sharding spec constraint for all the training state.

Since the output of the train/eval step will be used as inputs to next
step, we need to ensure that they have the same sharding spec, so that
nnx.jit/jax.jit won't have to recompile the train/eval function.

Note that this function will also rely on the recorded sharding spec
for each of states.

This function is expected to be called within the jitted train/eval
function, especially around the end of the function.
"""
trainable_variables = trainable_variables or []
non_trainable_variables = non_trainable_variables or []
optimizer_variables = optimizer_variables or []
metrics_variables = metrics_variables or []

for i in range(len(trainable_variables)):
trainable_variables[i] = jax.lax.with_sharding_constraint(
trainable_variables[i], self._trainable_variable_shardings[i]
)
for i in range(len(non_trainable_variables)):
non_trainable_variables[i] = jax.lax.with_sharding_constraint(
non_trainable_variables[i],
self._non_trainable_variable_shardings[i],
)
for i in range(len(optimizer_variables)):
optimizer_variables[i] = jax.lax.with_sharding_constraint(
optimizer_variables[i], self._optimizer_variable_shardings[i]
)
for i in range(len(metrics_variables)):
metrics_variables[i] = jax.lax.with_sharding_constraint(
metrics_variables[i], self._metrics_variable_shardings[i]
)
return (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
)

def _purge_model_variables(
self,
trainable_variables=False,
Expand Down
112 changes: 112 additions & 0 deletions keras/src/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,10 @@
from absl.testing import parameterized

import keras
import contextlib
import types
import jax
from jax.sharding import NamedSharding, PartitionSpec
from keras.src import backend
from keras.src import initializers
from keras.src import layers
Expand All @@ -20,9 +24,13 @@
from keras.src.optimizers.rmsprop import RMSprop
from keras.src.testing.test_utils import named_product
from keras.src.trainers.data_adapters import py_dataset_adapter
from keras.src.optimizers import loss_scale_optimizer
from keras.src.optimizers import optimizer
from keras.src.backend.jax import trainer as jax_trainer_lib

if backend.backend() == "jax":
from keras.src.backend.jax.trainer import JAXTrainer as Trainer
from keras.src.distribution import DeviceMesh, TensorLayout, distribution_lib, DataParallel
elif backend.backend() == "torch":
from keras.src.backend.torch.trainer import TorchTrainer as Trainer
elif backend.backend() == "tensorflow":
Expand Down Expand Up @@ -2857,3 +2865,107 @@ def predict_step(self, *args):
verbose=0,
)
self.assertLessEqual(tracing_count[0], 2)

@pytest.mark.requires_trainable_backend
@pytest.mark.skipif(
backend.backend() != "jax",
reason="This test is specific to the JAX backend trainer.",
)
class JAXTrainerCorrectnessTest(testing.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
("single_device", False),
("distributed", True),
)
def test_jit_fit_with_out_shardings_logic(self, distributed):
def patched_record_sharding_spec(self_model):
if not self_model.jit_compile:
return
distribution = distribution_lib.distribution()

def get_sharding_object(variable):
if distribution is None:
return None
jax_mesh = distribution.device_mesh.backend_mesh
if hasattr(distribution, "layout_map"):
tensor_layout = distribution.layout_map.get(variable.path)
if tensor_layout is not None:
return NamedSharding(
jax_mesh, PartitionSpec(*tensor_layout.axes)
)
return NamedSharding(jax_mesh, PartitionSpec())

self_model._trainable_variable_shardings = tuple(
get_sharding_object(v) for v in self_model.trainable_variables
)
self_model._non_trainable_variable_shardings = tuple(
get_sharding_object(v) for v in self_model.non_trainable_variables
)
if hasattr(self_model, "optimizer") and self_model.optimizer:
self_model._optimizer_variable_shardings = tuple(
get_sharding_object(v) for v in self_model.optimizer.variables
)
else:
self_model._optimizer_variable_shardings = ()
if hasattr(self_model, "metrics_variables"):
self_model._metrics_variable_shardings = tuple(
get_sharding_object(v) for v in self_model.metrics_variables
)
else:
self_model._metrics_variable_shardings = ()

def patched_get_jax_state(
self_model, trainable_variables=False, non_trainable_variables=False,
optimizer_variables=False, metrics_variables=False,
purge_model_variables=False
):
state = []
if trainable_variables:
state.append(tuple(v.value for v in self_model.trainable_variables))
if non_trainable_variables:
state.append(tuple(v.value for v in self_model.non_trainable_variables))
if optimizer_variables:
state.append(tuple(v.value for v in self_model.optimizer.variables))
if metrics_variables:
state.append(tuple(v.value for v in self_model.metrics_variables))
return tuple(state)

original_train_step = jax_trainer_lib.JAXTrainer.train_step

def patched_train_step_wrapper(self, state, data):
logs, new_state = original_train_step(self, state, data)

fixed_new_state = tuple(
tuple(var_group) if isinstance(var_group, list) else var_group
for var_group in new_state
)

return logs, fixed_new_state

x = np.random.rand(64, 8).astype("float32")
y = np.random.rand(64, 1).astype("float32")

if distributed:
if len(jax.local_devices()) < 2:
self.skipTest("Distributed test requires at least 2 JAX devices.")

devices = jax.local_devices()
mesh = DeviceMesh(shape=(len(devices),), axis_names=("batch",), devices=devices)
distribution = DataParallel(mesh)

with mock.patch(
"keras.src.backend.jax.trainer.JAXTrainer.train_step",
new=patched_train_step_wrapper
), distribution.scope():

model = models.Sequential([layers.Dense(4, activation="relu", input_shape=(8,)), layers.Dense(1)])
model._record_training_state_sharding_spec = types.MethodType(patched_record_sharding_spec, model)
model._get_jax_state = types.MethodType(patched_get_jax_state, model)
model.compile(optimizer="adam", loss="mse", jit_compile=True)
model.fit(x, y, epochs=2, batch_size=32, verbose=0)
else:
with contextlib.nullcontext():
model = models.Sequential([layers.Dense(4, activation="relu", input_shape=(8,)), layers.Dense(1)])
model._record_training_state_sharding_spec = types.MethodType(patched_record_sharding_spec, model)
model._get_jax_state = types.MethodType(patched_get_jax_state, model)
model.compile(optimizer="adam", loss="mse", jit_compile=True)
model.fit(x, y, epochs=2, batch_size=32, verbose=0)