Skip to content

Commit 6f65582

Browse files
committed
fix evaluation script for mnist example task
1 parent 810d515 commit 6f65582

File tree

1 file changed

+15
-7
lines changed

1 file changed

+15
-7
lines changed

tasks/selfclass_mnist/eval_selfclass_mnist.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,24 @@
1+
from pathlib import Path
12
import os
23
import sys
34

45
root_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "../.."))
56
sys.path.append(root_dir)
67

7-
from ncalab import ClassificationNCAModel, get_compute_device, pad_input, WEIGHTS_PATH
8+
from ncalab import ClassificationNCAModel, get_compute_device, pad_input
89

910
import click
1011

1112
import torch
13+
import numpy as np
1214

1315
from torchvision.datasets import MNIST # type: ignore[import-untyped]
1416
from torchvision import transforms # type: ignore[import-untyped]
1517

18+
TASK_PATH = Path(__file__).parent.resolve()
19+
WEIGHTS_PATH = TASK_PATH / "weights"
20+
WEIGHTS_PATH.mkdir(exist_ok=True)
21+
1622

1723
def print_MNIST_digit(image, prediction, downscale: int = 2):
1824
assert downscale >= 1
@@ -41,8 +47,8 @@ def print_MNIST_digit(image, prediction, downscale: int = 2):
4147
if image[y * downscale, x * downscale] < 0.3:
4248
click.secho(" ", nl=False, fg="black", bg="black")
4349
continue
44-
n = prediction[y * downscale, x * downscale].detach().cpu()
45-
n = int(torch.argmax(n))
50+
n = prediction[y * downscale, x * downscale]
51+
n = int(np.argmax(n))
4652
click.secho(f" {n}", nl=False, fg=FG.get(n, "black"), bg=BG.get(n, "white"))
4753
click.secho()
4854

@@ -69,7 +75,9 @@ def eval_selfclass_mnist(
6975
pixel_wise_loss=True,
7076
)
7177
nca.load_state_dict(
72-
torch.load(WEIGHTS_PATH / "selfclass_mnist", weights_only=True)
78+
torch.load(
79+
WEIGHTS_PATH / "selfclass_mnist" / "best_model.pth", weights_only=True
80+
)
7381
)
7482
nca.eval()
7583

@@ -81,9 +89,9 @@ def eval_selfclass_mnist(
8189
x = pad_input(x, nca, noise=False)
8290
x = x.to(device)
8391

84-
prediction = nca(x, steps=50)[0]
85-
prediction = prediction[..., nca.num_image_channels + nca.num_hidden_channels :]
86-
print_MNIST_digit(image[0, 0], prediction)
92+
prediction = nca(x, steps=50)
93+
out = prediction.output_channels_np[0].transpose(1, 2, 0)
94+
print_MNIST_digit(image[0, 0], out)
8795

8896
if i != 1:
8997
click.secho("-" * 28)

0 commit comments

Comments
 (0)