21
21
from lightning .pytorch import Callback , LightningDataModule , LightningModule , Trainer , __version__
22
22
from lightning .pytorch .demos .boring_classes import BoringDataModule , BoringModel , RandomDataset
23
23
from lightning .pytorch .utilities .model_helpers import is_overridden
24
+ from tests_pytorch .helpers .runif import RunIf
24
25
from torch import Tensor
25
26
from torch .utils .data import DataLoader
26
27
27
- from tests_pytorch .helpers .runif import RunIf
28
-
29
28
30
29
class HookedDataModule (BoringDataModule ):
31
30
def __init__ (self , called ):
@@ -267,79 +266,83 @@ def _auto_train_batch(trainer, model, batches, device, current_epoch=0, current_
267
266
using_deepspeed = kwargs .get ("strategy" ) == "deepspeed"
268
267
out = []
269
268
for i in range (current_batch , batches ):
270
- out .extend ([
271
- {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
272
- {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
273
- {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
274
- {"name" : "Callback.on_train_batch_start" , "args" : (trainer , model , ANY , i )},
275
- {"name" : "on_train_batch_start" , "args" : (ANY , i )},
276
- {"name" : "forward" , "args" : (ANY ,)},
277
- {"name" : "training_step" , "args" : (ANY , i )},
278
- {"name" : "Callback.on_before_zero_grad" , "args" : (trainer , model , ANY )},
279
- {"name" : "on_before_zero_grad" , "args" : (ANY ,)},
280
- {"name" : "optimizer_zero_grad" , "args" : (current_epoch , i , ANY )},
281
- {"name" : "Callback.on_before_backward" , "args" : (trainer , model , ANY )},
282
- {"name" : "on_before_backward" , "args" : (ANY ,)},
283
- # DeepSpeed handles backward internally
284
- * ([{"name" : "backward" , "args" : (ANY ,)}] if not using_deepspeed else []),
285
- {"name" : "Callback.on_after_backward" , "args" : (trainer , model )},
286
- {"name" : "on_after_backward" },
287
- # note: unscaling happens here in the case of AMP
288
- {"name" : "Callback.on_before_optimizer_step" , "args" : (trainer , model , ANY )},
289
- {"name" : "on_before_optimizer_step" , "args" : (ANY ,)},
290
- {
291
- "name" : "clip_gradients" ,
292
- "args" : (ANY ,),
293
- "kwargs" : {"gradient_clip_val" : None , "gradient_clip_algorithm" : None },
294
- },
295
- {
296
- "name" : "configure_gradient_clipping" ,
297
- "args" : (ANY ,),
298
- "kwargs" : {"gradient_clip_val" : None , "gradient_clip_algorithm" : None },
299
- },
300
- # this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates
301
- # the actual call to `Precision.optimizer_step`
302
- {
303
- "name" : "optimizer_step" ,
304
- "args" : (current_epoch , i , ANY , ANY ),
305
- },
306
- * (
307
- [{"name" : "lr_scheduler_step" , "args" : (ANY , None )}]
308
- if i == (trainer .num_training_batches - 1 )
309
- else []
310
- ),
311
- {"name" : "Callback.on_train_batch_end" , "args" : (trainer , model , {"loss" : ANY }, ANY , i )},
312
- {"name" : "on_train_batch_end" , "args" : ({"loss" : ANY }, ANY , i )},
313
- ])
269
+ out .extend (
270
+ [
271
+ {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
272
+ {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
273
+ {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
274
+ {"name" : "Callback.on_train_batch_start" , "args" : (trainer , model , ANY , i )},
275
+ {"name" : "on_train_batch_start" , "args" : (ANY , i )},
276
+ {"name" : "forward" , "args" : (ANY ,)},
277
+ {"name" : "training_step" , "args" : (ANY , i )},
278
+ {"name" : "Callback.on_before_zero_grad" , "args" : (trainer , model , ANY )},
279
+ {"name" : "on_before_zero_grad" , "args" : (ANY ,)},
280
+ {"name" : "optimizer_zero_grad" , "args" : (current_epoch , i , ANY )},
281
+ {"name" : "Callback.on_before_backward" , "args" : (trainer , model , ANY )},
282
+ {"name" : "on_before_backward" , "args" : (ANY ,)},
283
+ # DeepSpeed handles backward internally
284
+ * ([{"name" : "backward" , "args" : (ANY ,)}] if not using_deepspeed else []),
285
+ {"name" : "Callback.on_after_backward" , "args" : (trainer , model )},
286
+ {"name" : "on_after_backward" },
287
+ # note: unscaling happens here in the case of AMP
288
+ {"name" : "Callback.on_before_optimizer_step" , "args" : (trainer , model , ANY )},
289
+ {"name" : "on_before_optimizer_step" , "args" : (ANY ,)},
290
+ {
291
+ "name" : "clip_gradients" ,
292
+ "args" : (ANY ,),
293
+ "kwargs" : {"gradient_clip_val" : None , "gradient_clip_algorithm" : None },
294
+ },
295
+ {
296
+ "name" : "configure_gradient_clipping" ,
297
+ "args" : (ANY ,),
298
+ "kwargs" : {"gradient_clip_val" : None , "gradient_clip_algorithm" : None },
299
+ },
300
+ # this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates
301
+ # the actual call to `Precision.optimizer_step`
302
+ {
303
+ "name" : "optimizer_step" ,
304
+ "args" : (current_epoch , i , ANY , ANY ),
305
+ },
306
+ * (
307
+ [{"name" : "lr_scheduler_step" , "args" : (ANY , None )}]
308
+ if i == (trainer .num_training_batches - 1 )
309
+ else []
310
+ ),
311
+ {"name" : "Callback.on_train_batch_end" , "args" : (trainer , model , {"loss" : ANY }, ANY , i )},
312
+ {"name" : "on_train_batch_end" , "args" : ({"loss" : ANY }, ANY , i )},
313
+ ]
314
+ )
314
315
return out
315
316
316
317
@staticmethod
317
318
def _manual_train_batch (trainer , model , batches , device , ** kwargs ):
318
319
using_deepspeed = kwargs .get ("strategy" ) == "deepspeed"
319
320
out = []
320
321
for i in range (batches ):
321
- out .extend ([
322
- {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
323
- {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
324
- {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
325
- {"name" : "Callback.on_train_batch_start" , "args" : (trainer , model , ANY , i )},
326
- {"name" : "on_train_batch_start" , "args" : (ANY , i )},
327
- {"name" : "forward" , "args" : (ANY ,)},
328
- {"name" : "Callback.on_before_backward" , "args" : (trainer , model , ANY )},
329
- {"name" : "on_before_backward" , "args" : (ANY ,)},
330
- # DeepSpeed handles backward internally
331
- * ([{"name" : "backward" , "args" : (ANY ,)}] if not using_deepspeed else []),
332
- {"name" : "Callback.on_after_backward" , "args" : (trainer , model )},
333
- {"name" : "on_after_backward" },
334
- # `manual_backward` calls the previous 3
335
- {"name" : "manual_backward" , "args" : (ANY ,)},
336
- {"name" : "closure" },
337
- {"name" : "Callback.on_before_optimizer_step" , "args" : (trainer , model , ANY )},
338
- {"name" : "on_before_optimizer_step" , "args" : (ANY ,)},
339
- {"name" : "training_step" , "args" : (ANY , i )},
340
- {"name" : "Callback.on_train_batch_end" , "args" : (trainer , model , {"loss" : ANY }, ANY , i )},
341
- {"name" : "on_train_batch_end" , "args" : ({"loss" : ANY }, ANY , i )},
342
- ])
322
+ out .extend (
323
+ [
324
+ {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
325
+ {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
326
+ {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
327
+ {"name" : "Callback.on_train_batch_start" , "args" : (trainer , model , ANY , i )},
328
+ {"name" : "on_train_batch_start" , "args" : (ANY , i )},
329
+ {"name" : "forward" , "args" : (ANY ,)},
330
+ {"name" : "Callback.on_before_backward" , "args" : (trainer , model , ANY )},
331
+ {"name" : "on_before_backward" , "args" : (ANY ,)},
332
+ # DeepSpeed handles backward internally
333
+ * ([{"name" : "backward" , "args" : (ANY ,)}] if not using_deepspeed else []),
334
+ {"name" : "Callback.on_after_backward" , "args" : (trainer , model )},
335
+ {"name" : "on_after_backward" },
336
+ # `manual_backward` calls the previous 3
337
+ {"name" : "manual_backward" , "args" : (ANY ,)},
338
+ {"name" : "closure" },
339
+ {"name" : "Callback.on_before_optimizer_step" , "args" : (trainer , model , ANY )},
340
+ {"name" : "on_before_optimizer_step" , "args" : (ANY ,)},
341
+ {"name" : "training_step" , "args" : (ANY , i )},
342
+ {"name" : "Callback.on_train_batch_end" , "args" : (trainer , model , {"loss" : ANY }, ANY , i )},
343
+ {"name" : "on_train_batch_end" , "args" : ({"loss" : ANY }, ANY , i )},
344
+ ]
345
+ )
343
346
return out
344
347
345
348
@staticmethod
@@ -357,34 +360,38 @@ def _eval_batch(fn, trainer, model, batches, key, device):
357
360
out = []
358
361
outputs = {key : ANY }
359
362
for i in range (batches ):
360
- out .extend ([
361
- {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
362
- {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
363
- {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
364
- {"name" : f"Callback.on_{ fn } _batch_start" , "args" : (trainer , model , ANY , i )},
365
- {"name" : f"on_{ fn } _batch_start" , "args" : (ANY , i )},
366
- {"name" : "forward" , "args" : (ANY ,)},
367
- {"name" : f"{ fn } _step" , "args" : (ANY , i )},
368
- {"name" : f"Callback.on_{ fn } _batch_end" , "args" : (trainer , model , outputs , ANY , i )},
369
- {"name" : f"on_{ fn } _batch_end" , "args" : (outputs , ANY , i )},
370
- ])
363
+ out .extend (
364
+ [
365
+ {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
366
+ {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
367
+ {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
368
+ {"name" : f"Callback.on_{ fn } _batch_start" , "args" : (trainer , model , ANY , i )},
369
+ {"name" : f"on_{ fn } _batch_start" , "args" : (ANY , i )},
370
+ {"name" : "forward" , "args" : (ANY ,)},
371
+ {"name" : f"{ fn } _step" , "args" : (ANY , i )},
372
+ {"name" : f"Callback.on_{ fn } _batch_end" , "args" : (trainer , model , outputs , ANY , i )},
373
+ {"name" : f"on_{ fn } _batch_end" , "args" : (outputs , ANY , i )},
374
+ ]
375
+ )
371
376
return out
372
377
373
378
@staticmethod
374
379
def _predict_batch (trainer , model , batches , device ):
375
380
out = []
376
381
for i in range (batches ):
377
- out .extend ([
378
- {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
379
- {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
380
- {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
381
- {"name" : "Callback.on_predict_batch_start" , "args" : (trainer , model , ANY , i )},
382
- {"name" : "on_predict_batch_start" , "args" : (ANY , i )},
383
- {"name" : "forward" , "args" : (ANY ,)},
384
- {"name" : "predict_step" , "args" : (ANY , i )},
385
- {"name" : "Callback.on_predict_batch_end" , "args" : (trainer , model , ANY , ANY , i )},
386
- {"name" : "on_predict_batch_end" , "args" : (ANY , ANY , i )},
387
- ])
382
+ out .extend (
383
+ [
384
+ {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
385
+ {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
386
+ {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
387
+ {"name" : "Callback.on_predict_batch_start" , "args" : (trainer , model , ANY , i )},
388
+ {"name" : "on_predict_batch_start" , "args" : (ANY , i )},
389
+ {"name" : "forward" , "args" : (ANY ,)},
390
+ {"name" : "predict_step" , "args" : (ANY , i )},
391
+ {"name" : "Callback.on_predict_batch_end" , "args" : (trainer , model , ANY , ANY , i )},
392
+ {"name" : "on_predict_batch_end" , "args" : (ANY , ANY , i )},
393
+ ]
394
+ )
388
395
return out
389
396
390
397
# override so that it gets called
@@ -472,11 +479,11 @@ def training_step(self, batch, batch_idx):
472
479
expected = [
473
480
{"name" : "configure_callbacks" },
474
481
{"name" : "prepare_data" },
482
+ {"name" : "configure_model" },
475
483
{"name" : "Callback.setup" , "args" : (trainer , model ), "kwargs" : {"stage" : "fit" }},
476
484
{"name" : "setup" , "kwargs" : {"stage" : "fit" }},
477
485
# DeepSpeed needs the batch size to figure out throughput logging
478
486
* ([{"name" : "train_dataloader" }] if using_deepspeed else []),
479
- {"name" : "configure_model" },
480
487
{"name" : "configure_optimizers" },
481
488
{"name" : "Callback.on_fit_start" , "args" : (trainer , model )},
482
489
{"name" : "on_fit_start" },
@@ -569,9 +576,9 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path):
569
576
expected = [
570
577
{"name" : "configure_callbacks" },
571
578
{"name" : "prepare_data" },
579
+ {"name" : "configure_model" },
572
580
{"name" : "Callback.setup" , "args" : (trainer , model ), "kwargs" : {"stage" : "fit" }},
573
581
{"name" : "setup" , "kwargs" : {"stage" : "fit" }},
574
- {"name" : "configure_model" },
575
582
{"name" : "on_load_checkpoint" , "args" : (loaded_ckpt ,)},
576
583
{"name" : "Callback.on_load_checkpoint" , "args" : (trainer , model , loaded_ckpt )},
577
584
{"name" : "Callback.load_state_dict" , "args" : ({"foo" : True },)},
@@ -647,9 +654,9 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path):
647
654
expected = [
648
655
{"name" : "configure_callbacks" },
649
656
{"name" : "prepare_data" },
657
+ {"name" : "configure_model" },
650
658
{"name" : "Callback.setup" , "args" : (trainer , model ), "kwargs" : {"stage" : "fit" }},
651
659
{"name" : "setup" , "kwargs" : {"stage" : "fit" }},
652
- {"name" : "configure_model" },
653
660
{"name" : "on_load_checkpoint" , "args" : (loaded_ckpt ,)},
654
661
{"name" : "Callback.on_load_checkpoint" , "args" : (trainer , model , loaded_ckpt )},
655
662
{"name" : "Callback.load_state_dict" , "args" : ({"foo" : True },)},
@@ -714,9 +721,9 @@ def test_trainer_model_hook_system_eval(tmp_path, override_on_x_model_train, bat
714
721
expected = [
715
722
{"name" : "configure_callbacks" },
716
723
{"name" : "prepare_data" },
724
+ {"name" : "configure_model" },
717
725
{"name" : "Callback.setup" , "args" : (trainer , model ), "kwargs" : {"stage" : verb }},
718
726
{"name" : "setup" , "kwargs" : {"stage" : verb }},
719
- {"name" : "configure_model" },
720
727
{"name" : "zero_grad" },
721
728
* (hooks if batches else []),
722
729
{"name" : "Callback.teardown" , "args" : (trainer , model ), "kwargs" : {"stage" : verb }},
@@ -737,9 +744,9 @@ def test_trainer_model_hook_system_predict(tmp_path):
737
744
expected = [
738
745
{"name" : "configure_callbacks" },
739
746
{"name" : "prepare_data" },
747
+ {"name" : "configure_model" },
740
748
{"name" : "Callback.setup" , "args" : (trainer , model ), "kwargs" : {"stage" : "predict" }},
741
749
{"name" : "setup" , "kwargs" : {"stage" : "predict" }},
742
- {"name" : "configure_model" },
743
750
{"name" : "zero_grad" },
744
751
{"name" : "predict_dataloader" },
745
752
{"name" : "train" , "args" : (False ,)},
0 commit comments