10
10
from keras .src .backend .common .keras_tensor import KerasTensor
11
11
from keras .src .backend .mlx .core import is_tensor
12
12
from keras .src .trainers import trainer as base_trainer
13
+ from keras .src .trainers .data_adapters import array_slicing
13
14
from keras .src .trainers .data_adapters import data_adapter_utils
14
15
from keras .src .trainers .epoch_iterator import EpochIterator
15
16
from keras .src .utils import traceback_utils
@@ -21,6 +22,7 @@ def __init__(self):
21
22
self .train_function = None
22
23
self .test_function = None
23
24
self .predict_function = None
25
+ self ._mlx_state_synced = True
24
26
25
27
def _data_to_mlx (self , data ):
26
28
def _transform (x ):
@@ -55,13 +57,40 @@ def mlx_state_sync(self):
55
57
if metrics_variables :
56
58
for ref_v , v in zip (self .metrics_variables , metrics_variables ):
57
59
ref_v .assign (v )
60
+ self ._mlx_state_synced = True
61
+
62
+ def _get_mlx_state (
63
+ self ,
64
+ trainable_variables = False ,
65
+ non_trainable_variables = False ,
66
+ optimizer_variables = False ,
67
+ metrics_variables = False ,
68
+ purge_model_variables = False ,
69
+ ):
70
+ state = []
71
+ if trainable_variables :
72
+ state .append ([v .value for v in self .trainable_variables ])
73
+ if non_trainable_variables :
74
+ state .append ([v .value for v in self .non_trainable_variables ])
75
+ if optimizer_variables :
76
+ state .append ([v .value for v in self .optimizer .variables ])
77
+ if metrics_variables :
78
+ state .append ([v .value for v in self .metrics_variables ])
79
+ if purge_model_variables :
80
+ self ._purge_model_variables (
81
+ trainable_variables = trainable_variables ,
82
+ non_trainable_variables = non_trainable_variables ,
83
+ optimizer_variables = optimizer_variables ,
84
+ metric_variables = metrics_variables ,
85
+ )
86
+ return tuple (state )
58
87
59
88
def _purge_model_variables (
60
89
self ,
61
- trainable_variables = True ,
62
- non_trainable_variables = True ,
63
- optimizer_variables = True ,
64
- metric_variables = True ,
90
+ trainable_variables = False ,
91
+ non_trainable_variables = False ,
92
+ optimizer_variables = False ,
93
+ metric_variables = False ,
65
94
):
66
95
"""Remove all the model variables so they can be garbage collected and
67
96
the memory reclaimed by MLX.
@@ -117,6 +146,7 @@ def compute_loss_and_updates(
117
146
self ,
118
147
trainable_variables ,
119
148
non_trainable_variables ,
149
+ metrics_variables ,
120
150
x ,
121
151
y ,
122
152
sample_weight ,
@@ -135,22 +165,39 @@ def compute_loss_and_updates(
135
165
return_losses = True ,
136
166
** kwargs ,
137
167
)
168
+ if losses :
169
+ # Make forward pass losses available to compute_loss.
170
+ self ._losses_override .clear ()
171
+ self ._losses_override = losses
138
172
139
- trainable_mapping = zip (self .trainable_variables , trainable_variables )
140
- with backend .StatelessScope (state_mapping = trainable_mapping ):
141
- # Note that this is needed for the regularization loss, which need
142
- # the latest value of train/non-trainable variables.
143
- loss = self .compute_loss (x , y , y_pred , sample_weight )
173
+ loss , variables = self .stateless_compute_loss (
174
+ trainable_variables ,
175
+ non_trainable_variables ,
176
+ metrics_variables ,
177
+ x = x ,
178
+ y = y ,
179
+ y_pred = y_pred ,
180
+ sample_weight = sample_weight ,
181
+ )
144
182
if losses :
145
- loss += ops .sum (losses )
183
+ self ._losses_override .clear ()
184
+ (trainable_variables , non_trainable_variables , metrics_variables ) = (
185
+ variables
186
+ )
146
187
unscaled_loss = loss
147
188
if training and self .optimizer is not None :
148
189
# Scale loss with a StatelessScope, to use an update scale variable.
149
190
mapping = list (zip (self .optimizer .variables , optimizer_variables ))
150
191
with backend .StatelessScope (state_mapping = mapping ):
151
192
loss = self .optimizer .scale_loss (loss )
152
193
153
- return loss , unscaled_loss , y_pred , non_trainable_variables
194
+ return (
195
+ loss ,
196
+ unscaled_loss ,
197
+ y_pred ,
198
+ non_trainable_variables ,
199
+ metrics_variables ,
200
+ )
154
201
155
202
def train_step (self , state , data ):
156
203
data = self ._data_to_mlx (data )
@@ -169,9 +216,11 @@ def train_step(self, state, data):
169
216
unscaled_loss ,
170
217
y_pred ,
171
218
non_trainable_variables ,
219
+ metrics_variables ,
172
220
), grads = grad_fn (
173
221
trainable_variables ,
174
222
non_trainable_variables ,
223
+ metrics_variables ,
175
224
x ,
176
225
y ,
177
226
sample_weight ,
@@ -191,9 +240,11 @@ def train_step(self, state, data):
191
240
unscaled_loss ,
192
241
y_pred ,
193
242
non_trainable_variables ,
243
+ metrics_variables ,
194
244
) = self .compute_loss_and_updates (
195
245
trainable_variables ,
196
246
non_trainable_variables ,
247
+ metrics_variables ,
197
248
x ,
198
249
y ,
199
250
sample_weight ,
@@ -239,9 +290,11 @@ def test_step(self, state, data):
239
290
unscaled_loss ,
240
291
y_pred ,
241
292
non_trainable_variables ,
293
+ metrics_variables ,
242
294
) = self .compute_loss_and_updates (
243
295
trainable_variables ,
244
296
non_trainable_variables ,
297
+ metrics_variables ,
245
298
x ,
246
299
y ,
247
300
sample_weight ,
@@ -443,7 +496,7 @@ def fit(
443
496
x ,
444
497
y ,
445
498
sample_weight ,
446
- ), validation_data = data_adapter_utils .train_validation_split (
499
+ ), validation_data = array_slicing .train_validation_split (
447
500
(x , y , sample_weight ), validation_split = validation_split
448
501
)
449
502
@@ -496,30 +549,27 @@ def fit(
496
549
self .stop_training = False
497
550
self .make_train_function ()
498
551
callbacks .on_train_begin ()
499
-
552
+ initial_epoch = self . _initial_epoch or initial_epoch
500
553
for epoch in range (initial_epoch , epochs ):
501
554
self .reset_metrics ()
502
555
callbacks .on_epoch_begin (epoch )
503
556
504
- trainable_variables = [v .value for v in self .trainable_variables ]
505
- non_trainable_variables = [
506
- v .value for v in self .non_trainable_variables
507
- ]
508
- optimizer_variables = [v .value for v in self .optimizer .variables ]
509
- metrics_variables = [v .value for v in self .metrics_variables ]
510
-
511
- self ._purge_model_variables ()
557
+ self ._mlx_state_synced = True
512
558
for step , data in epoch_iterator .enumerate_epoch ():
513
559
# Callbacks
514
560
callbacks .on_train_batch_begin (step )
515
-
516
561
# Train step
517
- state = (
518
- trainable_variables ,
519
- non_trainable_variables ,
520
- optimizer_variables ,
521
- metrics_variables ,
522
- )
562
+ if self ._mlx_state_synced :
563
+ # The state may have been synced by a callback.
564
+ state = self ._get_mlx_state (
565
+ trainable_variables = True ,
566
+ non_trainable_variables = True ,
567
+ optimizer_variables = True ,
568
+ metrics_variables = True ,
569
+ purge_model_variables = True ,
570
+ )
571
+ self ._mlx_state_synced = False
572
+
523
573
logs , state = self .train_function (state , data )
524
574
mx .eval (logs , state )
525
575
(
@@ -547,7 +597,9 @@ def fit(
547
597
self .mlx_state_sync ()
548
598
549
599
# Override with model metrics instead of last step logs
550
- epoch_logs = self ._pythonify_logs (self .get_metrics_result ())
600
+ epoch_logs = self ._pythonify_logs (
601
+ self ._get_metrics_result_or_logs (logs )
602
+ )
551
603
552
604
# Run validation.
553
605
if validation_data and self ._should_eval (epoch , validation_freq ):
@@ -687,7 +739,7 @@ def evaluate(
687
739
break
688
740
689
741
self .mlx_state_sync ()
690
- logs = self ._pythonify_logs (self .get_metrics_result ( ))
742
+ logs = self ._pythonify_logs (self ._get_metrics_result_or_logs ( logs ))
691
743
callbacks .on_test_end (logs )
692
744
self ._mlx_state = None
693
745
@@ -711,8 +763,10 @@ def predict(
711
763
if not all (layer .built for layer in self ._flatten_layers ()):
712
764
# Build the model on one batch of data.
713
765
for _ , data in epoch_iterator .enumerate_epoch ():
714
- data_batch = data [0 ]
715
- self ._symbolic_build (data_batch )
766
+ # Build model
767
+ x , _ , _ = data_adapter_utils .unpack_x_y_sample_weight (data [0 ])
768
+ with backend .StatelessScope ():
769
+ self (x )
716
770
break
717
771
718
772
# Container that configures and calls callbacks.
@@ -746,22 +800,36 @@ def append_to_outputs(batch_outputs, outputs):
746
800
self .stop_predicting = False
747
801
callbacks .on_predict_begin ()
748
802
749
- trainable_variables = [v .value for v in self .trainable_variables ]
750
- non_trainable_variables = [
751
- v .value for v in self .non_trainable_variables
752
- ]
753
- state = (trainable_variables , non_trainable_variables )
754
-
803
+ self ._mlx_state_synced = True
755
804
outputs = None
805
+ non_trainable_variables = None
756
806
for step , data in epoch_iterator .enumerate_epoch ():
757
807
callbacks .on_predict_batch_begin (step )
808
+ if self ._mlx_state_synced :
809
+ # The state may have been synced by a callback.
810
+ state = self ._get_mlx_state (
811
+ trainable_variables = True ,
812
+ non_trainable_variables = True ,
813
+ )
814
+ self ._purge_model_variables (non_trainable_variables = True )
815
+ self ._mlx_state_synced = False
816
+ else :
817
+ state = (state [0 ], non_trainable_variables )
758
818
batch_outputs , state = self .predict_function (state , data )
759
819
mx .eval (batch_outputs , state )
820
+ (trainable_variables , non_trainable_variables ) = state
760
821
outputs = append_to_outputs (batch_outputs , outputs )
761
822
callbacks .on_predict_batch_end (step , {"outputs" : batch_outputs })
762
823
if self .stop_predicting :
763
824
break
825
+ self ._mlx_state = {
826
+ # I wouldn't recommend modifying non-trainable model state
827
+ # during predict(), but it's allowed.
828
+ "non_trainable_variables" : non_trainable_variables ,
829
+ }
830
+ self .mlx_state_sync ()
764
831
callbacks .on_predict_end ()
832
+ self ._mlx_state = None
765
833
outputs = tree .map_structure (
766
834
backend .convert_to_numpy , outputs
767
835
) # TODO: This copies but we could avoid it
@@ -794,19 +862,14 @@ def train_on_batch(
794
862
self ._symbolic_build (data )
795
863
self .make_train_function ()
796
864
797
- trainable_variables = [v .value for v in self .trainable_variables ]
798
- non_trainable_variables = [
799
- v .value for v in self .non_trainable_variables
800
- ]
801
- optimizer_variables = [v .value for v in self .optimizer .variables ]
802
- metrics_variables = [v .value for v in self .metrics_variables ]
803
- # TODO: Why not purge model state?
804
- state = (
805
- trainable_variables ,
806
- non_trainable_variables ,
807
- optimizer_variables ,
808
- metrics_variables ,
865
+ state = self ._get_mlx_state (
866
+ trainable_variables = True ,
867
+ non_trainable_variables = True ,
868
+ optimizer_variables = True ,
869
+ metrics_variables = True ,
870
+ purge_model_variables = False ,
809
871
)
872
+ self ._mlx_state_synced = False
810
873
logs , state = self .train_function (state , [data ])
811
874
mx .eval (logs , state )
812
875
@@ -846,17 +909,13 @@ def test_on_batch(
846
909
self .make_test_function ()
847
910
848
911
# Test step
849
- trainable_variables = [v .value for v in self .trainable_variables ]
850
- non_trainable_variables = [
851
- v .value for v in self .non_trainable_variables
852
- ]
853
- metrics_variables = [v .value for v in self .metrics_variables ]
854
- # TODO: Why not purge model state?
855
- state = (
856
- trainable_variables ,
857
- non_trainable_variables ,
858
- metrics_variables ,
912
+ state = self ._get_mlx_state (
913
+ trainable_variables = True ,
914
+ non_trainable_variables = True ,
915
+ metrics_variables = True ,
916
+ purge_model_variables = False ,
859
917
)
918
+ self ._mlx_state_synced = False
860
919
logs , state = self .test_function (state , [data ])
861
920
mx .eval (logs , state )
862
921
@@ -875,17 +934,26 @@ def test_on_batch(
875
934
return self ._flatten_metrics_in_order (logs )
876
935
877
936
def predict_on_batch (self , x ):
937
+ if not all (layer .built for layer in self ._flatten_layers ()):
938
+ # Build model
939
+ with backend .StatelessScope ():
940
+ self (x )
878
941
self ._symbolic_build (x )
879
942
self .make_predict_function ()
880
-
881
- trainable_variables = [v .value for v in self .trainable_variables ]
882
- non_trainable_variables = [
883
- v .value for v in self .non_trainable_variables
884
- ]
885
- state = (trainable_variables , non_trainable_variables )
943
+ state = self ._get_mlx_state (
944
+ trainable_variables = True ,
945
+ non_trainable_variables = True ,
946
+ metrics_variables = False ,
947
+ purge_model_variables = False ,
948
+ )
949
+ self ._mlx_state_synced = False
886
950
batch_outputs , state = self .predict_function (state , [(x ,)])
887
951
mx .eval (batch_outputs , state )
888
-
952
+ trainable_variables , non_trainable_variables = state
953
+ self ._mlx_state = {
954
+ "non_trainable_variables" : non_trainable_variables ,
955
+ }
956
+ self .mlx_state_sync ()
889
957
# TODO: This copies but we could avoid it
890
958
batch_outputs = tree .map_structure (
891
959
backend .convert_to_numpy , batch_outputs
0 commit comments