@@ -153,7 +153,7 @@ def train_step(self, state, data):
153
153
metrics_variables , unscaled_loss , x , y , y_pred , sample_weight
154
154
)
155
155
156
- state = self . _enforce_jax_state_sharding (
156
+ state = (
157
157
trainable_variables ,
158
158
non_trainable_variables ,
159
159
optimizer_variables ,
@@ -185,17 +185,6 @@ def test_step(self, state, data):
185
185
metrics_variables , unscaled_loss , x , y , y_pred , sample_weight
186
186
)
187
187
188
- (
189
- trainable_variables ,
190
- non_trainable_variables ,
191
- _ ,
192
- metrics_variables ,
193
- ) = self ._enforce_jax_state_sharding (
194
- trainable_variables = trainable_variables ,
195
- non_trainable_variables = non_trainable_variables ,
196
- optimizer_variables = None ,
197
- metrics_variables = metrics_variables ,
198
- )
199
188
state = (
200
189
trainable_variables ,
201
190
non_trainable_variables ,
@@ -213,17 +202,6 @@ def predict_step(self, state, data):
213
202
outputs , non_trainable_variables = self .stateless_call (
214
203
trainable_variables , non_trainable_variables , x , ** kwargs
215
204
)
216
- (
217
- _ ,
218
- non_trainable_variables ,
219
- _ ,
220
- _ ,
221
- ) = self ._enforce_jax_state_sharding (
222
- trainable_variables = None ,
223
- non_trainable_variables = non_trainable_variables ,
224
- optimizer_variables = None ,
225
- metrics_variables = None ,
226
- )
227
205
return outputs , non_trainable_variables
228
206
229
207
def _make_function (self , step_function , concatenate_outputs = False ):
@@ -281,11 +259,15 @@ def make_train_function(self, force=False):
281
259
if self .train_function is not None and not force :
282
260
return
283
261
if not self .run_eagerly and self .jit_compile :
284
- # Note that we mark the state to be donated to jax,
285
- # so that jax will reuse the memory buffer for outputs.
286
- # This will reduce the memory usage of the training function by
287
- # half.
288
- train_step = jit (self .train_step , donate_argnums = 0 )
262
+ out_shardings = None
263
+ if distribution_lib .distribution () is not None :
264
+ state_shardings = self ._get_state_sharding_spec ()
265
+ out_shardings = (None , state_shardings )
266
+ train_step = jit (
267
+ self .train_step ,
268
+ donate_argnums = 0 ,
269
+ out_shardings = out_shardings ,
270
+ )
289
271
else :
290
272
train_step = self .train_step
291
273
@@ -297,12 +279,25 @@ def make_test_function(self, force=False):
297
279
if self .test_function is not None and not force :
298
280
return
299
281
if not self .run_eagerly and self .jit_compile :
300
- # Note that we mark the state to be donated to jax,
301
- # so that jax will reuse the memory buffer for outputs.
302
- # This will reduce the memory usage of the training function by
303
- # half.
304
- test_step = jit (self .test_step , donate_argnums = 0 )
305
-
282
+ out_shardings = None
283
+ if distribution_lib .distribution () is not None :
284
+ (
285
+ trainable_shardings ,
286
+ non_trainable_shardings ,
287
+ _ , # optimizer_shardings
288
+ metrics_shardings ,
289
+ ) = self ._get_state_sharding_spec ()
290
+ state_shardings = (
291
+ trainable_shardings ,
292
+ non_trainable_shardings ,
293
+ metrics_shardings ,
294
+ )
295
+ out_shardings = (None , state_shardings )
296
+ test_step = jit (
297
+ self .test_step ,
298
+ donate_argnums = 0 ,
299
+ out_shardings = out_shardings ,
300
+ )
306
301
else :
307
302
test_step = self .test_step
308
303
@@ -319,7 +314,24 @@ def predict_step(state, data):
319
314
return outputs , (state [0 ], non_trainable_variables )
320
315
321
316
if not self .run_eagerly and self .jit_compile :
322
- predict_step = jit (predict_step , donate_argnums = 0 )
317
+ out_shardings = None
318
+ if distribution_lib .distribution () is not None :
319
+ (
320
+ trainable_shardings ,
321
+ non_trainable_shardings ,
322
+ _ , # optimizer_shardings
323
+ _ , # metrics_shardings
324
+ ) = self ._get_state_sharding_spec ()
325
+ state_shardings = (
326
+ trainable_shardings ,
327
+ non_trainable_shardings ,
328
+ )
329
+ out_shardings = (None , state_shardings )
330
+ predict_step = jit (
331
+ predict_step ,
332
+ donate_argnums = 0 ,
333
+ out_shardings = out_shardings ,
334
+ )
323
335
324
336
_step_function = self ._make_function (
325
337
predict_step , concatenate_outputs = True
@@ -402,7 +414,6 @@ def fit(
402
414
steps = epoch_iterator .num_batches ,
403
415
model = self ,
404
416
)
405
- self ._record_training_state_sharding_spec ()
406
417
407
418
self .make_train_function ()
408
419
self .stop_training = False
@@ -518,7 +529,6 @@ def fit(
518
529
if training_finished :
519
530
callbacks .on_train_end (logs = training_logs )
520
531
self ._jax_state = None
521
- self ._clear_jax_state_sharding ()
522
532
return self .history
523
533
524
534
@traceback_utils .filter_traceback
@@ -568,7 +578,6 @@ def evaluate(
568
578
steps = epoch_iterator .num_batches ,
569
579
model = self ,
570
580
)
571
- self ._record_training_state_sharding_spec ()
572
581
573
582
self .make_test_function ()
574
583
self .stop_evaluating = False
@@ -620,9 +629,6 @@ def evaluate(
620
629
logs = self ._get_metrics_result_or_logs (logs )
621
630
callbacks .on_test_end (logs )
622
631
self ._jax_state = None
623
- if not use_cached_eval_dataset :
624
- # Only clear sharding if evaluate is not called from `fit`.
625
- self ._clear_jax_state_sharding ()
626
632
if return_dict :
627
633
return logs
628
634
return self ._flatten_metrics_in_order (logs )
@@ -664,7 +670,6 @@ def predict(
664
670
steps = epoch_iterator .num_batches ,
665
671
model = self ,
666
672
)
667
- self ._record_training_state_sharding_spec ()
668
673
669
674
self .make_predict_function ()
670
675
self .stop_predicting = False
@@ -723,7 +728,6 @@ def append_to_outputs(batch_outputs, outputs):
723
728
self .jax_state_sync ()
724
729
callbacks .on_predict_end ()
725
730
self ._jax_state = None
726
- self ._clear_jax_state_sharding ()
727
731
return tree .map_structure_up_to (batch_outputs , np .concatenate , outputs )
728
732
729
733
def train_on_batch (
@@ -752,7 +756,6 @@ def data():
752
756
753
757
# Maybe build model
754
758
self ._symbolic_build (data_batch = next (data ()))
755
- self ._record_training_state_sharding_spec ()
756
759
self .make_train_function ()
757
760
758
761
# Train step
@@ -801,7 +804,6 @@ def data():
801
804
802
805
# Maybe build model
803
806
self ._symbolic_build (data_batch = next (data ()))
804
- self ._record_training_state_sharding_spec ()
805
807
self .make_test_function ()
806
808
807
809
# Test step
@@ -834,7 +836,6 @@ def predict_on_batch(self, x):
834
836
# Build model
835
837
with backend .StatelessScope ():
836
838
self (x )
837
- self ._record_training_state_sharding_spec ()
838
839
self .make_predict_function ()
839
840
840
841
state = self ._get_jax_state (
@@ -884,75 +885,25 @@ def jax_state_sync(self):
884
885
ref_v .assign (v )
885
886
self ._jax_state_synced = True
886
887
887
- def _record_training_state_sharding_spec (self ):
888
- self . _trainable_variable_shardings = [
888
+ def _get_state_sharding_spec (self ):
889
+ trainable_shardings = [
889
890
v .value .sharding for v in self .trainable_variables
890
891
]
891
- self . _non_trainable_variable_shardings = [
892
+ non_trainable_shardings = [
892
893
v .value .sharding for v in self .non_trainable_variables
893
894
]
894
895
if hasattr (self , "optimizer" ) and self .optimizer is not None :
895
- self . _optimizer_variable_shardings = [
896
+ optimizer_shardings = [
896
897
v .value .sharding for v in self .optimizer .variables
897
898
]
898
899
else :
899
- self ._optimizer_variable_shardings = []
900
- self ._metrics_variable_shardings = [
901
- v .value .sharding for v in self .metrics_variables
902
- ]
903
-
904
- def _clear_jax_state_sharding (self ):
905
- self ._trainable_variable_shardings = None
906
- self ._non_trainable_variable_shardings = None
907
- self ._optimizer_variable_shardings = None
908
- self ._metrics_variable_shardings = None
909
-
910
- def _enforce_jax_state_sharding (
911
- self ,
912
- trainable_variables = None ,
913
- non_trainable_variables = None ,
914
- optimizer_variables = None ,
915
- metrics_variables = None ,
916
- ):
917
- """Enforce the sharding spec constraint for all the training state.
918
-
919
- Since the output of the train/eval step will be used as inputs to next
920
- step, we need to ensure that they have the same sharding spec, so that
921
- nnx.jit/jax.jit won't have to recompile the train/eval function.
922
-
923
- Note that this function will also rely on the recorded sharding spec
924
- for each of states.
925
-
926
- This function is expected to be called within the jitted train/eval
927
- function, especially around the end of the function.
928
- """
929
- trainable_variables = trainable_variables or []
930
- non_trainable_variables = non_trainable_variables or []
931
- optimizer_variables = optimizer_variables or []
932
- metrics_variables = metrics_variables or []
933
-
934
- for i in range (len (trainable_variables )):
935
- trainable_variables [i ] = jax .lax .with_sharding_constraint (
936
- trainable_variables [i ], self ._trainable_variable_shardings [i ]
937
- )
938
- for i in range (len (non_trainable_variables )):
939
- non_trainable_variables [i ] = jax .lax .with_sharding_constraint (
940
- non_trainable_variables [i ],
941
- self ._non_trainable_variable_shardings [i ],
942
- )
943
- for i in range (len (optimizer_variables )):
944
- optimizer_variables [i ] = jax .lax .with_sharding_constraint (
945
- optimizer_variables [i ], self ._optimizer_variable_shardings [i ]
946
- )
947
- for i in range (len (metrics_variables )):
948
- metrics_variables [i ] = jax .lax .with_sharding_constraint (
949
- metrics_variables [i ], self ._metrics_variable_shardings [i ]
950
- )
900
+ optimizer_shardings = []
901
+ metrics_shardings = [v .value .sharding for v in self .metrics_variables ]
951
902
return (
952
- trainable_variables ,
953
- non_trainable_variables ,
954
- optimizer_variables ,
955
- metrics_variables ,
903
+ trainable_shardings ,
904
+ non_trainable_shardings ,
905
+ optimizer_shardings ,
906
+ metrics_shardings ,
956
907
)
957
908
958
909
def _purge_model_variables (
0 commit comments