Skip to content

Commit 817619e

Browse files
authored
fix for trainer bugs (#19584)
1 parent 98fdd0e commit 817619e

File tree

6 files changed

+165
-70
lines changed

6 files changed

+165
-70
lines changed

keras/src/backend/mlx/trainer.py

Lines changed: 135 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from keras.src.backend.common.keras_tensor import KerasTensor
1111
from keras.src.backend.mlx.core import is_tensor
1212
from keras.src.trainers import trainer as base_trainer
13+
from keras.src.trainers.data_adapters import array_slicing
1314
from keras.src.trainers.data_adapters import data_adapter_utils
1415
from keras.src.trainers.epoch_iterator import EpochIterator
1516
from keras.src.utils import traceback_utils
@@ -21,6 +22,7 @@ def __init__(self):
2122
self.train_function = None
2223
self.test_function = None
2324
self.predict_function = None
25+
self._mlx_state_synced = True
2426

2527
def _data_to_mlx(self, data):
2628
def _transform(x):
@@ -55,13 +57,40 @@ def mlx_state_sync(self):
5557
if metrics_variables:
5658
for ref_v, v in zip(self.metrics_variables, metrics_variables):
5759
ref_v.assign(v)
60+
self._mlx_state_synced = True
61+
62+
def _get_mlx_state(
63+
self,
64+
trainable_variables=False,
65+
non_trainable_variables=False,
66+
optimizer_variables=False,
67+
metrics_variables=False,
68+
purge_model_variables=False,
69+
):
70+
state = []
71+
if trainable_variables:
72+
state.append([v.value for v in self.trainable_variables])
73+
if non_trainable_variables:
74+
state.append([v.value for v in self.non_trainable_variables])
75+
if optimizer_variables:
76+
state.append([v.value for v in self.optimizer.variables])
77+
if metrics_variables:
78+
state.append([v.value for v in self.metrics_variables])
79+
if purge_model_variables:
80+
self._purge_model_variables(
81+
trainable_variables=trainable_variables,
82+
non_trainable_variables=non_trainable_variables,
83+
optimizer_variables=optimizer_variables,
84+
metric_variables=metrics_variables,
85+
)
86+
return tuple(state)
5887

5988
def _purge_model_variables(
6089
self,
61-
trainable_variables=True,
62-
non_trainable_variables=True,
63-
optimizer_variables=True,
64-
metric_variables=True,
90+
trainable_variables=False,
91+
non_trainable_variables=False,
92+
optimizer_variables=False,
93+
metric_variables=False,
6594
):
6695
"""Remove all the model variables so they can be garbage collected and
6796
the memory reclaimed by MLX.
@@ -117,6 +146,7 @@ def compute_loss_and_updates(
117146
self,
118147
trainable_variables,
119148
non_trainable_variables,
149+
metrics_variables,
120150
x,
121151
y,
122152
sample_weight,
@@ -135,22 +165,39 @@ def compute_loss_and_updates(
135165
return_losses=True,
136166
**kwargs,
137167
)
168+
if losses:
169+
# Make forward pass losses available to compute_loss.
170+
self._losses_override.clear()
171+
self._losses_override = losses
138172

139-
trainable_mapping = zip(self.trainable_variables, trainable_variables)
140-
with backend.StatelessScope(state_mapping=trainable_mapping):
141-
# Note that this is needed for the regularization loss, which need
142-
# the latest value of train/non-trainable variables.
143-
loss = self.compute_loss(x, y, y_pred, sample_weight)
173+
loss, variables = self.stateless_compute_loss(
174+
trainable_variables,
175+
non_trainable_variables,
176+
metrics_variables,
177+
x=x,
178+
y=y,
179+
y_pred=y_pred,
180+
sample_weight=sample_weight,
181+
)
144182
if losses:
145-
loss += ops.sum(losses)
183+
self._losses_override.clear()
184+
(trainable_variables, non_trainable_variables, metrics_variables) = (
185+
variables
186+
)
146187
unscaled_loss = loss
147188
if training and self.optimizer is not None:
148189
# Scale loss with a StatelessScope, to use an update scale variable.
149190
mapping = list(zip(self.optimizer.variables, optimizer_variables))
150191
with backend.StatelessScope(state_mapping=mapping):
151192
loss = self.optimizer.scale_loss(loss)
152193

153-
return loss, unscaled_loss, y_pred, non_trainable_variables
194+
return (
195+
loss,
196+
unscaled_loss,
197+
y_pred,
198+
non_trainable_variables,
199+
metrics_variables,
200+
)
154201

155202
def train_step(self, state, data):
156203
data = self._data_to_mlx(data)
@@ -169,9 +216,11 @@ def train_step(self, state, data):
169216
unscaled_loss,
170217
y_pred,
171218
non_trainable_variables,
219+
metrics_variables,
172220
), grads = grad_fn(
173221
trainable_variables,
174222
non_trainable_variables,
223+
metrics_variables,
175224
x,
176225
y,
177226
sample_weight,
@@ -191,9 +240,11 @@ def train_step(self, state, data):
191240
unscaled_loss,
192241
y_pred,
193242
non_trainable_variables,
243+
metrics_variables,
194244
) = self.compute_loss_and_updates(
195245
trainable_variables,
196246
non_trainable_variables,
247+
metrics_variables,
197248
x,
198249
y,
199250
sample_weight,
@@ -239,9 +290,11 @@ def test_step(self, state, data):
239290
unscaled_loss,
240291
y_pred,
241292
non_trainable_variables,
293+
metrics_variables,
242294
) = self.compute_loss_and_updates(
243295
trainable_variables,
244296
non_trainable_variables,
297+
metrics_variables,
245298
x,
246299
y,
247300
sample_weight,
@@ -443,7 +496,7 @@ def fit(
443496
x,
444497
y,
445498
sample_weight,
446-
), validation_data = data_adapter_utils.train_validation_split(
499+
), validation_data = array_slicing.train_validation_split(
447500
(x, y, sample_weight), validation_split=validation_split
448501
)
449502

@@ -496,30 +549,27 @@ def fit(
496549
self.stop_training = False
497550
self.make_train_function()
498551
callbacks.on_train_begin()
499-
552+
initial_epoch = self._initial_epoch or initial_epoch
500553
for epoch in range(initial_epoch, epochs):
501554
self.reset_metrics()
502555
callbacks.on_epoch_begin(epoch)
503556

504-
trainable_variables = [v.value for v in self.trainable_variables]
505-
non_trainable_variables = [
506-
v.value for v in self.non_trainable_variables
507-
]
508-
optimizer_variables = [v.value for v in self.optimizer.variables]
509-
metrics_variables = [v.value for v in self.metrics_variables]
510-
511-
self._purge_model_variables()
557+
self._mlx_state_synced = True
512558
for step, data in epoch_iterator.enumerate_epoch():
513559
# Callbacks
514560
callbacks.on_train_batch_begin(step)
515-
516561
# Train step
517-
state = (
518-
trainable_variables,
519-
non_trainable_variables,
520-
optimizer_variables,
521-
metrics_variables,
522-
)
562+
if self._mlx_state_synced:
563+
# The state may have been synced by a callback.
564+
state = self._get_mlx_state(
565+
trainable_variables=True,
566+
non_trainable_variables=True,
567+
optimizer_variables=True,
568+
metrics_variables=True,
569+
purge_model_variables=True,
570+
)
571+
self._mlx_state_synced = False
572+
523573
logs, state = self.train_function(state, data)
524574
mx.eval(logs, state)
525575
(
@@ -547,7 +597,9 @@ def fit(
547597
self.mlx_state_sync()
548598

549599
# Override with model metrics instead of last step logs
550-
epoch_logs = self._pythonify_logs(self.get_metrics_result())
600+
epoch_logs = self._pythonify_logs(
601+
self._get_metrics_result_or_logs(logs)
602+
)
551603

552604
# Run validation.
553605
if validation_data and self._should_eval(epoch, validation_freq):
@@ -687,7 +739,7 @@ def evaluate(
687739
break
688740

689741
self.mlx_state_sync()
690-
logs = self._pythonify_logs(self.get_metrics_result())
742+
logs = self._pythonify_logs(self._get_metrics_result_or_logs(logs))
691743
callbacks.on_test_end(logs)
692744
self._mlx_state = None
693745

@@ -711,8 +763,10 @@ def predict(
711763
if not all(layer.built for layer in self._flatten_layers()):
712764
# Build the model on one batch of data.
713765
for _, data in epoch_iterator.enumerate_epoch():
714-
data_batch = data[0]
715-
self._symbolic_build(data_batch)
766+
# Build model
767+
x, _, _ = data_adapter_utils.unpack_x_y_sample_weight(data[0])
768+
with backend.StatelessScope():
769+
self(x)
716770
break
717771

718772
# Container that configures and calls callbacks.
@@ -746,22 +800,36 @@ def append_to_outputs(batch_outputs, outputs):
746800
self.stop_predicting = False
747801
callbacks.on_predict_begin()
748802

749-
trainable_variables = [v.value for v in self.trainable_variables]
750-
non_trainable_variables = [
751-
v.value for v in self.non_trainable_variables
752-
]
753-
state = (trainable_variables, non_trainable_variables)
754-
803+
self._mlx_state_synced = True
755804
outputs = None
805+
non_trainable_variables = None
756806
for step, data in epoch_iterator.enumerate_epoch():
757807
callbacks.on_predict_batch_begin(step)
808+
if self._mlx_state_synced:
809+
# The state may have been synced by a callback.
810+
state = self._get_mlx_state(
811+
trainable_variables=True,
812+
non_trainable_variables=True,
813+
)
814+
self._purge_model_variables(non_trainable_variables=True)
815+
self._mlx_state_synced = False
816+
else:
817+
state = (state[0], non_trainable_variables)
758818
batch_outputs, state = self.predict_function(state, data)
759819
mx.eval(batch_outputs, state)
820+
(trainable_variables, non_trainable_variables) = state
760821
outputs = append_to_outputs(batch_outputs, outputs)
761822
callbacks.on_predict_batch_end(step, {"outputs": batch_outputs})
762823
if self.stop_predicting:
763824
break
825+
self._mlx_state = {
826+
# I wouldn't recommend modifying non-trainable model state
827+
# during predict(), but it's allowed.
828+
"non_trainable_variables": non_trainable_variables,
829+
}
830+
self.mlx_state_sync()
764831
callbacks.on_predict_end()
832+
self._mlx_state = None
765833
outputs = tree.map_structure(
766834
backend.convert_to_numpy, outputs
767835
) # TODO: This copies but we could avoid it
@@ -794,19 +862,14 @@ def train_on_batch(
794862
self._symbolic_build(data)
795863
self.make_train_function()
796864

797-
trainable_variables = [v.value for v in self.trainable_variables]
798-
non_trainable_variables = [
799-
v.value for v in self.non_trainable_variables
800-
]
801-
optimizer_variables = [v.value for v in self.optimizer.variables]
802-
metrics_variables = [v.value for v in self.metrics_variables]
803-
# TODO: Why not purge model state?
804-
state = (
805-
trainable_variables,
806-
non_trainable_variables,
807-
optimizer_variables,
808-
metrics_variables,
865+
state = self._get_mlx_state(
866+
trainable_variables=True,
867+
non_trainable_variables=True,
868+
optimizer_variables=True,
869+
metrics_variables=True,
870+
purge_model_variables=False,
809871
)
872+
self._mlx_state_synced = False
810873
logs, state = self.train_function(state, [data])
811874
mx.eval(logs, state)
812875

@@ -846,17 +909,13 @@ def test_on_batch(
846909
self.make_test_function()
847910

848911
# Test step
849-
trainable_variables = [v.value for v in self.trainable_variables]
850-
non_trainable_variables = [
851-
v.value for v in self.non_trainable_variables
852-
]
853-
metrics_variables = [v.value for v in self.metrics_variables]
854-
# TODO: Why not purge model state?
855-
state = (
856-
trainable_variables,
857-
non_trainable_variables,
858-
metrics_variables,
912+
state = self._get_mlx_state(
913+
trainable_variables=True,
914+
non_trainable_variables=True,
915+
metrics_variables=True,
916+
purge_model_variables=False,
859917
)
918+
self._mlx_state_synced = False
860919
logs, state = self.test_function(state, [data])
861920
mx.eval(logs, state)
862921

@@ -875,17 +934,26 @@ def test_on_batch(
875934
return self._flatten_metrics_in_order(logs)
876935

877936
def predict_on_batch(self, x):
937+
if not all(layer.built for layer in self._flatten_layers()):
938+
# Build model
939+
with backend.StatelessScope():
940+
self(x)
878941
self._symbolic_build(x)
879942
self.make_predict_function()
880-
881-
trainable_variables = [v.value for v in self.trainable_variables]
882-
non_trainable_variables = [
883-
v.value for v in self.non_trainable_variables
884-
]
885-
state = (trainable_variables, non_trainable_variables)
943+
state = self._get_mlx_state(
944+
trainable_variables=True,
945+
non_trainable_variables=True,
946+
metrics_variables=False,
947+
purge_model_variables=False,
948+
)
949+
self._mlx_state_synced = False
886950
batch_outputs, state = self.predict_function(state, [(x,)])
887951
mx.eval(batch_outputs, state)
888-
952+
trainable_variables, non_trainable_variables = state
953+
self._mlx_state = {
954+
"non_trainable_variables": non_trainable_variables,
955+
}
956+
self.mlx_state_sync()
889957
# TODO: This copies but we could avoid it
890958
batch_outputs = tree.map_structure(
891959
backend.convert_to_numpy, batch_outputs

keras/src/callbacks/callback.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,10 @@ def model(self):
8383
# epoch. We have to force a sync before
8484
# accessing model state for e.g. checkpointing.
8585
self._model.jax_state_sync()
86+
elif backend.backend() == "mlx" and hasattr(
87+
self._model, "mlx_state_sync"
88+
):
89+
self._model.mlx_state_sync()
8690
return self._model
8791

8892
def on_batch_begin(self, batch, logs=None):

0 commit comments

Comments
 (0)