Skip to content

Commit 8ba644a

Browse files
Alan ChuAlan Chu
authored andcommitted
change assertion order for setup() and configure_model() in test_hooks.py
1 parent bfa0fd4 commit 8ba644a

File tree

1 file changed

+102
-95
lines changed

1 file changed

+102
-95
lines changed

tests/tests_pytorch/models/test_hooks.py

Lines changed: 102 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,10 @@
2121
from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, __version__
2222
from lightning.pytorch.demos.boring_classes import BoringDataModule, BoringModel, RandomDataset
2323
from lightning.pytorch.utilities.model_helpers import is_overridden
24+
from tests_pytorch.helpers.runif import RunIf
2425
from torch import Tensor
2526
from torch.utils.data import DataLoader
2627

27-
from tests_pytorch.helpers.runif import RunIf
28-
2928

3029
class HookedDataModule(BoringDataModule):
3130
def __init__(self, called):
@@ -267,79 +266,83 @@ def _auto_train_batch(trainer, model, batches, device, current_epoch=0, current_
267266
using_deepspeed = kwargs.get("strategy") == "deepspeed"
268267
out = []
269268
for i in range(current_batch, batches):
270-
out.extend([
271-
{"name": "on_before_batch_transfer", "args": (ANY, 0)},
272-
{"name": "transfer_batch_to_device", "args": (ANY, device, 0)},
273-
{"name": "on_after_batch_transfer", "args": (ANY, 0)},
274-
{"name": "Callback.on_train_batch_start", "args": (trainer, model, ANY, i)},
275-
{"name": "on_train_batch_start", "args": (ANY, i)},
276-
{"name": "forward", "args": (ANY,)},
277-
{"name": "training_step", "args": (ANY, i)},
278-
{"name": "Callback.on_before_zero_grad", "args": (trainer, model, ANY)},
279-
{"name": "on_before_zero_grad", "args": (ANY,)},
280-
{"name": "optimizer_zero_grad", "args": (current_epoch, i, ANY)},
281-
{"name": "Callback.on_before_backward", "args": (trainer, model, ANY)},
282-
{"name": "on_before_backward", "args": (ANY,)},
283-
# DeepSpeed handles backward internally
284-
*([{"name": "backward", "args": (ANY,)}] if not using_deepspeed else []),
285-
{"name": "Callback.on_after_backward", "args": (trainer, model)},
286-
{"name": "on_after_backward"},
287-
# note: unscaling happens here in the case of AMP
288-
{"name": "Callback.on_before_optimizer_step", "args": (trainer, model, ANY)},
289-
{"name": "on_before_optimizer_step", "args": (ANY,)},
290-
{
291-
"name": "clip_gradients",
292-
"args": (ANY,),
293-
"kwargs": {"gradient_clip_val": None, "gradient_clip_algorithm": None},
294-
},
295-
{
296-
"name": "configure_gradient_clipping",
297-
"args": (ANY,),
298-
"kwargs": {"gradient_clip_val": None, "gradient_clip_algorithm": None},
299-
},
300-
# this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates
301-
# the actual call to `Precision.optimizer_step`
302-
{
303-
"name": "optimizer_step",
304-
"args": (current_epoch, i, ANY, ANY),
305-
},
306-
*(
307-
[{"name": "lr_scheduler_step", "args": (ANY, None)}]
308-
if i == (trainer.num_training_batches - 1)
309-
else []
310-
),
311-
{"name": "Callback.on_train_batch_end", "args": (trainer, model, {"loss": ANY}, ANY, i)},
312-
{"name": "on_train_batch_end", "args": ({"loss": ANY}, ANY, i)},
313-
])
269+
out.extend(
270+
[
271+
{"name": "on_before_batch_transfer", "args": (ANY, 0)},
272+
{"name": "transfer_batch_to_device", "args": (ANY, device, 0)},
273+
{"name": "on_after_batch_transfer", "args": (ANY, 0)},
274+
{"name": "Callback.on_train_batch_start", "args": (trainer, model, ANY, i)},
275+
{"name": "on_train_batch_start", "args": (ANY, i)},
276+
{"name": "forward", "args": (ANY,)},
277+
{"name": "training_step", "args": (ANY, i)},
278+
{"name": "Callback.on_before_zero_grad", "args": (trainer, model, ANY)},
279+
{"name": "on_before_zero_grad", "args": (ANY,)},
280+
{"name": "optimizer_zero_grad", "args": (current_epoch, i, ANY)},
281+
{"name": "Callback.on_before_backward", "args": (trainer, model, ANY)},
282+
{"name": "on_before_backward", "args": (ANY,)},
283+
# DeepSpeed handles backward internally
284+
*([{"name": "backward", "args": (ANY,)}] if not using_deepspeed else []),
285+
{"name": "Callback.on_after_backward", "args": (trainer, model)},
286+
{"name": "on_after_backward"},
287+
# note: unscaling happens here in the case of AMP
288+
{"name": "Callback.on_before_optimizer_step", "args": (trainer, model, ANY)},
289+
{"name": "on_before_optimizer_step", "args": (ANY,)},
290+
{
291+
"name": "clip_gradients",
292+
"args": (ANY,),
293+
"kwargs": {"gradient_clip_val": None, "gradient_clip_algorithm": None},
294+
},
295+
{
296+
"name": "configure_gradient_clipping",
297+
"args": (ANY,),
298+
"kwargs": {"gradient_clip_val": None, "gradient_clip_algorithm": None},
299+
},
300+
# this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates
301+
# the actual call to `Precision.optimizer_step`
302+
{
303+
"name": "optimizer_step",
304+
"args": (current_epoch, i, ANY, ANY),
305+
},
306+
*(
307+
[{"name": "lr_scheduler_step", "args": (ANY, None)}]
308+
if i == (trainer.num_training_batches - 1)
309+
else []
310+
),
311+
{"name": "Callback.on_train_batch_end", "args": (trainer, model, {"loss": ANY}, ANY, i)},
312+
{"name": "on_train_batch_end", "args": ({"loss": ANY}, ANY, i)},
313+
]
314+
)
314315
return out
315316

316317
@staticmethod
317318
def _manual_train_batch(trainer, model, batches, device, **kwargs):
318319
using_deepspeed = kwargs.get("strategy") == "deepspeed"
319320
out = []
320321
for i in range(batches):
321-
out.extend([
322-
{"name": "on_before_batch_transfer", "args": (ANY, 0)},
323-
{"name": "transfer_batch_to_device", "args": (ANY, device, 0)},
324-
{"name": "on_after_batch_transfer", "args": (ANY, 0)},
325-
{"name": "Callback.on_train_batch_start", "args": (trainer, model, ANY, i)},
326-
{"name": "on_train_batch_start", "args": (ANY, i)},
327-
{"name": "forward", "args": (ANY,)},
328-
{"name": "Callback.on_before_backward", "args": (trainer, model, ANY)},
329-
{"name": "on_before_backward", "args": (ANY,)},
330-
# DeepSpeed handles backward internally
331-
*([{"name": "backward", "args": (ANY,)}] if not using_deepspeed else []),
332-
{"name": "Callback.on_after_backward", "args": (trainer, model)},
333-
{"name": "on_after_backward"},
334-
# `manual_backward` calls the previous 3
335-
{"name": "manual_backward", "args": (ANY,)},
336-
{"name": "closure"},
337-
{"name": "Callback.on_before_optimizer_step", "args": (trainer, model, ANY)},
338-
{"name": "on_before_optimizer_step", "args": (ANY,)},
339-
{"name": "training_step", "args": (ANY, i)},
340-
{"name": "Callback.on_train_batch_end", "args": (trainer, model, {"loss": ANY}, ANY, i)},
341-
{"name": "on_train_batch_end", "args": ({"loss": ANY}, ANY, i)},
342-
])
322+
out.extend(
323+
[
324+
{"name": "on_before_batch_transfer", "args": (ANY, 0)},
325+
{"name": "transfer_batch_to_device", "args": (ANY, device, 0)},
326+
{"name": "on_after_batch_transfer", "args": (ANY, 0)},
327+
{"name": "Callback.on_train_batch_start", "args": (trainer, model, ANY, i)},
328+
{"name": "on_train_batch_start", "args": (ANY, i)},
329+
{"name": "forward", "args": (ANY,)},
330+
{"name": "Callback.on_before_backward", "args": (trainer, model, ANY)},
331+
{"name": "on_before_backward", "args": (ANY,)},
332+
# DeepSpeed handles backward internally
333+
*([{"name": "backward", "args": (ANY,)}] if not using_deepspeed else []),
334+
{"name": "Callback.on_after_backward", "args": (trainer, model)},
335+
{"name": "on_after_backward"},
336+
# `manual_backward` calls the previous 3
337+
{"name": "manual_backward", "args": (ANY,)},
338+
{"name": "closure"},
339+
{"name": "Callback.on_before_optimizer_step", "args": (trainer, model, ANY)},
340+
{"name": "on_before_optimizer_step", "args": (ANY,)},
341+
{"name": "training_step", "args": (ANY, i)},
342+
{"name": "Callback.on_train_batch_end", "args": (trainer, model, {"loss": ANY}, ANY, i)},
343+
{"name": "on_train_batch_end", "args": ({"loss": ANY}, ANY, i)},
344+
]
345+
)
343346
return out
344347

345348
@staticmethod
@@ -357,34 +360,38 @@ def _eval_batch(fn, trainer, model, batches, key, device):
357360
out = []
358361
outputs = {key: ANY}
359362
for i in range(batches):
360-
out.extend([
361-
{"name": "on_before_batch_transfer", "args": (ANY, 0)},
362-
{"name": "transfer_batch_to_device", "args": (ANY, device, 0)},
363-
{"name": "on_after_batch_transfer", "args": (ANY, 0)},
364-
{"name": f"Callback.on_{fn}_batch_start", "args": (trainer, model, ANY, i)},
365-
{"name": f"on_{fn}_batch_start", "args": (ANY, i)},
366-
{"name": "forward", "args": (ANY,)},
367-
{"name": f"{fn}_step", "args": (ANY, i)},
368-
{"name": f"Callback.on_{fn}_batch_end", "args": (trainer, model, outputs, ANY, i)},
369-
{"name": f"on_{fn}_batch_end", "args": (outputs, ANY, i)},
370-
])
363+
out.extend(
364+
[
365+
{"name": "on_before_batch_transfer", "args": (ANY, 0)},
366+
{"name": "transfer_batch_to_device", "args": (ANY, device, 0)},
367+
{"name": "on_after_batch_transfer", "args": (ANY, 0)},
368+
{"name": f"Callback.on_{fn}_batch_start", "args": (trainer, model, ANY, i)},
369+
{"name": f"on_{fn}_batch_start", "args": (ANY, i)},
370+
{"name": "forward", "args": (ANY,)},
371+
{"name": f"{fn}_step", "args": (ANY, i)},
372+
{"name": f"Callback.on_{fn}_batch_end", "args": (trainer, model, outputs, ANY, i)},
373+
{"name": f"on_{fn}_batch_end", "args": (outputs, ANY, i)},
374+
]
375+
)
371376
return out
372377

373378
@staticmethod
374379
def _predict_batch(trainer, model, batches, device):
375380
out = []
376381
for i in range(batches):
377-
out.extend([
378-
{"name": "on_before_batch_transfer", "args": (ANY, 0)},
379-
{"name": "transfer_batch_to_device", "args": (ANY, device, 0)},
380-
{"name": "on_after_batch_transfer", "args": (ANY, 0)},
381-
{"name": "Callback.on_predict_batch_start", "args": (trainer, model, ANY, i)},
382-
{"name": "on_predict_batch_start", "args": (ANY, i)},
383-
{"name": "forward", "args": (ANY,)},
384-
{"name": "predict_step", "args": (ANY, i)},
385-
{"name": "Callback.on_predict_batch_end", "args": (trainer, model, ANY, ANY, i)},
386-
{"name": "on_predict_batch_end", "args": (ANY, ANY, i)},
387-
])
382+
out.extend(
383+
[
384+
{"name": "on_before_batch_transfer", "args": (ANY, 0)},
385+
{"name": "transfer_batch_to_device", "args": (ANY, device, 0)},
386+
{"name": "on_after_batch_transfer", "args": (ANY, 0)},
387+
{"name": "Callback.on_predict_batch_start", "args": (trainer, model, ANY, i)},
388+
{"name": "on_predict_batch_start", "args": (ANY, i)},
389+
{"name": "forward", "args": (ANY,)},
390+
{"name": "predict_step", "args": (ANY, i)},
391+
{"name": "Callback.on_predict_batch_end", "args": (trainer, model, ANY, ANY, i)},
392+
{"name": "on_predict_batch_end", "args": (ANY, ANY, i)},
393+
]
394+
)
388395
return out
389396

390397
# override so that it gets called
@@ -472,11 +479,11 @@ def training_step(self, batch, batch_idx):
472479
expected = [
473480
{"name": "configure_callbacks"},
474481
{"name": "prepare_data"},
482+
{"name": "configure_model"},
475483
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}},
476484
{"name": "setup", "kwargs": {"stage": "fit"}},
477485
# DeepSpeed needs the batch size to figure out throughput logging
478486
*([{"name": "train_dataloader"}] if using_deepspeed else []),
479-
{"name": "configure_model"},
480487
{"name": "configure_optimizers"},
481488
{"name": "Callback.on_fit_start", "args": (trainer, model)},
482489
{"name": "on_fit_start"},
@@ -569,9 +576,9 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmp_path):
569576
expected = [
570577
{"name": "configure_callbacks"},
571578
{"name": "prepare_data"},
579+
{"name": "configure_model"},
572580
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}},
573581
{"name": "setup", "kwargs": {"stage": "fit"}},
574-
{"name": "configure_model"},
575582
{"name": "on_load_checkpoint", "args": (loaded_ckpt,)},
576583
{"name": "Callback.on_load_checkpoint", "args": (trainer, model, loaded_ckpt)},
577584
{"name": "Callback.load_state_dict", "args": ({"foo": True},)},
@@ -647,9 +654,9 @@ def test_trainer_model_hook_system_fit_no_val_and_resume_max_steps(tmp_path):
647654
expected = [
648655
{"name": "configure_callbacks"},
649656
{"name": "prepare_data"},
657+
{"name": "configure_model"},
650658
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "fit"}},
651659
{"name": "setup", "kwargs": {"stage": "fit"}},
652-
{"name": "configure_model"},
653660
{"name": "on_load_checkpoint", "args": (loaded_ckpt,)},
654661
{"name": "Callback.on_load_checkpoint", "args": (trainer, model, loaded_ckpt)},
655662
{"name": "Callback.load_state_dict", "args": ({"foo": True},)},
@@ -714,9 +721,9 @@ def test_trainer_model_hook_system_eval(tmp_path, override_on_x_model_train, bat
714721
expected = [
715722
{"name": "configure_callbacks"},
716723
{"name": "prepare_data"},
724+
{"name": "configure_model"},
717725
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": verb}},
718726
{"name": "setup", "kwargs": {"stage": verb}},
719-
{"name": "configure_model"},
720727
{"name": "zero_grad"},
721728
*(hooks if batches else []),
722729
{"name": "Callback.teardown", "args": (trainer, model), "kwargs": {"stage": verb}},
@@ -737,9 +744,9 @@ def test_trainer_model_hook_system_predict(tmp_path):
737744
expected = [
738745
{"name": "configure_callbacks"},
739746
{"name": "prepare_data"},
747+
{"name": "configure_model"},
740748
{"name": "Callback.setup", "args": (trainer, model), "kwargs": {"stage": "predict"}},
741749
{"name": "setup", "kwargs": {"stage": "predict"}},
742-
{"name": "configure_model"},
743750
{"name": "zero_grad"},
744751
{"name": "predict_dataloader"},
745752
{"name": "train", "args": (False,)},

0 commit comments

Comments
 (0)