@@ -563,6 +563,39 @@ Gradient clipping value
563
563
# default used by the Trainer
564
564
trainer = Trainer(gradient_clip_val=None)
565
565
566
+
567
+ inference_mode
568
+ ^^^^^^^^^^^^^^
569
+
570
+ Whether to use :func: `torch.inference_mode ` or :func: `torch.no_grad ` mode during evaluation
571
+ (``validate ``/``test ``/``predict ``)
572
+
573
+ .. testcode ::
574
+
575
+ # default used by the Trainer
576
+ trainer = Trainer(inference_mode=True)
577
+
578
+ # Use `torch.no_grad ` instead
579
+ trainer = Trainer(inference_mode=False)
580
+
581
+
582
+ With :func: `torch.inference_mode ` disabled, you can enable the grad of your model layers if required.
583
+
584
+ .. code-block :: python
585
+
586
+ class LitModel (LightningModule ):
587
+ def validation_step (self , batch , batch_idx ):
588
+ preds = self .layer1(batch)
589
+ with torch.enable_grad():
590
+ grad_preds = preds.requires_grad_()
591
+ preds2 = self .layer2(grad_preds)
592
+
593
+
594
+ model = LitModel()
595
+ trainer = Trainer(inference_mode = False )
596
+ trainer.validate(model)
597
+
598
+
566
599
limit_train_batches
567
600
^^^^^^^^^^^^^^^^^^^
568
601
@@ -1110,38 +1143,6 @@ Can specify as float, int, or a time-based duration.
1110
1143
# Total number of batches run
1111
1144
total_fit_batches = total_train_batches + total_val_batches
1112
1145
1113
-
1114
- inference_mode
1115
- ^^^^^^^^^^^^^^
1116
-
1117
- Whether to use :func: `torch.inference_mode ` or :func: `torch.no_grad ` mode during evaluation
1118
- (``validate ``/``test ``/``predict ``)
1119
-
1120
- .. testcode ::
1121
-
1122
- # default used by the Trainer
1123
- trainer = Trainer(inference_mode=True)
1124
-
1125
- # Use `torch.no_grad ` instead
1126
- trainer = Trainer(inference_mode=False)
1127
-
1128
-
1129
- With :func: `torch.inference_mode ` disabled, you can enable the grad of your model layers if required.
1130
-
1131
- .. code-block :: python
1132
-
1133
- class LitModel (LightningModule ):
1134
- def validation_step (self , batch , batch_idx ):
1135
- preds = self .layer1(batch)
1136
- with torch.enable_grad():
1137
- grad_preds = preds.requires_grad_()
1138
- preds2 = self .layer2(grad_preds)
1139
-
1140
-
1141
- model = LitModel()
1142
- trainer = Trainer(inference_mode = False )
1143
- trainer.validate(model)
1144
-
1145
1146
-----
1146
1147
1147
1148
Trainer class API
0 commit comments