21
21
from lightning .pytorch import Callback , LightningDataModule , LightningModule , Trainer , __version__
22
22
from lightning .pytorch .demos .boring_classes import BoringDataModule , BoringModel , RandomDataset
23
23
from lightning .pytorch .utilities .model_helpers import is_overridden
24
- from tests_pytorch .helpers .runif import RunIf
25
24
from torch import Tensor
26
25
from torch .utils .data import DataLoader
27
26
27
+ from tests_pytorch .helpers .runif import RunIf
28
+
28
29
29
30
class HookedDataModule (BoringDataModule ):
30
31
def __init__ (self , called ):
@@ -266,83 +267,79 @@ def _auto_train_batch(trainer, model, batches, device, current_epoch=0, current_
266
267
using_deepspeed = kwargs .get ("strategy" ) == "deepspeed"
267
268
out = []
268
269
for i in range (current_batch , batches ):
269
- out .extend (
270
- [
271
- {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
272
- {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
273
- {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
274
- {"name" : "Callback.on_train_batch_start" , "args" : (trainer , model , ANY , i )},
275
- {"name" : "on_train_batch_start" , "args" : (ANY , i )},
276
- {"name" : "forward" , "args" : (ANY ,)},
277
- {"name" : "training_step" , "args" : (ANY , i )},
278
- {"name" : "Callback.on_before_zero_grad" , "args" : (trainer , model , ANY )},
279
- {"name" : "on_before_zero_grad" , "args" : (ANY ,)},
280
- {"name" : "optimizer_zero_grad" , "args" : (current_epoch , i , ANY )},
281
- {"name" : "Callback.on_before_backward" , "args" : (trainer , model , ANY )},
282
- {"name" : "on_before_backward" , "args" : (ANY ,)},
283
- # DeepSpeed handles backward internally
284
- * ([{"name" : "backward" , "args" : (ANY ,)}] if not using_deepspeed else []),
285
- {"name" : "Callback.on_after_backward" , "args" : (trainer , model )},
286
- {"name" : "on_after_backward" },
287
- # note: unscaling happens here in the case of AMP
288
- {"name" : "Callback.on_before_optimizer_step" , "args" : (trainer , model , ANY )},
289
- {"name" : "on_before_optimizer_step" , "args" : (ANY ,)},
290
- {
291
- "name" : "clip_gradients" ,
292
- "args" : (ANY ,),
293
- "kwargs" : {"gradient_clip_val" : None , "gradient_clip_algorithm" : None },
294
- },
295
- {
296
- "name" : "configure_gradient_clipping" ,
297
- "args" : (ANY ,),
298
- "kwargs" : {"gradient_clip_val" : None , "gradient_clip_algorithm" : None },
299
- },
300
- # this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates
301
- # the actual call to `Precision.optimizer_step`
302
- {
303
- "name" : "optimizer_step" ,
304
- "args" : (current_epoch , i , ANY , ANY ),
305
- },
306
- * (
307
- [{"name" : "lr_scheduler_step" , "args" : (ANY , None )}]
308
- if i == (trainer .num_training_batches - 1 )
309
- else []
310
- ),
311
- {"name" : "Callback.on_train_batch_end" , "args" : (trainer , model , {"loss" : ANY }, ANY , i )},
312
- {"name" : "on_train_batch_end" , "args" : ({"loss" : ANY }, ANY , i )},
313
- ]
314
- )
270
+ out .extend ([
271
+ {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
272
+ {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
273
+ {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
274
+ {"name" : "Callback.on_train_batch_start" , "args" : (trainer , model , ANY , i )},
275
+ {"name" : "on_train_batch_start" , "args" : (ANY , i )},
276
+ {"name" : "forward" , "args" : (ANY ,)},
277
+ {"name" : "training_step" , "args" : (ANY , i )},
278
+ {"name" : "Callback.on_before_zero_grad" , "args" : (trainer , model , ANY )},
279
+ {"name" : "on_before_zero_grad" , "args" : (ANY ,)},
280
+ {"name" : "optimizer_zero_grad" , "args" : (current_epoch , i , ANY )},
281
+ {"name" : "Callback.on_before_backward" , "args" : (trainer , model , ANY )},
282
+ {"name" : "on_before_backward" , "args" : (ANY ,)},
283
+ # DeepSpeed handles backward internally
284
+ * ([{"name" : "backward" , "args" : (ANY ,)}] if not using_deepspeed else []),
285
+ {"name" : "Callback.on_after_backward" , "args" : (trainer , model )},
286
+ {"name" : "on_after_backward" },
287
+ # note: unscaling happens here in the case of AMP
288
+ {"name" : "Callback.on_before_optimizer_step" , "args" : (trainer , model , ANY )},
289
+ {"name" : "on_before_optimizer_step" , "args" : (ANY ,)},
290
+ {
291
+ "name" : "clip_gradients" ,
292
+ "args" : (ANY ,),
293
+ "kwargs" : {"gradient_clip_val" : None , "gradient_clip_algorithm" : None },
294
+ },
295
+ {
296
+ "name" : "configure_gradient_clipping" ,
297
+ "args" : (ANY ,),
298
+ "kwargs" : {"gradient_clip_val" : None , "gradient_clip_algorithm" : None },
299
+ },
300
+ # this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates
301
+ # the actual call to `Precision.optimizer_step`
302
+ {
303
+ "name" : "optimizer_step" ,
304
+ "args" : (current_epoch , i , ANY , ANY ),
305
+ },
306
+ * (
307
+ [{"name" : "lr_scheduler_step" , "args" : (ANY , None )}]
308
+ if i == (trainer .num_training_batches - 1 )
309
+ else []
310
+ ),
311
+ {"name" : "Callback.on_train_batch_end" , "args" : (trainer , model , {"loss" : ANY }, ANY , i )},
312
+ {"name" : "on_train_batch_end" , "args" : ({"loss" : ANY }, ANY , i )},
313
+ ])
315
314
return out
316
315
317
316
@staticmethod
318
317
def _manual_train_batch (trainer , model , batches , device , ** kwargs ):
319
318
using_deepspeed = kwargs .get ("strategy" ) == "deepspeed"
320
319
out = []
321
320
for i in range (batches ):
322
- out .extend (
323
- [
324
- {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
325
- {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
326
- {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
327
- {"name" : "Callback.on_train_batch_start" , "args" : (trainer , model , ANY , i )},
328
- {"name" : "on_train_batch_start" , "args" : (ANY , i )},
329
- {"name" : "forward" , "args" : (ANY ,)},
330
- {"name" : "Callback.on_before_backward" , "args" : (trainer , model , ANY )},
331
- {"name" : "on_before_backward" , "args" : (ANY ,)},
332
- # DeepSpeed handles backward internally
333
- * ([{"name" : "backward" , "args" : (ANY ,)}] if not using_deepspeed else []),
334
- {"name" : "Callback.on_after_backward" , "args" : (trainer , model )},
335
- {"name" : "on_after_backward" },
336
- # `manual_backward` calls the previous 3
337
- {"name" : "manual_backward" , "args" : (ANY ,)},
338
- {"name" : "closure" },
339
- {"name" : "Callback.on_before_optimizer_step" , "args" : (trainer , model , ANY )},
340
- {"name" : "on_before_optimizer_step" , "args" : (ANY ,)},
341
- {"name" : "training_step" , "args" : (ANY , i )},
342
- {"name" : "Callback.on_train_batch_end" , "args" : (trainer , model , {"loss" : ANY }, ANY , i )},
343
- {"name" : "on_train_batch_end" , "args" : ({"loss" : ANY }, ANY , i )},
344
- ]
345
- )
321
+ out .extend ([
322
+ {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
323
+ {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
324
+ {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
325
+ {"name" : "Callback.on_train_batch_start" , "args" : (trainer , model , ANY , i )},
326
+ {"name" : "on_train_batch_start" , "args" : (ANY , i )},
327
+ {"name" : "forward" , "args" : (ANY ,)},
328
+ {"name" : "Callback.on_before_backward" , "args" : (trainer , model , ANY )},
329
+ {"name" : "on_before_backward" , "args" : (ANY ,)},
330
+ # DeepSpeed handles backward internally
331
+ * ([{"name" : "backward" , "args" : (ANY ,)}] if not using_deepspeed else []),
332
+ {"name" : "Callback.on_after_backward" , "args" : (trainer , model )},
333
+ {"name" : "on_after_backward" },
334
+ # `manual_backward` calls the previous 3
335
+ {"name" : "manual_backward" , "args" : (ANY ,)},
336
+ {"name" : "closure" },
337
+ {"name" : "Callback.on_before_optimizer_step" , "args" : (trainer , model , ANY )},
338
+ {"name" : "on_before_optimizer_step" , "args" : (ANY ,)},
339
+ {"name" : "training_step" , "args" : (ANY , i )},
340
+ {"name" : "Callback.on_train_batch_end" , "args" : (trainer , model , {"loss" : ANY }, ANY , i )},
341
+ {"name" : "on_train_batch_end" , "args" : ({"loss" : ANY }, ANY , i )},
342
+ ])
346
343
return out
347
344
348
345
@staticmethod
@@ -360,38 +357,34 @@ def _eval_batch(fn, trainer, model, batches, key, device):
360
357
out = []
361
358
outputs = {key : ANY }
362
359
for i in range (batches ):
363
- out .extend (
364
- [
365
- {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
366
- {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
367
- {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
368
- {"name" : f"Callback.on_{ fn } _batch_start" , "args" : (trainer , model , ANY , i )},
369
- {"name" : f"on_{ fn } _batch_start" , "args" : (ANY , i )},
370
- {"name" : "forward" , "args" : (ANY ,)},
371
- {"name" : f"{ fn } _step" , "args" : (ANY , i )},
372
- {"name" : f"Callback.on_{ fn } _batch_end" , "args" : (trainer , model , outputs , ANY , i )},
373
- {"name" : f"on_{ fn } _batch_end" , "args" : (outputs , ANY , i )},
374
- ]
375
- )
360
+ out .extend ([
361
+ {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
362
+ {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
363
+ {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
364
+ {"name" : f"Callback.on_{ fn } _batch_start" , "args" : (trainer , model , ANY , i )},
365
+ {"name" : f"on_{ fn } _batch_start" , "args" : (ANY , i )},
366
+ {"name" : "forward" , "args" : (ANY ,)},
367
+ {"name" : f"{ fn } _step" , "args" : (ANY , i )},
368
+ {"name" : f"Callback.on_{ fn } _batch_end" , "args" : (trainer , model , outputs , ANY , i )},
369
+ {"name" : f"on_{ fn } _batch_end" , "args" : (outputs , ANY , i )},
370
+ ])
376
371
return out
377
372
378
373
@staticmethod
379
374
def _predict_batch (trainer , model , batches , device ):
380
375
out = []
381
376
for i in range (batches ):
382
- out .extend (
383
- [
384
- {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
385
- {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
386
- {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
387
- {"name" : "Callback.on_predict_batch_start" , "args" : (trainer , model , ANY , i )},
388
- {"name" : "on_predict_batch_start" , "args" : (ANY , i )},
389
- {"name" : "forward" , "args" : (ANY ,)},
390
- {"name" : "predict_step" , "args" : (ANY , i )},
391
- {"name" : "Callback.on_predict_batch_end" , "args" : (trainer , model , ANY , ANY , i )},
392
- {"name" : "on_predict_batch_end" , "args" : (ANY , ANY , i )},
393
- ]
394
- )
377
+ out .extend ([
378
+ {"name" : "on_before_batch_transfer" , "args" : (ANY , 0 )},
379
+ {"name" : "transfer_batch_to_device" , "args" : (ANY , device , 0 )},
380
+ {"name" : "on_after_batch_transfer" , "args" : (ANY , 0 )},
381
+ {"name" : "Callback.on_predict_batch_start" , "args" : (trainer , model , ANY , i )},
382
+ {"name" : "on_predict_batch_start" , "args" : (ANY , i )},
383
+ {"name" : "forward" , "args" : (ANY ,)},
384
+ {"name" : "predict_step" , "args" : (ANY , i )},
385
+ {"name" : "Callback.on_predict_batch_end" , "args" : (trainer , model , ANY , ANY , i )},
386
+ {"name" : "on_predict_batch_end" , "args" : (ANY , ANY , i )},
387
+ ])
395
388
return out
396
389
397
390
# override so that it gets called
0 commit comments