@@ -808,7 +808,22 @@ def validation_step(self, batch, batch_idx):
808808 # CASE 2: multiple validation dataloaders
809809 def validation_step(self, batch, batch_idx, dataloader_idx=0):
810810 # dataloader_idx tells you which dataset this is.
811- ...
811+ x, y = batch
812+
813+ # implement your own
814+ out = self(x)
815+
816+ if dataloader_idx == 0:
817+ loss = self.loss0(out, y)
818+ else:
819+ loss = self.loss1(out, y)
820+
821+ # calculate acc
822+ labels_hat = torch.argmax(out, dim=1)
823+ acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
824+
825+ # log the outputs separately for each dataloader
826+ self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})
812827
813828 Note:
814829 If you don't need to validate you don't need to implement this method.
@@ -875,7 +890,22 @@ def test_step(self, batch, batch_idx):
875890 # CASE 2: multiple test dataloaders
876891 def test_step(self, batch, batch_idx, dataloader_idx=0):
877892 # dataloader_idx tells you which dataset this is.
878- ...
893+ x, y = batch
894+
895+ # implement your own
896+ out = self(x)
897+
898+ if dataloader_idx == 0:
899+ loss = self.loss0(out, y)
900+ else:
901+ loss = self.loss1(out, y)
902+
903+ # calculate acc
904+ labels_hat = torch.argmax(out, dim=1)
905+ acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
906+
907+ # log the outputs separately for each dataloader
908+ self.log_dict({f"test_loss_{dataloader_idx}": loss, f"test_acc_{dataloader_idx}": acc})
879909
880910 Note:
881911 If you don't need to test you don't need to implement this method.
0 commit comments