Skip to content

Refactor(JAX): Use jax.jit's out_shardings instead of _enforce_jax_state_sharding #21559

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 57 additions & 106 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def train_step(self, state, data):
metrics_variables, unscaled_loss, x, y, y_pred, sample_weight
)

state = self._enforce_jax_state_sharding(
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
Expand Down Expand Up @@ -185,17 +185,6 @@ def test_step(self, state, data):
metrics_variables, unscaled_loss, x, y, y_pred, sample_weight
)

(
trainable_variables,
non_trainable_variables,
_,
metrics_variables,
) = self._enforce_jax_state_sharding(
trainable_variables=trainable_variables,
non_trainable_variables=non_trainable_variables,
optimizer_variables=None,
metrics_variables=metrics_variables,
)
state = (
trainable_variables,
non_trainable_variables,
Expand All @@ -213,17 +202,6 @@ def predict_step(self, state, data):
outputs, non_trainable_variables = self.stateless_call(
trainable_variables, non_trainable_variables, x, **kwargs
)
(
_,
non_trainable_variables,
_,
_,
) = self._enforce_jax_state_sharding(
trainable_variables=None,
non_trainable_variables=non_trainable_variables,
optimizer_variables=None,
metrics_variables=None,
)
return outputs, non_trainable_variables

def _make_function(self, step_function, concatenate_outputs=False):
Expand Down Expand Up @@ -281,11 +259,15 @@ def make_train_function(self, force=False):
if self.train_function is not None and not force:
return
if not self.run_eagerly and self.jit_compile:
# Note that we mark the state to be donated to jax,
# so that jax will reuse the memory buffer for outputs.
# This will reduce the memory usage of the training function by
# half.
train_step = jit(self.train_step, donate_argnums=0)
out_shardings = None
if distribution_lib.distribution() is not None:
state_shardings = self._get_state_sharding_spec()
out_shardings = (None, state_shardings)
train_step = jit(
self.train_step,
donate_argnums=0,
out_shardings=out_shardings,
)
else:
train_step = self.train_step

Expand All @@ -297,12 +279,25 @@ def make_test_function(self, force=False):
if self.test_function is not None and not force:
return
if not self.run_eagerly and self.jit_compile:
# Note that we mark the state to be donated to jax,
# so that jax will reuse the memory buffer for outputs.
# This will reduce the memory usage of the training function by
# half.
test_step = jit(self.test_step, donate_argnums=0)

out_shardings = None
if distribution_lib.distribution() is not None:
(
trainable_shardings,
non_trainable_shardings,
_, # optimizer_shardings
metrics_shardings,
) = self._get_state_sharding_spec()
state_shardings = (
trainable_shardings,
non_trainable_shardings,
metrics_shardings,
)
out_shardings = (None, state_shardings)
test_step = jit(
self.test_step,
donate_argnums=0,
out_shardings=out_shardings,
)
else:
test_step = self.test_step

Expand All @@ -319,7 +314,24 @@ def predict_step(state, data):
return outputs, (state[0], non_trainable_variables)

if not self.run_eagerly and self.jit_compile:
predict_step = jit(predict_step, donate_argnums=0)
out_shardings = None
if distribution_lib.distribution() is not None:
(
trainable_shardings,
non_trainable_shardings,
_, # optimizer_shardings
_, # metrics_shardings
) = self._get_state_sharding_spec()
state_shardings = (
trainable_shardings,
non_trainable_shardings,
)
out_shardings = (None, state_shardings)
predict_step = jit(
predict_step,
donate_argnums=0,
out_shardings=out_shardings,
)

_step_function = self._make_function(
predict_step, concatenate_outputs=True
Expand Down Expand Up @@ -402,7 +414,6 @@ def fit(
steps=epoch_iterator.num_batches,
model=self,
)
self._record_training_state_sharding_spec()

self.make_train_function()
self.stop_training = False
Expand Down Expand Up @@ -518,7 +529,6 @@ def fit(
if training_finished:
callbacks.on_train_end(logs=training_logs)
self._jax_state = None
self._clear_jax_state_sharding()
return self.history

@traceback_utils.filter_traceback
Expand Down Expand Up @@ -568,7 +578,6 @@ def evaluate(
steps=epoch_iterator.num_batches,
model=self,
)
self._record_training_state_sharding_spec()

self.make_test_function()
self.stop_evaluating = False
Expand Down Expand Up @@ -620,9 +629,6 @@ def evaluate(
logs = self._get_metrics_result_or_logs(logs)
callbacks.on_test_end(logs)
self._jax_state = None
if not use_cached_eval_dataset:
# Only clear sharding if evaluate is not called from `fit`.
self._clear_jax_state_sharding()
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)
Expand Down Expand Up @@ -664,7 +670,6 @@ def predict(
steps=epoch_iterator.num_batches,
model=self,
)
self._record_training_state_sharding_spec()

self.make_predict_function()
self.stop_predicting = False
Expand Down Expand Up @@ -723,7 +728,6 @@ def append_to_outputs(batch_outputs, outputs):
self.jax_state_sync()
callbacks.on_predict_end()
self._jax_state = None
self._clear_jax_state_sharding()
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)

def train_on_batch(
Expand Down Expand Up @@ -752,7 +756,6 @@ def data():

# Maybe build model
self._symbolic_build(data_batch=next(data()))
self._record_training_state_sharding_spec()
self.make_train_function()

# Train step
Expand Down Expand Up @@ -801,7 +804,6 @@ def data():

# Maybe build model
self._symbolic_build(data_batch=next(data()))
self._record_training_state_sharding_spec()
self.make_test_function()

# Test step
Expand Down Expand Up @@ -834,7 +836,6 @@ def predict_on_batch(self, x):
# Build model
with backend.StatelessScope():
self(x)
self._record_training_state_sharding_spec()
self.make_predict_function()

state = self._get_jax_state(
Expand Down Expand Up @@ -884,75 +885,25 @@ def jax_state_sync(self):
ref_v.assign(v)
self._jax_state_synced = True

def _record_training_state_sharding_spec(self):
self._trainable_variable_shardings = [
def _get_state_sharding_spec(self):
trainable_shardings = [
v.value.sharding for v in self.trainable_variables
]
self._non_trainable_variable_shardings = [
non_trainable_shardings = [
v.value.sharding for v in self.non_trainable_variables
]
if hasattr(self, "optimizer") and self.optimizer is not None:
self._optimizer_variable_shardings = [
optimizer_shardings = [
v.value.sharding for v in self.optimizer.variables
]
else:
self._optimizer_variable_shardings = []
self._metrics_variable_shardings = [
v.value.sharding for v in self.metrics_variables
]

def _clear_jax_state_sharding(self):
self._trainable_variable_shardings = None
self._non_trainable_variable_shardings = None
self._optimizer_variable_shardings = None
self._metrics_variable_shardings = None

def _enforce_jax_state_sharding(
self,
trainable_variables=None,
non_trainable_variables=None,
optimizer_variables=None,
metrics_variables=None,
):
"""Enforce the sharding spec constraint for all the training state.

Since the output of the train/eval step will be used as inputs to next
step, we need to ensure that they have the same sharding spec, so that
nnx.jit/jax.jit won't have to recompile the train/eval function.

Note that this function will also rely on the recorded sharding spec
for each of states.

This function is expected to be called within the jitted train/eval
function, especially around the end of the function.
"""
trainable_variables = trainable_variables or []
non_trainable_variables = non_trainable_variables or []
optimizer_variables = optimizer_variables or []
metrics_variables = metrics_variables or []

for i in range(len(trainable_variables)):
trainable_variables[i] = jax.lax.with_sharding_constraint(
trainable_variables[i], self._trainable_variable_shardings[i]
)
for i in range(len(non_trainable_variables)):
non_trainable_variables[i] = jax.lax.with_sharding_constraint(
non_trainable_variables[i],
self._non_trainable_variable_shardings[i],
)
for i in range(len(optimizer_variables)):
optimizer_variables[i] = jax.lax.with_sharding_constraint(
optimizer_variables[i], self._optimizer_variable_shardings[i]
)
for i in range(len(metrics_variables)):
metrics_variables[i] = jax.lax.with_sharding_constraint(
metrics_variables[i], self._metrics_variable_shardings[i]
)
optimizer_shardings = []
metrics_shardings = [v.value.sharding for v in self.metrics_variables]
return (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
trainable_shardings,
non_trainable_shardings,
optimizer_shardings,
metrics_shardings,
)

def _purge_model_variables(
Expand Down
56 changes: 56 additions & 0 deletions keras/src/trainers/trainer_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from unittest import mock

import jax
import numpy as np
import pytest
from absl.testing import parameterized
Expand All @@ -17,12 +18,17 @@
from keras.src.backend import config
from keras.src.backend.common.symbolic_scope import in_symbolic_scope
from keras.src.callbacks.callback import Callback
from keras.src.distribution.distribution_lib import DataParallel
from keras.src.distribution.distribution_lib import DeviceMesh
from keras.src.optimizers.rmsprop import RMSprop
from keras.src.testing import test_case
from keras.src.testing.test_utils import named_product
from keras.src.trainers.data_adapters import py_dataset_adapter

if backend.backend() == "jax":
from keras.src.backend.jax.trainer import JAXTrainer as Trainer
from keras.src.distribution import DataParallel
from keras.src.distribution import DeviceMesh
elif backend.backend() == "torch":
from keras.src.backend.torch.trainer import TorchTrainer as Trainer
elif backend.backend() == "tensorflow":
Expand Down Expand Up @@ -2857,3 +2863,53 @@ def predict_step(self, *args):
verbose=0,
)
self.assertLessEqual(tracing_count[0], 2)


class JAXTrainerCorrectnessTest(test_case.TestCase, parameterized.TestCase):
@parameterized.named_parameters(
("single_device", False),
("distributed", True),
)
def test_jit_fit_with_out_shardings_logic(self, distributed):
if keras.backend.backend() != "jax":
self.skipTest("This test requires the JAX backend.")
x = np.random.rand(64, 8).astype("float32")
y = np.random.rand(64, 1).astype("float32")

distribution = None
if distributed:
if len(jax.local_devices()) < 2:
self.skipTest(
"Distributed test requires at least 2 JAX devices."
)

devices = jax.local_devices()
mesh = DeviceMesh(
shape=(len(devices),), axis_names=("batch",), devices=devices
)
distribution = DataParallel(mesh)

scope = distribution.scope() if distribution else mock.MagicMock()

with scope:
model = models.Sequential(
[
layers.Dense(4, activation="relu", input_shape=(8,)),
layers.Dense(1),
]
)
model.compile(optimizer="adam", loss="mse", jit_compile=True)

if distribution:
expected_shardings = [
v.value.sharding for v in model.trainable_variables
]
self.assertNotEqual(len(set(expected_shardings)), 1)

model.fit(x, y, epochs=2, batch_size=32, verbose=0)

if distribution:
actual_shardings = [
v.value.sharding for v in model.trainable_variables
]
self.assertListEqual(actual_shardings, expected_shardings)