You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@@ -111,7 +111,7 @@ The LightningModule interface is on the right. Each method corresponds to a part
111
111
### training_step
112
112
113
113
```{.python}
114
-
def training_step(self, data_batch, batch_nb)
114
+
def training_step(self, batch, batch_nb)
115
115
```
116
116
117
117
In this step you'd normally do the forward pass and calculate the loss for a batch. You can also do fancier things like multiple forward passes or something specific to your model.
@@ -120,7 +120,7 @@ In this step you'd normally do the forward pass and calculate the loss for a bat
120
120
121
121
| Param | description |
122
122
|---|---|
123
-
|data_batch| The output of your dataloader. A tensor, tuple or list |
123
+
|batch| The output of your dataloader. A tensor, tuple or list |
124
124
| batch_nb | Integer displaying which batch this is |
125
125
126
126
**Return**
@@ -130,22 +130,22 @@ Dictionary or OrderedDict
130
130
| key | value | is required |
131
131
|---|---|---|
132
132
| loss | tensor scalar | Y |
133
-
|prog| Dict for progress bar display. Must have only tensors | N |
133
+
|progress| Dict for progress bar display. Must have only tensors | N |
Called by lightning during training loop. Make sure to use the @pl.data_loader decorator, this ensures not calling this function until the data are needed.
If you don't need to validate you don't need to implement this method. In this step you'd normally generate examples or calculate anything of interest such as accuracy.
@@ -256,9 +256,9 @@ The dict you return here will be available in the `validation_end` method.
256
256
257
257
| Param | description |
258
258
|---|---|
259
-
|data_batch| The output of your dataloader. A tensor, tuple or list |
259
+
|batch| The output of your dataloader. A tensor, tuple or list |
260
260
| batch_nb | Integer displaying which batch this is |
261
-
|dataloader_i| Integer displaying which dataloader this is (only if multiple val datasets used) |
261
+
|dataloader_idx| Integer displaying which dataloader this is (only if multiple val datasets used) |
262
262
263
263
**Return**
264
264
@@ -270,8 +270,8 @@ The dict you return here will be available in the `validation_end` method.
270
270
271
271
```{.python}
272
272
# CASE 1: A single validation dataset
273
-
def validation_step(self, data_batch, batch_nb):
274
-
x, y = data_batch
273
+
def validation_step(self, batch, batch_nb):
274
+
x, y = batch
275
275
276
276
# implement your own
277
277
out = self.forward(x)
@@ -302,7 +302,7 @@ If you pass in multiple validation datasets, validation_step will have an additi
If you don't need to test you don't need to implement this method. In this step you'd normally generate examples or calculate anything of interest such as accuracy.
@@ -403,9 +403,9 @@ This function is used when you execute `trainer.test()`.
403
403
404
404
| Param | description |
405
405
|---|---|
406
-
|data_batch| The output of your dataloader. A tensor, tuple or list |
406
+
|batch| The output of your dataloader. A tensor, tuple or list |
407
407
| batch_nb | Integer displaying which batch this is |
408
-
|dataloader_i| Integer displaying which dataloader this is (only if multiple test datasets used) |
408
+
|dataloader_idx| Integer displaying which dataloader this is (only if multiple test datasets used) |
409
409
410
410
**Return**
411
411
@@ -417,8 +417,8 @@ This function is used when you execute `trainer.test()`.
417
417
418
418
```{.python}
419
419
# CASE 1: A single test dataset
420
-
def test_step(self, data_batch, batch_nb):
421
-
x, y = data_batch
420
+
def test_step(self, batch, batch_nb):
421
+
x, y = batch
422
422
423
423
# implement your own
424
424
out = self.forward(x)
@@ -443,7 +443,7 @@ If you pass in multiple test datasets, test_step will have an additional argumen
0 commit comments