Skip to content

Commit 9af43d7

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 2eaa4c9 commit 9af43d7

File tree

1 file changed

+40
-40
lines changed

1 file changed

+40
-40
lines changed

src/lightning/pytorch/trainer/trainer.py

Lines changed: 40 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1697,61 +1697,61 @@ def _results(self) -> Optional[_ResultCollection]:
16971697
def estimated_stepping_batches(self) -> Union[int, float]:
16981698
r"""The estimated number of batches that will ``optimizer.step()`` during training.
16991699
1700-
This accounts for gradient accumulation and the current trainer configuration. This might be used when setting
1701-
up your training dataloader, if it hasn't been set up already.
1700+
This accounts for gradient accumulation and the current trainer configuration. This might be used when setting
1701+
up your training dataloader, if it hasn't been set up already.
17021702
1703-
.. code-block:: python
1703+
.. code-block:: python
17041704
1705-
def configure_optimizers(self):
1706-
optimizer = ...
1707-
stepping_batches = self.trainer.estimated_stepping_batches
1708-
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=stepping_batches)
1709-
return [val_check_interval
1710-
.. video:: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/val_check_interval.mp4
1711-
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/val_check_interval.jpg
1712-
:width: 400
1713-
:muted:
1705+
def configure_optimizers(self):
1706+
optimizer = ...
1707+
stepping_batches = self.trainer.estimated_stepping_batches
1708+
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=1e-3, total_steps=stepping_batches)
1709+
return [val_check_interval
1710+
.. video:: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/val_check_interval.mp4
1711+
:poster: https://pl-bolts-doc-images.s3.us-east-2.amazonaws.com/pl_docs/trainer_flags/thumb/val_check_interval.jpg
1712+
:width: 400
1713+
:muted:
17141714
1715-
How often within one training epoch to check the validation set. Can specify as float or int.
1715+
How often within one training epoch to check the validation set. Can specify as float or int.
17161716
1717-
pass a float in the range [0.0, 1.0] to check after a fraction of the training epoch.
1718-
pass an int to check after a fixed number of training batches. An int value can only be higher than the number of training batches when check_val_every_n_epoch=None, which validates after every N training batches across epochs or iteration-based training.
1719-
.. testcode::
1717+
pass a float in the range [0.0, 1.0] to check after a fraction of the training epoch.
1718+
pass an int to check after a fixed number of training batches. An int value can only be higher than the number of training batches when check_val_every_n_epoch=None, which validates after every N training batches across epochs or iteration-based training.
1719+
.. testcode::
17201720
1721-
# default used by the Trainer
1722-
trainer = Trainer(val_check_interval=1.0)
1721+
# default used by the Trainer
1722+
trainer = Trainer(val_check_interval=1.0)
17231723
1724-
# check validation set 4 times during a training epoch
1725-
trainer = Trainer(val_check_interval=0.25)
1724+
# check validation set 4 times during a training epoch
1725+
trainer = Trainer(val_check_interval=0.25)
17261726
1727-
# check validation set every 1000 training batches in the current epoch
1728-
trainer = Trainer(val_check_interval=1000)
1727+
# check validation set every 1000 training batches in the current epoch
1728+
trainer = Trainer(val_check_interval=1000)
17291729
1730-
# check validation set every 1000 training batches across complete epochs or during iteration-based training
1731-
# use this when using iterableDataset and your dataset has no length
1732-
# (ie: production cases with streaming data)
1733-
trainer = Trainer(val_check_interval=1000, check_val_every_n_epoch=None)
1730+
# check validation set every 1000 training batches across complete epochs or during iteration-based training
1731+
# use this when using iterableDataset and your dataset has no length
1732+
# (ie: production cases with streaming data)
1733+
trainer = Trainer(val_check_interval=1000, check_val_every_n_epoch=None)
17341734
17351735
1736-
# Here is the computation to estimate the total number of batches seen within an epoch.
1736+
# Here is the computation to estimate the total number of batches seen within an epoch.
17371737
1738-
# Find the total number of train batches
1739-
total_train_batches = total_train_samples // (train_batch_size * world_size)
1738+
# Find the total number of train batches
1739+
total_train_batches = total_train_samples // (train_batch_size * world_size)
17401740
1741-
# Compute how many times we will call validation during the training loop
1742-
val_check_batch = max(1, int(total_train_batches * val_check_interval))
1743-
val_checks_per_epoch = total_train_batches / val_check_batch
1741+
# Compute how many times we will call validation during the training loop
1742+
val_check_batch = max(1, int(total_train_batches * val_check_interval))
1743+
val_checks_per_epoch = total_train_batches / val_check_batch
17441744
1745-
# Find the total number of validation batches
1746-
total_val_batches = total_val_samples // (val_batch_size * world_size)
1745+
# Find the total number of validation batches
1746+
total_val_batches = total_val_samples // (val_batch_size * world_size)
17471747
1748-
# Total number of batches run
1749-
total_fit_batches = total_train_batches + total_val_batchesizer], [scheduler]
1748+
# Total number of batches run
1749+
total_fit_batches = total_train_batches + total_val_batchesizer], [scheduler]
17501750
1751-
Raises:
1752-
MisconfigurationException:
1753-
If estimated stepping batches cannot be computed due to different `accumulate_grad_batches`
1754-
at different epochs.
1751+
Raises:
1752+
MisconfigurationException:
1753+
If estimated stepping batches cannot be computed due to different `accumulate_grad_batches`
1754+
at different epochs.
17551755
17561756
"""
17571757
# infinite training

0 commit comments

Comments
 (0)