Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 52 additions & 89 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def train_step(self, state, data):
metrics_variables, unscaled_loss, x, y, y_pred, sample_weight
)

state = self._enforce_jax_state_sharding(
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
Expand Down Expand Up @@ -185,17 +185,6 @@ def test_step(self, state, data):
metrics_variables, unscaled_loss, x, y, y_pred, sample_weight
)

(
trainable_variables,
non_trainable_variables,
_,
metrics_variables,
) = self._enforce_jax_state_sharding(
trainable_variables=trainable_variables,
non_trainable_variables=non_trainable_variables,
optimizer_variables=None,
metrics_variables=metrics_variables,
)
state = (
trainable_variables,
non_trainable_variables,
Expand All @@ -213,17 +202,6 @@ def predict_step(self, state, data):
outputs, non_trainable_variables = self.stateless_call(
trainable_variables, non_trainable_variables, x, **kwargs
)
(
_,
non_trainable_variables,
_,
_,
) = self._enforce_jax_state_sharding(
trainable_variables=None,
non_trainable_variables=non_trainable_variables,
optimizer_variables=None,
metrics_variables=None,
)
return outputs, non_trainable_variables

def _make_function(self, step_function, concatenate_outputs=False):
Expand Down Expand Up @@ -277,38 +255,57 @@ def iterator_step(state, iterator):

return iterator_step

def _get_out_shardings_for_step(self, state_shardings):
"""Helper to create the out_shardings for a jitted step function."""
if distribution_lib.distribution() is not None and hasattr(
self, "_metrics_result_structure"
):
return (
tree.map_structure(lambda _: None, self._metrics_result_structure),
state_shardings,
)
return None

def make_train_function(self, force=False):
if self.train_function is not None and not force:
return

state_shardings = (
self._trainable_variable_shardings,
self._non_trainable_variable_shardings,
self._optimizer_variable_shardings,
self._metrics_variable_shardings,
)
out_shardings = self._get_out_shardings_for_step(state_shardings)

if not self.run_eagerly and self.jit_compile:
# Note that we mark the state to be donated to jax,
# so that jax will reuse the memory buffer for outputs.
# This will reduce the memory usage of the training function by
# half.
train_step = jit(self.train_step, donate_argnums=0)
train_step = jit(
self.train_step, donate_argnums=0, out_shardings=out_shardings
)
else:
train_step = self.train_step

step_function = self._make_function(train_step)

self.train_function = step_function
self.train_function = self._make_function(train_step)

def make_test_function(self, force=False):
if self.test_function is not None and not force:
return
if not self.run_eagerly and self.jit_compile:
# Note that we mark the state to be donated to jax,
# so that jax will reuse the memory buffer for outputs.
# This will reduce the memory usage of the training function by
# half.
test_step = jit(self.test_step, donate_argnums=0)

state_shardings = (
self._trainable_variable_shardings,
self._non_trainable_variable_shardings,
self._metrics_variable_shardings,
)
out_shardings = self._get_out_shardings_for_step(state_shardings)

if not self.run_eagerly and self.jit_compile:
test_step = jit(
self.test_step, donate_argnums=0, out_shardings=out_shardings
)
else:
test_step = self.test_step

step_function = self._make_function(test_step)

self.test_function = step_function
self.test_function = self._make_function(test_step)

def make_predict_function(self, force=False):
if self.predict_function is not None and not force:
Expand All @@ -319,7 +316,21 @@ def predict_step(state, data):
return outputs, (state[0], non_trainable_variables)

if not self.run_eagerly and self.jit_compile:
predict_step = jit(predict_step, donate_argnums=0)
out_shardings = None
# FIX: Check if the model has been built before accessing sharding attrs
if distribution_lib.distribution() is not None and hasattr(
self, "_trainable_variable_shardings"
):
out_shardings = (
None,
(
self._trainable_variable_shardings,
self._non_trainable_variable_shardings,
),
)
predict_step = jit(
predict_step, donate_argnums=0, out_shardings=out_shardings
)

_step_function = self._make_function(
predict_step, concatenate_outputs=True
Expand Down Expand Up @@ -907,54 +918,6 @@ def _clear_jax_state_sharding(self):
self._optimizer_variable_shardings = None
self._metrics_variable_shardings = None

def _enforce_jax_state_sharding(
self,
trainable_variables=None,
non_trainable_variables=None,
optimizer_variables=None,
metrics_variables=None,
):
"""Enforce the sharding spec constraint for all the training state.

Since the output of the train/eval step will be used as inputs to next
step, we need to ensure that they have the same sharding spec, so that
nnx.jit/jax.jit won't have to recompile the train/eval function.

Note that this function will also rely on the recorded sharding spec
for each of states.

This function is expected to be called within the jitted train/eval
function, especially around the end of the function.
"""
trainable_variables = trainable_variables or []
non_trainable_variables = non_trainable_variables or []
optimizer_variables = optimizer_variables or []
metrics_variables = metrics_variables or []

for i in range(len(trainable_variables)):
trainable_variables[i] = jax.lax.with_sharding_constraint(
trainable_variables[i], self._trainable_variable_shardings[i]
)
for i in range(len(non_trainable_variables)):
non_trainable_variables[i] = jax.lax.with_sharding_constraint(
non_trainable_variables[i],
self._non_trainable_variable_shardings[i],
)
for i in range(len(optimizer_variables)):
optimizer_variables[i] = jax.lax.with_sharding_constraint(
optimizer_variables[i], self._optimizer_variable_shardings[i]
)
for i in range(len(metrics_variables)):
metrics_variables[i] = jax.lax.with_sharding_constraint(
metrics_variables[i], self._metrics_variable_shardings[i]
)
return (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
)

def _purge_model_variables(
self,
trainable_variables=False,
Expand Down
72 changes: 72 additions & 0 deletions keras/src/trainers/trainer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

if backend.backend() == "jax":
from keras.src.backend.jax.trainer import JAXTrainer as Trainer
from keras.src.distribution import DeviceMesh, TensorLayout, distribution_lib
elif backend.backend() == "torch":
from keras.src.backend.torch.trainer import TorchTrainer as Trainer
elif backend.backend() == "tensorflow":
Expand Down Expand Up @@ -2857,3 +2858,74 @@ def predict_step(self, *args):
verbose=0,
)
self.assertLessEqual(tracing_count[0], 2)

class ExampleModelForJAXTrainerShardingTest(models.Model):
def __init__(self, units=3, **kwargs):
super().__init__(**kwargs)
self.dense1 = layers.Dense(4, activation="relu", kernel_initializer="ones")
self.dense2 = layers.Dense(units, activation="softmax", kernel_initializer="ones")

def call(self, x):
return self.dense2(self.dense1(x))

@pytest.mark.skipif(
backend.backend() != "jax",
reason="This is a JAX-specific distribution test.",
)
class JAXTrainerShardingTest(testing.TestCase):

def setUp(self):
super().setUp()
import jax

if jax.device_count() < 2:
self.skipTest(
"Cannot test sharding with less than 2 devices. "
f"Found {jax.device_count()} devices."
)

devices = np.array(jax.devices())
device_mesh = DeviceMesh(
shape=(jax.device_count(),),
axis_names=("batch",),
devices=devices.flatten(),
)
data_layout_2d = TensorLayout(axes=("batch", None), device_mesh=device_mesh)
data_layout_1d = TensorLayout(axes=("batch",), device_mesh=device_mesh)
variable_layout = TensorLayout(axes=(None, None), device_mesh=device_mesh)

def get_layout_for_data(shape):
if not hasattr(shape, '__len__'):
return variable_layout
if len(shape) == 2:
return data_layout_2d
elif len(shape) == 1:
return data_layout_1d
return variable_layout

mock_dist = mock.MagicMock()
mock_dist.get_data_layout.side_effect = get_layout_for_data
mock_dist.get_tensor_layout.return_value = variable_layout
mock_dist.auto_shard_dataset = False
self.distribution_mock = mock_dist

@pytest.mark.requires_trainable_backend
def test_fit_with_sharding(self):
with mock.patch.object(
distribution_lib, "distribution", return_value=self.distribution_mock
):
model = ExampleModelForJAXTrainerShardingTest(units=3)
model.compile(
optimizer="sgd",
loss="mse",
jit_compile=True
)

x = np.ones((16, 5), dtype="float32")
y = np.zeros((16, 3), dtype="float32")
sw = np.ones((16,), dtype="float32")

history = model.fit(x, y, sample_weight=sw, batch_size=4, epochs=2)

self.assertIn("loss", history.history)
self.assertEqual(len(history.history["loss"]), 2)