Skip to content

Commit 9c48699

Browse files
Docs for logging multiple val and test dataloaders (#21054)
* docs for logging multiple val and test dataloaders * Apply suggestions from code review --------- Co-authored-by: Bhimraj Yadav <[email protected]>
1 parent a777069 commit 9c48699

File tree

1 file changed

+32
-2
lines changed

1 file changed

+32
-2
lines changed

src/lightning/pytorch/core/module.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -808,7 +808,22 @@ def validation_step(self, batch, batch_idx):
808808
# CASE 2: multiple validation dataloaders
809809
def validation_step(self, batch, batch_idx, dataloader_idx=0):
810810
# dataloader_idx tells you which dataset this is.
811-
...
811+
x, y = batch
812+
813+
# implement your own
814+
out = self(x)
815+
816+
if dataloader_idx == 0:
817+
loss = self.loss0(out, y)
818+
else:
819+
loss = self.loss1(out, y)
820+
821+
# calculate acc
822+
labels_hat = torch.argmax(out, dim=1)
823+
acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
824+
825+
# log the outputs separately for each dataloader
826+
self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})
812827
813828
Note:
814829
If you don't need to validate you don't need to implement this method.
@@ -875,7 +890,22 @@ def test_step(self, batch, batch_idx):
875890
# CASE 2: multiple test dataloaders
876891
def test_step(self, batch, batch_idx, dataloader_idx=0):
877892
# dataloader_idx tells you which dataset this is.
878-
...
893+
x, y = batch
894+
895+
# implement your own
896+
out = self(x)
897+
898+
if dataloader_idx == 0:
899+
loss = self.loss0(out, y)
900+
else:
901+
loss = self.loss1(out, y)
902+
903+
# calculate acc
904+
labels_hat = torch.argmax(out, dim=1)
905+
acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
906+
907+
# log the outputs separately for each dataloader
908+
self.log_dict({f"test_loss_{dataloader_idx}": loss, f"test_acc_{dataloader_idx}": acc})
879909
880910
Note:
881911
If you don't need to test you don't need to implement this method.

0 commit comments

Comments
 (0)