Skip to content

Commit e5218f5

Browse files
author
Max Keller
committed
Add pyTorch mc_dropout example
1 parent 0f978be commit e5218f5

File tree

2 files changed

+126
-4
lines changed

2 files changed

+126
-4
lines changed

examples/pytorch_mc_dropout.py

Lines changed: 126 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
"""
2+
In this file the basic ModAL PyTorch DeepActiveLearner workflow is explained
3+
through an example on the MNIST dataset and the MC-Dropout-Bald query strategy.
4+
"""
5+
import sys
6+
import os
7+
import torch
8+
from torch import nn
9+
from skorch import NeuralNetClassifier
10+
11+
from modAL.models import DeepActiveLearner
12+
13+
# import of query strategies
14+
from modAL.dropout import mc_dropout_bald, mc_dropout_mean_st
15+
16+
import numpy as np
17+
from torch.utils.data import DataLoader
18+
from torchvision.transforms import ToTensor
19+
from torchvision.datasets import MNIST
20+
21+
# Standard Pytorch Model (Visit the PyTorch documentation for more details)
22+
class Torch_Model(nn.Module):
23+
def __init__(self,):
24+
super(Torch_Model, self).__init__()
25+
self.convs = nn.Sequential(
26+
nn.Conv2d(1, 32, 3),
27+
nn.ReLU(),
28+
nn.Conv2d(32, 64, 3),
29+
nn.ReLU(),
30+
nn.MaxPool2d(2),
31+
nn.Dropout(0.25)
32+
)
33+
self.fcs = nn.Sequential(
34+
nn.Linear(12*12*64, 128),
35+
nn.ReLU(),
36+
nn.Dropout(0.5),
37+
nn.Linear(128, 10),
38+
)
39+
40+
def forward(self, x):
41+
out = x
42+
out = self.convs(out)
43+
out = out.view(-1, 12*12*64)
44+
out = self.fcs(out)
45+
return out
46+
47+
48+
torch_model = Torch_Model()
49+
"""
50+
You can acquire from the layer_list the dropout_layer_indexes, which can then be passed on
51+
to the query strategies to decide which dropout layers should be active for the predictions.
52+
When no dropout_layer_indexes are passed, all dropout layers will be activated on default.
53+
"""
54+
layer_list = list(torch_model.modules())
55+
56+
device = "cuda" if torch.cuda.is_available() else "cpu"
57+
58+
# Use the NeuralNetClassifier from skorch to wrap the Pytorch model to the scikit-learn API
59+
classifier = NeuralNetClassifier(Torch_Model,
60+
criterion=torch.nn.CrossEntropyLoss,
61+
optimizer=torch.optim.Adam,
62+
train_split=None,
63+
verbose=1,
64+
device=device)
65+
66+
67+
# Load the Dataset
68+
mnist_data = MNIST('.', download=True, transform=ToTensor())
69+
dataloader = DataLoader(mnist_data, shuffle=True, batch_size=60000)
70+
X, y = next(iter(dataloader))
71+
72+
# read training data
73+
X_train, X_test, y_train, y_test = X[:50000], X[50000:], y[:50000], y[50000:]
74+
X_train = X_train.reshape(50000, 1, 28, 28)
75+
X_test = X_test.reshape(10000, 1, 28, 28)
76+
77+
# assemble initial data
78+
n_initial = 1000
79+
initial_idx = np.random.choice(
80+
range(len(X_train)), size=n_initial, replace=False)
81+
X_initial = X_train[initial_idx]
82+
y_initial = y_train[initial_idx]
83+
84+
85+
# generate the pool
86+
# remove the initial data from the training dataset
87+
X_pool = np.delete(X_train, initial_idx, axis=0)[:5000]
88+
y_pool = np.delete(y_train, initial_idx, axis=0)[:5000]
89+
90+
91+
# initialize ActiveLearner (Pass to him the skorch wrapped PyTorch model & the Query strategy)
92+
learner = DeepActiveLearner(
93+
estimator=classifier,
94+
query_strategy=mc_dropout_bald,
95+
)
96+
# initial teaching if desired (not necessary)
97+
learner.teach(X_initial, y_initial)
98+
99+
print("Score from sklearn: {}".format(learner.score(X_pool, y_pool)))
100+
101+
102+
# the active learning loop
103+
n_queries = 10
104+
X_teach = X_initial
105+
y_teach = y_initial
106+
107+
108+
for idx in range(n_queries):
109+
print('Query no. %d' % (idx + 1))
110+
"""
111+
Query new data (num_cycles are the number of dropout forward passes that should be performed)
112+
--> check the documentation of mc_dropout_bald in modAL/dropout.py to see all available parameters
113+
"""
114+
query_idx, metric_values = learner.query(
115+
X_pool, n_instances=100, dropout_layer_indexes=[7, 11], num_cycles=10)
116+
# Add queried instances
117+
X_teach = torch.cat((X_teach, X_pool[query_idx]))
118+
y_teach = torch.cat((y_teach, y_pool[query_idx]))
119+
learner.teach(X_teach, y_teach)
120+
121+
# remove queried instance from pool
122+
X_pool = np.delete(X_pool, query_idx, axis=0)
123+
y_pool = np.delete(y_pool, query_idx, axis=0)
124+
125+
# give us the model performance
126+
print("Model score: {}".format(learner.score(X_test, y_test)))

modAL/models/base.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -250,10 +250,6 @@ def _fit_on_new(self, X: modALinput, y: modALinput, bootstrap: bool = False, **f
250250
def predict(self, X: modALinput) -> Any:
251251
pass
252252

253-
@abc.abstractmethod
254-
def teach(self, X: modALinput, y: modALinput, bootstrap: bool = False, **fit_kwargs) -> Any:
255-
pass
256-
257253
def transform_without_estimating(self, X: modALinput) -> Union[np.ndarray, sp.csr_matrix]:
258254
"""
259255
Transforms the data as supplied to each learner's estimator and concatenates transformations.

0 commit comments

Comments
 (0)