Skip to content

Commit 965fc03

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent f10b897 commit 965fc03

File tree

1 file changed

+40
-42
lines changed

1 file changed

+40
-42
lines changed

tests/tests_pytorch/models/test_hooks.py

Lines changed: 40 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -266,48 +266,46 @@ def _auto_train_batch(trainer, model, batches, device, current_epoch=0, current_
266266
using_deepspeed = kwargs.get("strategy") == "deepspeed"
267267
out = []
268268
for i in range(current_batch, batches):
269-
out.extend(
270-
[
271-
{"name": "on_before_batch_transfer", "args": (ANY, 0)},
272-
{"name": "transfer_batch_to_device", "args": (ANY, device, 0)},
273-
{"name": "on_after_batch_transfer", "args": (ANY, 0)},
274-
{"name": "Callback.on_train_batch_start", "args": (trainer, model, ANY, i)},
275-
{"name": "on_train_batch_start", "args": (ANY, i)},
276-
{"name": "forward", "args": (ANY,)},
277-
{"name": "training_step", "args": (ANY, i)},
278-
{"name": "Callback.on_before_zero_grad", "args": (trainer, model, ANY)},
279-
{"name": "on_before_zero_grad", "args": (ANY,)},
280-
{"name": "optimizer_zero_grad", "args": (current_epoch, i, ANY)},
281-
{"name": "Callback.on_before_backward", "args": (trainer, model, ANY)},
282-
{"name": "on_before_backward", "args": (ANY,)},
283-
# DeepSpeed handles backward internally
284-
*([{"name": "backward", "args": (ANY,)}] if not using_deepspeed else []),
285-
{"name": "Callback.on_after_backward", "args": (trainer, model)},
286-
{"name": "on_after_backward"},
287-
# note: unscaling happens here in the case of AMP
288-
{"name": "Callback.on_before_optimizer_step", "args": (trainer, model, ANY)},
289-
{"name": "on_before_optimizer_step", "args": (ANY,)},
290-
{
291-
"name": "clip_gradients",
292-
"args": (ANY,),
293-
"kwargs": {"gradient_clip_val": None, "gradient_clip_algorithm": None},
294-
},
295-
{
296-
"name": "configure_gradient_clipping",
297-
"args": (ANY,),
298-
"kwargs": {"gradient_clip_val": None, "gradient_clip_algorithm": None},
299-
},
300-
# this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates
301-
# the actual call to `Precision.optimizer_step`
302-
{
303-
"name": "optimizer_step",
304-
"args": (current_epoch, i, ANY, ANY),
305-
},
306-
*([{"name": "lr_scheduler_step", "args": ANY}] if i == (trainer.num_training_batches - 1) else []),
307-
{"name": "Callback.on_train_batch_end", "args": (trainer, model, {"loss": ANY}, ANY, i)},
308-
{"name": "on_train_batch_end", "args": ({"loss": ANY}, ANY, i)},
309-
]
310-
)
269+
out.extend([
270+
{"name": "on_before_batch_transfer", "args": (ANY, 0)},
271+
{"name": "transfer_batch_to_device", "args": (ANY, device, 0)},
272+
{"name": "on_after_batch_transfer", "args": (ANY, 0)},
273+
{"name": "Callback.on_train_batch_start", "args": (trainer, model, ANY, i)},
274+
{"name": "on_train_batch_start", "args": (ANY, i)},
275+
{"name": "forward", "args": (ANY,)},
276+
{"name": "training_step", "args": (ANY, i)},
277+
{"name": "Callback.on_before_zero_grad", "args": (trainer, model, ANY)},
278+
{"name": "on_before_zero_grad", "args": (ANY,)},
279+
{"name": "optimizer_zero_grad", "args": (current_epoch, i, ANY)},
280+
{"name": "Callback.on_before_backward", "args": (trainer, model, ANY)},
281+
{"name": "on_before_backward", "args": (ANY,)},
282+
# DeepSpeed handles backward internally
283+
*([{"name": "backward", "args": (ANY,)}] if not using_deepspeed else []),
284+
{"name": "Callback.on_after_backward", "args": (trainer, model)},
285+
{"name": "on_after_backward"},
286+
# note: unscaling happens here in the case of AMP
287+
{"name": "Callback.on_before_optimizer_step", "args": (trainer, model, ANY)},
288+
{"name": "on_before_optimizer_step", "args": (ANY,)},
289+
{
290+
"name": "clip_gradients",
291+
"args": (ANY,),
292+
"kwargs": {"gradient_clip_val": None, "gradient_clip_algorithm": None},
293+
},
294+
{
295+
"name": "configure_gradient_clipping",
296+
"args": (ANY,),
297+
"kwargs": {"gradient_clip_val": None, "gradient_clip_algorithm": None},
298+
},
299+
# this is after because it refers to the `LightningModule.optimizer_step` hook which encapsulates
300+
# the actual call to `Precision.optimizer_step`
301+
{
302+
"name": "optimizer_step",
303+
"args": (current_epoch, i, ANY, ANY),
304+
},
305+
*([{"name": "lr_scheduler_step", "args": ANY}] if i == (trainer.num_training_batches - 1) else []),
306+
{"name": "Callback.on_train_batch_end", "args": (trainer, model, {"loss": ANY}, ANY, i)},
307+
{"name": "on_train_batch_end", "args": ({"loss": ANY}, ANY, i)},
308+
])
311309
return out
312310

313311
@staticmethod

0 commit comments

Comments
 (0)