@@ -266,48 +266,46 @@ def _auto_train_batch(trainer, model, batches, device, current_epoch=0, current_
266
266
using_deepspeed = kwargs .get ("strategy" ) == "deepspeed"
267
267
out = []
268
268
for i in range (current_batch , batches ):
269
- out .extend (
270
- [
271
- {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
272
- {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
273
- {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
274
- {"name" : "Callback.on_train_batch_start" , "args" : (trainer , model , ANY , i )},
275
- {"name" : "on_train_batch_start" , "args" : (ANY , i )},
276
- {"name" : "forward" , "args" : (ANY ,)},
277
- {"name" : "training_step" , "args" : (ANY , i )},
278
- {"name" : "Callback.on_before_zero_grad" , "args" : (trainer , model , ANY )},
279
- {"name" : "on_before_zero_grad" , "args" : (ANY ,)},
280
- {"name" : "optimizer_zero_grad" , "args" : (current_epoch , i , ANY )},
281
- {"name" : "Callback.on_before_backward" , "args" : (trainer , model , ANY )},
282
- {"name" : "on_before_backward" , "args" : (ANY ,)},
283
- # DeepSpeed handles backward internally
284
- * ([{"name" : "backward" , "args" : (ANY ,)}] if not using_deepspeed else []),
285
- {"name" : "Callback.on_after_backward" , "args" : (trainer , model )},
286
- {"name" : "on_after_backward" },
287
- # note: unscaling happens here in the case of AMP
288
- {"name" : "Callback.on_before_optimizer_step" , "args" : (trainer , model , ANY )},
289
- {"name" : "on_before_optimizer_step" , "args" : (ANY ,)},
290
- {
291
- "name" : "clip_gradients" ,
292
- "args" : (ANY ,),
293
- "kwargs" : {"gradient_clip_val" : None , "gradient_clip_algorithm" : None },
294
- },
295
- {
296
- "name" : "configure_gradient_clipping" ,
297
- "args" : (ANY ,),
298
- "kwargs" : {"gradient_clip_val" : None , "gradient_clip_algorithm" : None },
299
- },
300
- # this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates
301
- # the actual call to `Precision.optimizer_step`
302
- {
303
- "name" : "optimizer_step" ,
304
- "args" : (current_epoch , i , ANY , ANY ),
305
- },
306
- * ([{"name" : "lr_scheduler_step" , "args" : ANY }] if i == (trainer .num_training_batches - 1 ) else []),
307
- {"name" : "Callback.on_train_batch_end" , "args" : (trainer , model , {"loss" : ANY }, ANY , i )},
308
- {"name" : "on_train_batch_end" , "args" : ({"loss" : ANY }, ANY , i )},
309
- ]
310
- )
269
+ out .extend ([
270
+ {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
271
+ {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
272
+ {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
273
+ {"name" : "Callback.on_train_batch_start" , "args" : (trainer , model , ANY , i )},
274
+ {"name" : "on_train_batch_start" , "args" : (ANY , i )},
275
+ {"name" : "forward" , "args" : (ANY ,)},
276
+ {"name" : "training_step" , "args" : (ANY , i )},
277
+ {"name" : "Callback.on_before_zero_grad" , "args" : (trainer , model , ANY )},
278
+ {"name" : "on_before_zero_grad" , "args" : (ANY ,)},
279
+ {"name" : "optimizer_zero_grad" , "args" : (current_epoch , i , ANY )},
280
+ {"name" : "Callback.on_before_backward" , "args" : (trainer , model , ANY )},
281
+ {"name" : "on_before_backward" , "args" : (ANY ,)},
282
+ # DeepSpeed handles backward internally
283
+ * ([{"name" : "backward" , "args" : (ANY ,)}] if not using_deepspeed else []),
284
+ {"name" : "Callback.on_after_backward" , "args" : (trainer , model )},
285
+ {"name" : "on_after_backward" },
286
+ # note: unscaling happens here in the case of AMP
287
+ {"name" : "Callback.on_before_optimizer_step" , "args" : (trainer , model , ANY )},
288
+ {"name" : "on_before_optimizer_step" , "args" : (ANY ,)},
289
+ {
290
+ "name" : "clip_gradients" ,
291
+ "args" : (ANY ,),
292
+ "kwargs" : {"gradient_clip_val" : None , "gradient_clip_algorithm" : None },
293
+ },
294
+ {
295
+ "name" : "configure_gradient_clipping" ,
296
+ "args" : (ANY ,),
297
+ "kwargs" : {"gradient_clip_val" : None , "gradient_clip_algorithm" : None },
298
+ },
299
+ # this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates
300
+ # the actual call to `Precision.optimizer_step`
301
+ {
302
+ "name" : "optimizer_step" ,
303
+ "args" : (current_epoch , i , ANY , ANY ),
304
+ },
305
+ * ([{"name" : "lr_scheduler_step" , "args" : ANY }] if i == (trainer .num_training_batches - 1 ) else []),
306
+ {"name" : "Callback.on_train_batch_end" , "args" : (trainer , model , {"loss" : ANY }, ANY , i )},
307
+ {"name" : "on_train_batch_end" , "args" : ({"loss" : ANY }, ANY , i )},
308
+ ])
311
309
return out
312
310
313
311
@staticmethod
0 commit comments