Skip to content

Commit e766a0a

Browse files
committed
reorder
1 parent aac4824 commit e766a0a

File tree

1 file changed

+33
-32
lines changed

1 file changed

+33
-32
lines changed

docs/source-pytorch/common/trainer.rst

Lines changed: 33 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -563,6 +563,39 @@ Gradient clipping value
563563
# default used by the Trainer
564564
trainer = Trainer(gradient_clip_val=None)
565565

566+
567+
inference_mode
568+
^^^^^^^^^^^^^^
569+
570+
Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` mode during evaluation
571+
(``validate``/``test``/``predict``)
572+
573+
.. testcode::
574+
575+
# default used by the Trainer
576+
trainer = Trainer(inference_mode=True)
577+
578+
# Use `torch.no_grad` instead
579+
trainer = Trainer(inference_mode=False)
580+
581+
582+
With :func:`torch.inference_mode` disabled, you can enable the grad of your model layers if required.
583+
584+
.. code-block:: python
585+
586+
class LitModel(LightningModule):
587+
def validation_step(self, batch, batch_idx):
588+
preds = self.layer1(batch)
589+
with torch.enable_grad():
590+
grad_preds = preds.requires_grad_()
591+
preds2 = self.layer2(grad_preds)
592+
593+
594+
model = LitModel()
595+
trainer = Trainer(inference_mode=False)
596+
trainer.validate(model)
597+
598+
566599
limit_train_batches
567600
^^^^^^^^^^^^^^^^^^^
568601

@@ -1110,38 +1143,6 @@ Can specify as float, int, or a time-based duration.
11101143
# Total number of batches run
11111144
total_fit_batches = total_train_batches + total_val_batches
11121145
1113-
1114-
inference_mode
1115-
^^^^^^^^^^^^^^
1116-
1117-
Whether to use :func:`torch.inference_mode` or :func:`torch.no_grad` mode during evaluation
1118-
(``validate``/``test``/``predict``)
1119-
1120-
.. testcode::
1121-
1122-
# default used by the Trainer
1123-
trainer = Trainer(inference_mode=True)
1124-
1125-
# Use `torch.no_grad` instead
1126-
trainer = Trainer(inference_mode=False)
1127-
1128-
1129-
With :func:`torch.inference_mode` disabled, you can enable the grad of your model layers if required.
1130-
1131-
.. code-block:: python
1132-
1133-
class LitModel(LightningModule):
1134-
def validation_step(self, batch, batch_idx):
1135-
preds = self.layer1(batch)
1136-
with torch.enable_grad():
1137-
grad_preds = preds.requires_grad_()
1138-
preds2 = self.layer2(grad_preds)
1139-
1140-
1141-
model = LitModel()
1142-
trainer = Trainer(inference_mode=False)
1143-
trainer.validate(model)
1144-
11451146
-----
11461147

11471148
Trainer class API

0 commit comments

Comments
 (0)