1+ from pathlib import Path
12import os
23import sys
34
45root_dir = os .path .abspath (os .path .join (os .path .dirname (__file__ ), "../.." ))
56sys .path .append (root_dir )
67
7- from ncalab import ClassificationNCAModel , get_compute_device , pad_input , WEIGHTS_PATH
8+ from ncalab import ClassificationNCAModel , get_compute_device , pad_input
89
910import click
1011
1112import torch
13+ import numpy as np
1214
1315from torchvision .datasets import MNIST # type: ignore[import-untyped]
1416from torchvision import transforms # type: ignore[import-untyped]
1517
18+ TASK_PATH = Path (__file__ ).parent .resolve ()
19+ WEIGHTS_PATH = TASK_PATH / "weights"
20+ WEIGHTS_PATH .mkdir (exist_ok = True )
21+
1622
1723def print_MNIST_digit (image , prediction , downscale : int = 2 ):
1824 assert downscale >= 1
@@ -41,8 +47,8 @@ def print_MNIST_digit(image, prediction, downscale: int = 2):
4147 if image [y * downscale , x * downscale ] < 0.3 :
4248 click .secho (" " , nl = False , fg = "black" , bg = "black" )
4349 continue
44- n = prediction [y * downscale , x * downscale ]. detach (). cpu ()
45- n = int (torch .argmax (n ))
50+ n = prediction [y * downscale , x * downscale ]
51+ n = int (np .argmax (n ))
4652 click .secho (f" { n } " , nl = False , fg = FG .get (n , "black" ), bg = BG .get (n , "white" ))
4753 click .secho ()
4854
@@ -69,7 +75,9 @@ def eval_selfclass_mnist(
6975 pixel_wise_loss = True ,
7076 )
7177 nca .load_state_dict (
72- torch .load (WEIGHTS_PATH / "selfclass_mnist" , weights_only = True )
78+ torch .load (
79+ WEIGHTS_PATH / "selfclass_mnist" / "best_model.pth" , weights_only = True
80+ )
7381 )
7482 nca .eval ()
7583
@@ -81,9 +89,9 @@ def eval_selfclass_mnist(
8189 x = pad_input (x , nca , noise = False )
8290 x = x .to (device )
8391
84- prediction = nca (x , steps = 50 )[ 0 ]
85- prediction = prediction [..., nca . num_image_channels + nca . num_hidden_channels :]
86- print_MNIST_digit (image [0 , 0 ], prediction )
92+ prediction = nca (x , steps = 50 )
93+ out = prediction . output_channels_np [ 0 ]. transpose ( 1 , 2 , 0 )
94+ print_MNIST_digit (image [0 , 0 ], out )
8795
8896 if i != 1 :
8997 click .secho ("-" * 28 )
0 commit comments