Skip to content

Commit caec2c7

Browse files
committed
keras mnist replaced with torchvision mnist
1 parent e79b171 commit caec2c7

File tree

1 file changed

+19
-9
lines changed

1 file changed

+19
-9
lines changed

examples/pytorch_integration.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,16 @@
66

77
import torch
88
import numpy as np
9-
from keras.datasets import mnist
9+
1010
from torch import nn
11+
from torch.utils.data import DataLoader
12+
from torchvision.transforms import ToTensor
13+
from torchvision.datasets import MNIST
1114
from skorch import NeuralNetClassifier
15+
1216
from modAL.models import ActiveLearner
1317

18+
1419
# build class for the skorch API
1520
class Torch_Model(nn.Module):
1621
def __init__(self,):
@@ -37,28 +42,34 @@ def forward(self, x):
3742
out = self.fcs(out)
3843
return out
3944

45+
4046
# create the classifier
47+
device = "cuda" if torch.cuda.is_available() else "cpu"
4148
classifier = NeuralNetClassifier(Torch_Model,
4249
# max_epochs=100,
4350
criterion=nn.CrossEntropyLoss,
4451
optimizer=torch.optim.Adam,
4552
train_split=None,
4653
verbose=1,
47-
device="cuda")
54+
device=device)
4855

4956
"""
5057
Data wrangling
51-
1. Reading data from Keras
58+
1. Reading data from torchvision
5259
2. Assembling initial training data for ActiveLearner
5360
3. Generating the pool
5461
"""
5562

63+
mnist_data = MNIST('.', download=True, transform=ToTensor())
64+
dataloader = DataLoader(mnist_data, shuffle=True, batch_size=60000)
65+
X, y = next(iter(dataloader))
66+
5667
# read training data
57-
(X_train, y_train), (X_test, y_test) = mnist.load_data()
58-
X_train = X_train.reshape(60000, 1, 28, 28).astype('float32') / 255.
59-
X_test = X_test.reshape(10000, 1, 28, 28).astype('float32') / 255.
60-
y_train = y_train.astype('long')
61-
y_test = y_test.astype('long')
68+
X_train, X_test, y_train, y_test = X[:50000], X[50000:], y[:50000], y[50000:]
69+
X_train = X_train.reshape(50000, 1, 28, 28)
70+
X_test = X_test.reshape(10000, 1, 28, 28)
71+
y_train = y_train
72+
y_test = y_test
6273

6374
# assemble initial data
6475
n_initial = 1000
@@ -85,7 +96,6 @@ def forward(self, x):
8596
n_queries = 10
8697
for idx in range(n_queries):
8798
query_idx, query_instance = learner.query(X_pool, n_instances=100)
88-
print(query_idx)
8999
learner.teach(X_pool[query_idx], y_pool[query_idx], only_new=True)
90100
# remove queried instance from pool
91101
X_pool = np.delete(X_pool, query_idx, axis=0)

0 commit comments

Comments
 (0)