Skip to content
This repository was archived by the owner on Jan 20, 2026. It is now read-only.

Commit 40b6f78

Browse files
Made val & test loss like train loss (#664)
Co-authored-by: O'Donnell, Garry (DLSLtd,RAL,LSCI) <garry.o'donnell@diamond.ac.uk> Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
1 parent aad3d24 commit 40b6f78

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

pl_bolts/models/regression/logistic_regression.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,8 @@ def training_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[st
7878
def validation_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
7979
x, y = batch
8080
x = x.view(x.size(0), -1)
81-
y_hat = self(x)
82-
acc = accuracy(y_hat, y)
81+
y_hat = self.linear(x)
82+
acc = accuracy(F.softmax(y_hat, -1), y)
8383
return {'val_loss': F.cross_entropy(y_hat, y), 'acc': acc}
8484

8585
def validation_epoch_end(self, outputs: List[Dict[str, Tensor]]) -> Dict[str, Tensor]:
@@ -91,8 +91,8 @@ def validation_epoch_end(self, outputs: List[Dict[str, Tensor]]) -> Dict[str, Te
9191

9292
def test_step(self, batch: Tuple[Tensor, Tensor], batch_idx: int) -> Dict[str, Tensor]:
9393
x = x.view(x.size(0), -1)
94-
y_hat = self(x)
95-
acc = accuracy(y_hat, y)
94+
y_hat = self.linear(x)
95+
acc = accuracy(F.softmax(y_hat, -1), y)
9696
return {'test_loss': F.cross_entropy(y_hat, y), 'acc': acc}
9797

9898
def test_epoch_end(self, outputs: List[Dict[str, Tensor]]) -> Dict[str, Tensor]:

0 commit comments

Comments
 (0)