55
66from ncalab import (
77 ClassificationNCAModel ,
8- BasicNCATrainer ,
9- WEIGHTS_PATH ,
10- show_batch_binary_image_classification ,
118 get_compute_device ,
9+ pad_input ,
10+ WEIGHTS_PATH
1211)
1312
1413import click
2120from torchvision import transforms # type: ignore[import-untyped]
2221
2322
23+ def print_MNIST_digit (image , prediction ):
24+ BG = {
25+ 0 : "black" ,
26+ 1 : "red" ,
27+ 2 : "cyan" ,
28+ 3 : "green" ,
29+ 4 : "magenta" ,
30+ 5 : "yellow" ,
31+ 6 : "green" ,
32+ 7 : "white" ,
33+ 8 : "blue" ,
34+ 9 : "red"
35+ }
36+ FG = {
37+ 0 : "white" ,
38+ 6 : "red" ,
39+ 7 : "black" ,
40+ }
41+ for y in range (14 ):
42+ for x in range (14 ):
43+ if image [y * 2 , x * 2 ] < 0.3 :
44+ click .secho (" " , nl = False , fg = "black" , bg = "black" )
45+ continue
46+ n = prediction [y * 2 , x * 2 ].detach ().cpu ()
47+ n = int (torch .argmax (n ))
48+ click .secho (f" { n } " , nl = False , fg = FG .get (n , "black" ), bg = BG .get (n , "white" ))
49+ click .secho ()
50+
51+
2452def eval_selfclass_mnist (
25- batch_size : int , hidden_channels : int , gpu : bool , gpu_index : int
53+ hidden_channels : int , gpu : bool , gpu_index : int
2654):
2755 mnist_test = MNIST (
2856 "mnist" ,
@@ -32,7 +60,7 @@ def eval_selfclass_mnist(
3260 )
3361
3462 loader_test = torch .utils .data .DataLoader (
35- mnist_test , shuffle = True , batch_size = batch_size
63+ mnist_test , shuffle = True , batch_size = 1
3664 )
3765
3866 device = get_compute_device (f"cuda:{ gpu_index } " if gpu else "cpu" )
@@ -44,6 +72,24 @@ def eval_selfclass_mnist(
4472 num_classes = 10 ,
4573 pixel_wise_loss = True ,
4674 )
75+ nca .load_state_dict (torch .load (WEIGHTS_PATH / "selfclass_mnist.pth" , weights_only = True ))
76+ nca .eval ()
77+
78+ i = 1
79+ for image , _ in loader_test :
80+ if i == 0 :
81+ break
82+ x = image .clone ()
83+ x = pad_input (x , nca , noise = False )
84+ x = x .permute (0 , 2 , 3 , 1 ).to (device )
85+
86+ prediction = nca (x , steps = 50 )[0 ]
87+ prediction = prediction [..., nca .num_image_channels + nca .num_hidden_channels :]
88+ print_MNIST_digit (image [0 , 0 ], prediction )
89+
90+ if i != 1 :
91+ click .secho ("-" * 28 * 2 )
92+ i -= 1
4793
4894
4995@click .command ()
@@ -54,7 +100,7 @@ def eval_selfclass_mnist(
54100@click .option (
55101 "--gpu-index" , type = int , default = 0 , help = "Index of GPU to use, if --gpu in use."
56102)
57- def main (batch_size , hidden_channels , gpu , gpu_index ):
103+ def main (hidden_channels , gpu , gpu_index ):
58104 eval_selfclass_mnist (
59105 hidden_channels = hidden_channels ,
60106 gpu = gpu ,
0 commit comments