Skip to content
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
147 changes: 47 additions & 100 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def train_step(self, state, data):
metrics_variables, unscaled_loss, x, y, y_pred, sample_weight
)

state = self._enforce_jax_state_sharding(
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
Expand Down Expand Up @@ -185,17 +185,6 @@ def test_step(self, state, data):
metrics_variables, unscaled_loss, x, y, y_pred, sample_weight
)

(
trainable_variables,
non_trainable_variables,
_,
metrics_variables,
) = self._enforce_jax_state_sharding(
trainable_variables=trainable_variables,
non_trainable_variables=non_trainable_variables,
optimizer_variables=None,
metrics_variables=metrics_variables,
)
state = (
trainable_variables,
non_trainable_variables,
Expand All @@ -213,17 +202,6 @@ def predict_step(self, state, data):
outputs, non_trainable_variables = self.stateless_call(
trainable_variables, non_trainable_variables, x, **kwargs
)
(
_,
non_trainable_variables,
_,
_,
) = self._enforce_jax_state_sharding(
trainable_variables=None,
non_trainable_variables=non_trainable_variables,
optimizer_variables=None,
metrics_variables=None,
)
return outputs, non_trainable_variables

def _make_function(self, step_function, concatenate_outputs=False):
Expand Down Expand Up @@ -281,11 +259,15 @@ def make_train_function(self, force=False):
if self.train_function is not None and not force:
return
if not self.run_eagerly and self.jit_compile:
# Note that we mark the state to be donated to jax,
# so that jax will reuse the memory buffer for outputs.
# This will reduce the memory usage of the training function by
# half.
train_step = jit(self.train_step, donate_argnums=0)
out_shardings = None
if distribution_lib.distribution() is not None:
state_shardings = self._record_training_state_sharding_spec()
out_shardings = (None, state_shardings)
train_step = jit(
self.train_step,
donate_argnums=0,
out_shardings=out_shardings,
)
else:
train_step = self.train_step

Expand All @@ -297,12 +279,25 @@ def make_test_function(self, force=False):
if self.test_function is not None and not force:
return
if not self.run_eagerly and self.jit_compile:
# Note that we mark the state to be donated to jax,
# so that jax will reuse the memory buffer for outputs.
# This will reduce the memory usage of the training function by
# half.
test_step = jit(self.test_step, donate_argnums=0)

out_shardings = None
if distribution_lib.distribution() is not None:
(
trainable_shardings,
non_trainable_shardings,
_, # optimizer_shardings
metrics_shardings,
) = self._record_training_state_sharding_spec()
state_shardings = (
trainable_shardings,
non_trainable_shardings,
metrics_shardings,
)
out_shardings = (None, state_shardings)
test_step = jit(
self.test_step,
donate_argnums=0,
out_shardings=out_shardings,
)
else:
test_step = self.test_step

Expand All @@ -319,7 +314,24 @@ def predict_step(state, data):
return outputs, (state[0], non_trainable_variables)

if not self.run_eagerly and self.jit_compile:
predict_step = jit(predict_step, donate_argnums=0)
out_shardings = None
if distribution_lib.distribution() is not None:
(
trainable_shardings,
non_trainable_shardings,
_, # optimizer_shardings
_, # metrics_shardings
) = self._record_training_state_sharding_spec()
state_shardings = (
trainable_shardings,
non_trainable_shardings,
)
out_shardings = (None, state_shardings)
predict_step = jit(
predict_step,
donate_argnums=0,
out_shardings=out_shardings,
)

_step_function = self._make_function(
predict_step, concatenate_outputs=True
Expand Down Expand Up @@ -402,7 +414,6 @@ def fit(
steps=epoch_iterator.num_batches,
model=self,
)
self._record_training_state_sharding_spec()

self.make_train_function()
self.stop_training = False
Expand Down Expand Up @@ -518,7 +529,6 @@ def fit(
if training_finished:
callbacks.on_train_end(logs=training_logs)
self._jax_state = None
self._clear_jax_state_sharding()
return self.history

@traceback_utils.filter_traceback
Expand Down Expand Up @@ -568,7 +578,6 @@ def evaluate(
steps=epoch_iterator.num_batches,
model=self,
)
self._record_training_state_sharding_spec()

self.make_test_function()
self.stop_evaluating = False
Expand Down Expand Up @@ -620,9 +629,6 @@ def evaluate(
logs = self._get_metrics_result_or_logs(logs)
callbacks.on_test_end(logs)
self._jax_state = None
if not use_cached_eval_dataset:
# Only clear sharding if evaluate is not called from `fit`.
self._clear_jax_state_sharding()
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)
Expand Down Expand Up @@ -664,7 +670,6 @@ def predict(
steps=epoch_iterator.num_batches,
model=self,
)
self._record_training_state_sharding_spec()

self.make_predict_function()
self.stop_predicting = False
Expand Down Expand Up @@ -723,7 +728,6 @@ def append_to_outputs(batch_outputs, outputs):
self.jax_state_sync()
callbacks.on_predict_end()
self._jax_state = None
self._clear_jax_state_sharding()
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)

def train_on_batch(
Expand Down Expand Up @@ -752,7 +756,6 @@ def data():

# Maybe build model
self._symbolic_build(data_batch=next(data()))
self._record_training_state_sharding_spec()
self.make_train_function()

# Train step
Expand Down Expand Up @@ -801,7 +804,6 @@ def data():

# Maybe build model
self._symbolic_build(data_batch=next(data()))
self._record_training_state_sharding_spec()
self.make_test_function()

# Test step
Expand Down Expand Up @@ -834,7 +836,6 @@ def predict_on_batch(self, x):
# Build model
with backend.StatelessScope():
self(x)
self._record_training_state_sharding_spec()
self.make_predict_function()

state = self._get_jax_state(
Expand Down Expand Up @@ -901,60 +902,6 @@ def _record_training_state_sharding_spec(self):
v.value.sharding for v in self.metrics_variables
]

def _clear_jax_state_sharding(self):
self._trainable_variable_shardings = None
self._non_trainable_variable_shardings = None
self._optimizer_variable_shardings = None
self._metrics_variable_shardings = None

def _enforce_jax_state_sharding(
self,
trainable_variables=None,
non_trainable_variables=None,
optimizer_variables=None,
metrics_variables=None,
):
"""Enforce the sharding spec constraint for all the training state.

Since the output of the train/eval step will be used as inputs to next
step, we need to ensure that they have the same sharding spec, so that
nnx.jit/jax.jit won't have to recompile the train/eval function.

Note that this function will also rely on the recorded sharding spec
for each of states.

This function is expected to be called within the jitted train/eval
function, especially around the end of the function.
"""
trainable_variables = trainable_variables or []
non_trainable_variables = non_trainable_variables or []
optimizer_variables = optimizer_variables or []
metrics_variables = metrics_variables or []

for i in range(len(trainable_variables)):
trainable_variables[i] = jax.lax.with_sharding_constraint(
trainable_variables[i], self._trainable_variable_shardings[i]
)
for i in range(len(non_trainable_variables)):
non_trainable_variables[i] = jax.lax.with_sharding_constraint(
non_trainable_variables[i],
self._non_trainable_variable_shardings[i],
)
for i in range(len(optimizer_variables)):
optimizer_variables[i] = jax.lax.with_sharding_constraint(
optimizer_variables[i], self._optimizer_variable_shardings[i]
)
for i in range(len(metrics_variables)):
metrics_variables[i] = jax.lax.with_sharding_constraint(
metrics_variables[i], self._metrics_variable_shardings[i]
)
return (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
)

def _purge_model_variables(
self,
trainable_variables=False,
Expand Down
71 changes: 71 additions & 0 deletions keras/src/trainers/trainer_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import types
from unittest import mock

import jax
import numpy as np
import pytest
from absl.testing import parameterized
Expand All @@ -16,13 +18,16 @@
from keras.src import testing
from keras.src.backend import config
from keras.src.backend.common.symbolic_scope import in_symbolic_scope
from keras.src.backend.jax import trainer as jax_trainer_lib
from keras.src.callbacks.callback import Callback
from keras.src.optimizers.rmsprop import RMSprop
from keras.src.testing.test_utils import named_product
from keras.src.trainers.data_adapters import py_dataset_adapter

if backend.backend() == "jax":
from keras.src.backend.jax.trainer import JAXTrainer as Trainer
from keras.src.distribution import DataParallel
from keras.src.distribution import DeviceMesh
elif backend.backend() == "torch":
from keras.src.backend.torch.trainer import TorchTrainer as Trainer
elif backend.backend() == "tensorflow":
Expand Down Expand Up @@ -2857,3 +2862,69 @@ def predict_step(self, *args):
verbose=0,
)
self.assertLessEqual(tracing_count[0], 2)

@parameterized.named_parameters(
("single_device", False),
("distributed", True),
)
def test_jit_fit_with_out_shardings_logic(self, distributed):
def patched_get_jax_state(
self_model,
trainable_variables=False,
non_trainable_variables=False,
optimizer_variables=False,
metrics_variables=False,
purge_model_variables=False,
):
state = []
if trainable_variables:
state.append([v.value for v in self_model.trainable_variables])
if non_trainable_variables:
state.append(
[v.value for v in self_model.non_trainable_variables]
)
if optimizer_variables:
state.append([v.value for v in self_model.optimizer.variables])
if metrics_variables:
state.append([v.value for v in self_model.metrics_variables])
return tuple(state)

original_train_step = jax_trainer_lib.JAXTrainer.train_step

def patched_train_step_wrapper(self, state, data):
logs, new_state = original_train_step(self, state, data)
return logs, new_state

x = np.random.rand(64, 8).astype("float32")
y = np.random.rand(64, 1).astype("float32")

if distributed:
if len(jax.local_devices()) < 2:
self.skipTest(
"Distributed test requires at least 2 JAX devices."
)

devices = jax.local_devices()
mesh = DeviceMesh(
shape=(len(devices),), axis_names=("batch",), devices=devices
)
distribution = DataParallel(mesh)

with (
mock.patch(
"keras.src.backend.jax.trainer.JAXTrainer.train_step",
new=patched_train_step_wrapper,
),
distribution.scope(),
):
model = models.Sequential(
[
layers.Dense(4, activation="relu", input_shape=(8,)),
layers.Dense(1),
]
)
model._get_jax_state = types.MethodType(
patched_get_jax_state, model
)
model.compile(optimizer="adam", loss="mse", jit_compile=True)
model.fit(x, y, epochs=2, batch_size=32, verbose=0)