@@ -808,7 +808,22 @@ def validation_step(self, batch, batch_idx):
808
808
# CASE 2: multiple validation dataloaders
809
809
def validation_step(self, batch, batch_idx, dataloader_idx=0):
810
810
# dataloader_idx tells you which dataset this is.
811
- ...
811
+ x, y = batch
812
+
813
+ # implement your own
814
+ out = self(x)
815
+
816
+ if dataloader_idx == 0:
817
+ loss = self.loss0(out, y)
818
+ else:
819
+ loss = self.loss1(out, y)
820
+
821
+ # calculate acc
822
+ labels_hat = torch.argmax(out, dim=1)
823
+ acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
824
+
825
+ # log the outputs separately for each dataloader
826
+ self.log_dict({f"val_loss_{dataloader_idx}": loss, f"val_acc_{dataloader_idx}": acc})
812
827
813
828
Note:
814
829
If you don't need to validate you don't need to implement this method.
@@ -875,7 +890,22 @@ def test_step(self, batch, batch_idx):
875
890
# CASE 2: multiple test dataloaders
876
891
def test_step(self, batch, batch_idx, dataloader_idx=0):
877
892
# dataloader_idx tells you which dataset this is.
878
- ...
893
+ x, y = batch
894
+
895
+ # implement your own
896
+ out = self(x)
897
+
898
+ if dataloader_idx == 0:
899
+ loss = self.loss0(out, y)
900
+ else:
901
+ loss = self.loss1(out, y)
902
+
903
+ # calculate acc
904
+ labels_hat = torch.argmax(out, dim=1)
905
+ acc = torch.sum(y == labels_hat).item() / (len(y) * 1.0)
906
+
907
+ # log the outputs separately for each dataloader
908
+ self.log_dict({f"test_loss_{dataloader_idx}": loss, f"test_acc_{dataloader_idx}": acc})
879
909
880
910
Note:
881
911
If you don't need to test you don't need to implement this method.
0 commit comments