66
77import torch
88import numpy as np
9- from keras . datasets import mnist
9+
1010from torch import nn
11+ from torch .utils .data import DataLoader
12+ from torchvision .transforms import ToTensor
13+ from torchvision .datasets import MNIST
1114from skorch import NeuralNetClassifier
15+
1216from modAL .models import ActiveLearner
1317
18+
1419# build class for the skorch API
1520class Torch_Model (nn .Module ):
1621 def __init__ (self ,):
@@ -37,28 +42,34 @@ def forward(self, x):
3742 out = self .fcs (out )
3843 return out
3944
45+
4046# create the classifier
47+ device = "cuda" if torch .cuda .is_available () else "cpu"
4148classifier = NeuralNetClassifier (Torch_Model ,
4249 # max_epochs=100,
4350 criterion = nn .CrossEntropyLoss ,
4451 optimizer = torch .optim .Adam ,
4552 train_split = None ,
4653 verbose = 1 ,
47- device = "cuda" )
54+ device = device )
4855
4956"""
5057Data wrangling
51- 1. Reading data from Keras
58+ 1. Reading data from torchvision
52592. Assembling initial training data for ActiveLearner
53603. Generating the pool
5461"""
5562
63+ mnist_data = MNIST ('.' , download = True , transform = ToTensor ())
64+ dataloader = DataLoader (mnist_data , shuffle = True , batch_size = 60000 )
65+ X , y = next (iter (dataloader ))
66+
5667# read training data
57- ( X_train , y_train ), ( X_test , y_test ) = mnist . load_data ()
58- X_train = X_train .reshape (60000 , 1 , 28 , 28 ). astype ( 'float32' ) / 255.
59- X_test = X_test .reshape (10000 , 1 , 28 , 28 ). astype ( 'float32' ) / 255.
60- y_train = y_train . astype ( 'long' )
61- y_test = y_test . astype ( 'long' )
68+ X_train , X_test , y_train , y_test = X [: 50000 ], X [ 50000 :], y [: 50000 ], y [ 50000 :]
69+ X_train = X_train .reshape (50000 , 1 , 28 , 28 )
70+ X_test = X_test .reshape (10000 , 1 , 28 , 28 )
71+ y_train = y_train
72+ y_test = y_test
6273
6374# assemble initial data
6475n_initial = 1000
@@ -85,7 +96,6 @@ def forward(self, x):
8596n_queries = 10
8697for idx in range (n_queries ):
8798 query_idx , query_instance = learner .query (X_pool , n_instances = 100 )
88- print (query_idx )
8999 learner .teach (X_pool [query_idx ], y_pool [query_idx ], only_new = True )
90100 # remove queried instance from pool
91101 X_pool = np .delete (X_pool , query_idx , axis = 0 )
0 commit comments