Skip to content

Commit cc9dce0

Browse files
committed
Test now passing logits to metric
1 parent 09fa0d0 commit cc9dce0

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,8 +187,7 @@ def main():
187187
loss = criterion(logits, y)
188188
testloss.append(loss.item())
189189

190-
preds = th.argmax(logits, dim=1)
191-
test_metrics(y, preds)
190+
test_metrics(y, logits)
192191

193192
wandb.log(
194193
{"Epoch": 1, "Test loss": np.mean(testloss)}

0 commit comments

Comments
 (0)