Skip to content

Commit e144fc2

Browse files
committed
create evaluation script for selfclass MNIST
1 parent 96cf5bc commit e144fc2

File tree

2 files changed

+53
-6
lines changed

2 files changed

+53
-6
lines changed

tasks/selfclass_mnist/eval_selfclass_mnist.py

Lines changed: 52 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,9 @@
55

66
from ncalab import (
77
ClassificationNCAModel,
8-
BasicNCATrainer,
9-
WEIGHTS_PATH,
10-
show_batch_binary_image_classification,
118
get_compute_device,
9+
pad_input,
10+
WEIGHTS_PATH
1211
)
1312

1413
import click
@@ -21,8 +20,37 @@
2120
from torchvision import transforms # type: ignore[import-untyped]
2221

2322

23+
def print_MNIST_digit(image, prediction):
24+
BG = {
25+
0: "black",
26+
1: "red",
27+
2: "cyan",
28+
3: "green",
29+
4: "magenta",
30+
5: "yellow",
31+
6: "green",
32+
7: "white",
33+
8: "blue",
34+
9: "red"
35+
}
36+
FG = {
37+
0: "white",
38+
6: "red",
39+
7: "black",
40+
}
41+
for y in range(14):
42+
for x in range(14):
43+
if image[y * 2, x * 2] < 0.3:
44+
click.secho(" ", nl=False, fg="black", bg="black")
45+
continue
46+
n = prediction[y * 2, x * 2].detach().cpu()
47+
n = int(torch.argmax(n))
48+
click.secho(f" {n}", nl=False, fg=FG.get(n, "black"), bg=BG.get(n, "white"))
49+
click.secho()
50+
51+
2452
def eval_selfclass_mnist(
25-
batch_size: int, hidden_channels: int, gpu: bool, gpu_index: int
53+
hidden_channels: int, gpu: bool, gpu_index: int
2654
):
2755
mnist_test = MNIST(
2856
"mnist",
@@ -32,7 +60,7 @@ def eval_selfclass_mnist(
3260
)
3361

3462
loader_test = torch.utils.data.DataLoader(
35-
mnist_test, shuffle=True, batch_size=batch_size
63+
mnist_test, shuffle=True, batch_size=1
3664
)
3765

3866
device = get_compute_device(f"cuda:{gpu_index}" if gpu else "cpu")
@@ -44,6 +72,24 @@ def eval_selfclass_mnist(
4472
num_classes=10,
4573
pixel_wise_loss=True,
4674
)
75+
nca.load_state_dict(torch.load(WEIGHTS_PATH / "selfclass_mnist.pth", weights_only=True))
76+
nca.eval()
77+
78+
i = 1
79+
for image, _ in loader_test:
80+
if i == 0:
81+
break
82+
x = image.clone()
83+
x = pad_input(x, nca, noise=False)
84+
x = x.permute(0, 2, 3, 1).to(device)
85+
86+
prediction = nca(x, steps=50)[0]
87+
prediction = prediction[..., nca.num_image_channels + nca.num_hidden_channels :]
88+
print_MNIST_digit(image[0, 0], prediction)
89+
90+
if i != 1:
91+
click.secho("-" * 28 * 2)
92+
i -= 1
4793

4894

4995
@click.command()
@@ -54,7 +100,7 @@ def eval_selfclass_mnist(
54100
@click.option(
55101
"--gpu-index", type=int, default=0, help="Index of GPU to use, if --gpu in use."
56102
)
57-
def main(batch_size, hidden_channels, gpu, gpu_index):
103+
def main(hidden_channels, gpu, gpu_index):
58104
eval_selfclass_mnist(
59105
hidden_channels=hidden_channels,
60106
gpu=gpu,

tasks/selfclass_mnist/train_selfclass_mnist.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ def train_selfclass_mnist(
8484
)
8585
trainer.train_basic_nca(
8686
loader_train,
87+
# Validation is broken here. We're working on it!
8788
# loader_val,
8889
summary_writer=writer,
8990
plot_function=show_batch_binary_image_classification,

0 commit comments

Comments
 (0)