Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 82 additions & 111 deletions keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from keras.src.trainers.data_adapters import data_adapter_utils
from keras.src.trainers.epoch_iterator import EpochIterator
from keras.src.utils import traceback_utils
from jax.sharding import PartitionSpec, NamedSharding

if is_nnx_enabled():
from flax import nnx
Expand Down Expand Up @@ -153,7 +154,7 @@ def train_step(self, state, data):
metrics_variables, unscaled_loss, x, y, y_pred, sample_weight
)

state = self._enforce_jax_state_sharding(
state = (
trainable_variables,
non_trainable_variables,
optimizer_variables,
Expand Down Expand Up @@ -185,17 +186,6 @@ def test_step(self, state, data):
metrics_variables, unscaled_loss, x, y, y_pred, sample_weight
)

(
trainable_variables,
non_trainable_variables,
_,
metrics_variables,
) = self._enforce_jax_state_sharding(
trainable_variables=trainable_variables,
non_trainable_variables=non_trainable_variables,
optimizer_variables=None,
metrics_variables=metrics_variables,
)
state = (
trainable_variables,
non_trainable_variables,
Expand All @@ -213,17 +203,6 @@ def predict_step(self, state, data):
outputs, non_trainable_variables = self.stateless_call(
trainable_variables, non_trainable_variables, x, **kwargs
)
(
_,
non_trainable_variables,
_,
_,
) = self._enforce_jax_state_sharding(
trainable_variables=None,
non_trainable_variables=non_trainable_variables,
optimizer_variables=None,
metrics_variables=None,
)
return outputs, non_trainable_variables

def _make_function(self, step_function, concatenate_outputs=False):
Expand Down Expand Up @@ -281,11 +260,15 @@ def make_train_function(self, force=False):
if self.train_function is not None and not force:
return
if not self.run_eagerly and self.jit_compile:
# Note that we mark the state to be donated to jax,
# so that jax will reuse the memory buffer for outputs.
# This will reduce the memory usage of the training function by
# half.
train_step = jit(self.train_step, donate_argnums=0)
out_shardings = None
if distribution_lib.distribution() is not None:
state_shardings = self._get_training_state_shardings()
out_shardings = (None, state_shardings)
train_step = jit(
self.train_step,
donate_argnums=0,
out_shardings=out_shardings,
)
else:
train_step = self.train_step

Expand All @@ -297,12 +280,25 @@ def make_test_function(self, force=False):
if self.test_function is not None and not force:
return
if not self.run_eagerly and self.jit_compile:
# Note that we mark the state to be donated to jax,
# so that jax will reuse the memory buffer for outputs.
# This will reduce the memory usage of the training function by
# half.
test_step = jit(self.test_step, donate_argnums=0)

out_shardings = None
if distribution_lib.distribution() is not None:
(
trainable_shardings,
non_trainable_shardings,
_, # optimizer_shardings
metrics_shardings,
) = self._get_training_state_shardings()
state_shardings = (
trainable_shardings,
non_trainable_shardings,
metrics_shardings,
)
out_shardings = (None, state_shardings)
test_step = jit(
self.test_step,
donate_argnums=0,
out_shardings=out_shardings,
)
else:
test_step = self.test_step

Expand All @@ -319,7 +315,24 @@ def predict_step(state, data):
return outputs, (state[0], non_trainable_variables)

if not self.run_eagerly and self.jit_compile:
predict_step = jit(predict_step, donate_argnums=0)
out_shardings = None
if distribution_lib.distribution() is not None:
(
trainable_shardings,
non_trainable_shardings,
_, # optimizer_shardings
_, # metrics_shardings
) = self._get_training_state_shardings()
state_shardings = (
trainable_shardings,
non_trainable_shardings,
)
out_shardings = (None, state_shardings)
predict_step = jit(
predict_step,
donate_argnums=0,
out_shardings=out_shardings,
)

_step_function = self._make_function(
predict_step, concatenate_outputs=True
Expand Down Expand Up @@ -402,7 +415,6 @@ def fit(
steps=epoch_iterator.num_batches,
model=self,
)
self._record_training_state_sharding_spec()

self.make_train_function()
self.stop_training = False
Expand Down Expand Up @@ -518,7 +530,6 @@ def fit(
if training_finished:
callbacks.on_train_end(logs=training_logs)
self._jax_state = None
self._clear_jax_state_sharding()
return self.history

@traceback_utils.filter_traceback
Expand Down Expand Up @@ -568,7 +579,6 @@ def evaluate(
steps=epoch_iterator.num_batches,
model=self,
)
self._record_training_state_sharding_spec()

self.make_test_function()
self.stop_evaluating = False
Expand Down Expand Up @@ -620,9 +630,6 @@ def evaluate(
logs = self._get_metrics_result_or_logs(logs)
callbacks.on_test_end(logs)
self._jax_state = None
if not use_cached_eval_dataset:
# Only clear sharding if evaluate is not called from `fit`.
self._clear_jax_state_sharding()
if return_dict:
return logs
return self._flatten_metrics_in_order(logs)
Expand Down Expand Up @@ -664,7 +671,6 @@ def predict(
steps=epoch_iterator.num_batches,
model=self,
)
self._record_training_state_sharding_spec()

self.make_predict_function()
self.stop_predicting = False
Expand Down Expand Up @@ -723,7 +729,6 @@ def append_to_outputs(batch_outputs, outputs):
self.jax_state_sync()
callbacks.on_predict_end()
self._jax_state = None
self._clear_jax_state_sharding()
return tree.map_structure_up_to(batch_outputs, np.concatenate, outputs)

def train_on_batch(
Expand Down Expand Up @@ -752,7 +757,6 @@ def data():

# Maybe build model
self._symbolic_build(data_batch=next(data()))
self._record_training_state_sharding_spec()
self.make_train_function()

# Train step
Expand Down Expand Up @@ -801,7 +805,6 @@ def data():

# Maybe build model
self._symbolic_build(data_batch=next(data()))
self._record_training_state_sharding_spec()
self.make_test_function()

# Test step
Expand Down Expand Up @@ -834,7 +837,6 @@ def predict_on_batch(self, x):
# Build model
with backend.StatelessScope():
self(x)
self._record_training_state_sharding_spec()
self.make_predict_function()

state = self._get_jax_state(
Expand Down Expand Up @@ -884,75 +886,44 @@ def jax_state_sync(self):
ref_v.assign(v)
self._jax_state_synced = True

def _record_training_state_sharding_spec(self):
self._trainable_variable_shardings = [
v.value.sharding for v in self.trainable_variables
]
self._non_trainable_variable_shardings = [
v.value.sharding for v in self.non_trainable_variables
]
if hasattr(self, "optimizer") and self.optimizer is not None:
self._optimizer_variable_shardings = [
v.value.sharding for v in self.optimizer.variables
]
else:
self._optimizer_variable_shardings = []
self._metrics_variable_shardings = [
v.value.sharding for v in self.metrics_variables
]

def _clear_jax_state_sharding(self):
self._trainable_variable_shardings = None
self._non_trainable_variable_shardings = None
self._optimizer_variable_shardings = None
self._metrics_variable_shardings = None

def _enforce_jax_state_sharding(
self,
trainable_variables=None,
non_trainable_variables=None,
optimizer_variables=None,
metrics_variables=None,
):
"""Enforce the sharding spec constraint for all the training state.

Since the output of the train/eval step will be used as inputs to next
step, we need to ensure that they have the same sharding spec, so that
nnx.jit/jax.jit won't have to recompile the train/eval function.

Note that this function will also rely on the recorded sharding spec
for each of states.
def _get_training_state_shardings(self):
distribution = distribution_lib.distribution()
mesh = distribution.device_mesh.backend_mesh

def get_sharding(variable):
partition_spec = PartitionSpec()
if hasattr(distribution, "layout_map"):
tensor_layout = distribution.layout_map.get(variable.path)
if tensor_layout is not None:
partition_spec = PartitionSpec(*tensor_layout.axes)
return NamedSharding(mesh, partition_spec)

trainable_shardings = tuple(
get_sharding(v) for v in self.trainable_variables
)
non_trainable_shardings = tuple(
get_sharding(v) for v in self.non_trainable_variables
)

This function is expected to be called within the jitted train/eval
function, especially around the end of the function.
"""
trainable_variables = trainable_variables or []
non_trainable_variables = non_trainable_variables or []
optimizer_variables = optimizer_variables or []
metrics_variables = metrics_variables or []

for i in range(len(trainable_variables)):
trainable_variables[i] = jax.lax.with_sharding_constraint(
trainable_variables[i], self._trainable_variable_shardings[i]
)
for i in range(len(non_trainable_variables)):
non_trainable_variables[i] = jax.lax.with_sharding_constraint(
non_trainable_variables[i],
self._non_trainable_variable_shardings[i],
if hasattr(self, "optimizer") and self.optimizer:
optimizer_shardings = tuple(
get_sharding(v) for v in self.optimizer.variables
)
for i in range(len(optimizer_variables)):
optimizer_variables[i] = jax.lax.with_sharding_constraint(
optimizer_variables[i], self._optimizer_variable_shardings[i]
)
for i in range(len(metrics_variables)):
metrics_variables[i] = jax.lax.with_sharding_constraint(
metrics_variables[i], self._metrics_variable_shardings[i]
else:
optimizer_shardings = ()

if hasattr(self, "metrics_variables"):
metrics_shardings = tuple(
get_sharding(v) for v in self.metrics_variables
)
else:
metrics_shardings = ()

return (
trainable_variables,
non_trainable_variables,
optimizer_variables,
metrics_variables,
trainable_shardings,
non_trainable_shardings,
optimizer_shardings,
metrics_shardings,
)

def _purge_model_variables(
Expand Down
Loading
Loading