@@ -266,48 +266,46 @@ def _auto_train_batch(trainer, model, batches, device, current_epoch=0, current_
266266 using_deepspeed = kwargs .get ("strategy" ) == "deepspeed"
267267 out = []
268268 for i in range (current_batch , batches ):
269- out .extend (
270- [
271- {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
272- {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
273- {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
274- {"name" : "Callback.on_train_batch_start" , "args" : (trainer , model , ANY , i )},
275- {"name" : "on_train_batch_start" , "args" : (ANY , i )},
276- {"name" : "forward" , "args" : (ANY ,)},
277- {"name" : "training_step" , "args" : (ANY , i )},
278- {"name" : "Callback.on_before_zero_grad" , "args" : (trainer , model , ANY )},
279- {"name" : "on_before_zero_grad" , "args" : (ANY ,)},
280- {"name" : "optimizer_zero_grad" , "args" : (current_epoch , i , ANY )},
281- {"name" : "Callback.on_before_backward" , "args" : (trainer , model , ANY )},
282- {"name" : "on_before_backward" , "args" : (ANY ,)},
283- # DeepSpeed handles backward internally
284- * ([{"name" : "backward" , "args" : (ANY ,)}] if not using_deepspeed else []),
285- {"name" : "Callback.on_after_backward" , "args" : (trainer , model )},
286- {"name" : "on_after_backward" },
287- # note: unscaling happens here in the case of AMP
288- {"name" : "Callback.on_before_optimizer_step" , "args" : (trainer , model , ANY )},
289- {"name" : "on_before_optimizer_step" , "args" : (ANY ,)},
290- {
291- "name" : "clip_gradients" ,
292- "args" : (ANY ,),
293- "kwargs" : {"gradient_clip_val" : None , "gradient_clip_algorithm" : None },
294- },
295- {
296- "name" : "configure_gradient_clipping" ,
297- "args" : (ANY ,),
298- "kwargs" : {"gradient_clip_val" : None , "gradient_clip_algorithm" : None },
299- },
300- # this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates
301- # the actual call to `Precision.optimizer_step`
302- {
303- "name" : "optimizer_step" ,
304- "args" : (current_epoch , i , ANY , ANY ),
305- },
306- * ([{"name" : "lr_scheduler_step" , "args" : ANY }] if i == (trainer .num_training_batches - 1 ) else []),
307- {"name" : "Callback.on_train_batch_end" , "args" : (trainer , model , {"loss" : ANY }, ANY , i )},
308- {"name" : "on_train_batch_end" , "args" : ({"loss" : ANY }, ANY , i )},
309- ]
310- )
269+ out .extend ([
270+ {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
271+ {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
272+ {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
273+ {"name" : "Callback.on_train_batch_start" , "args" : (trainer , model , ANY , i )},
274+ {"name" : "on_train_batch_start" , "args" : (ANY , i )},
275+ {"name" : "forward" , "args" : (ANY ,)},
276+ {"name" : "training_step" , "args" : (ANY , i )},
277+ {"name" : "Callback.on_before_zero_grad" , "args" : (trainer , model , ANY )},
278+ {"name" : "on_before_zero_grad" , "args" : (ANY ,)},
279+ {"name" : "optimizer_zero_grad" , "args" : (current_epoch , i , ANY )},
280+ {"name" : "Callback.on_before_backward" , "args" : (trainer , model , ANY )},
281+ {"name" : "on_before_backward" , "args" : (ANY ,)},
282+ # DeepSpeed handles backward internally
283+ * ([{"name" : "backward" , "args" : (ANY ,)}] if not using_deepspeed else []),
284+ {"name" : "Callback.on_after_backward" , "args" : (trainer , model )},
285+ {"name" : "on_after_backward" },
286+ # note: unscaling happens here in the case of AMP
287+ {"name" : "Callback.on_before_optimizer_step" , "args" : (trainer , model , ANY )},
288+ {"name" : "on_before_optimizer_step" , "args" : (ANY ,)},
289+ {
290+ "name" : "clip_gradients" ,
291+ "args" : (ANY ,),
292+ "kwargs" : {"gradient_clip_val" : None , "gradient_clip_algorithm" : None },
293+ },
294+ {
295+ "name" : "configure_gradient_clipping" ,
296+ "args" : (ANY ,),
297+ "kwargs" : {"gradient_clip_val" : None , "gradient_clip_algorithm" : None },
298+ },
299+ # this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates
300+ # the actual call to `Precision.optimizer_step`
301+ {
302+ "name" : "optimizer_step" ,
303+ "args" : (current_epoch , i , ANY , ANY ),
304+ },
305+ * ([{"name" : "lr_scheduler_step" , "args" : ANY }] if i == (trainer .num_training_batches - 1 ) else []),
306+ {"name" : "Callback.on_train_batch_end" , "args" : (trainer , model , {"loss" : ANY }, ANY , i )},
307+ {"name" : "on_train_batch_end" , "args" : ({"loss" : ANY }, ANY , i )},
308+ ])
311309 return out
312310
313311 @staticmethod
0 commit comments