Skip to content

Commit f409358

Browse files
committed
Merge branch 'damienlancry-pytorch-integration' into dev
2 parents fc8e2e0 + 16dbc3f commit f409358

File tree

4 files changed

+385
-1
lines changed

4 files changed

+385
-1
lines changed
Lines changed: 277 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,277 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"Pytorch models in modAL workflows\n",
8+
"=============================\n",
9+
"\n",
10+
"Thanks to Skorch API, you can seamlessly integrate Pytorch models into your modAL workflow. In this tutorial, we shall quickly introduce how to use Skorch API of Keras and we are going to see how to do active learning with it. More details on the Keras scikit-learn API [can be found here](https://skorch.readthedocs.io/en/stable/).\n",
11+
"\n",
12+
"The executable script for this example can be [found here](https://github.com/cosmic-cortex/modAL/blob/master/examples/pytorch_integration.py)!"
13+
]
14+
},
15+
{
16+
"cell_type": "markdown",
17+
"metadata": {},
18+
"source": [
19+
"Skorch API\n",
20+
"-----------------------\n",
21+
"\n",
22+
"By default, a Pytorch model's interface differs from what is used for scikit-learn estimators. However, with the use of Skorch wrapper, it is possible to adapt your model."
23+
]
24+
},
25+
{
26+
"cell_type": "code",
27+
"execution_count": 1,
28+
"metadata": {},
29+
"outputs": [],
30+
"source": [
31+
"import torch\n",
32+
"from torch import nn\n",
33+
"from skorch import NeuralNetClassifier\n",
34+
"\n",
35+
"# build class for the skorch API\n",
36+
"class Torch_Model(nn.Module):\n",
37+
" def __init__(self,):\n",
38+
" super(Torch_Model, self).__init__()\n",
39+
" self.convs = nn.Sequential(\n",
40+
" nn.Conv2d(1,32,3),\n",
41+
" nn.ReLU(),\n",
42+
" nn.Conv2d(32,64,3),\n",
43+
" nn.ReLU(),\n",
44+
" nn.MaxPool2d(2),\n",
45+
" nn.Dropout(0.25)\n",
46+
" )\n",
47+
" self.fcs = nn.Sequential(\n",
48+
" nn.Linear(12*12*64,128),\n",
49+
" nn.ReLU(),\n",
50+
" nn.Dropout(0.5),\n",
51+
" nn.Linear(128,10),\n",
52+
" )\n",
53+
"\n",
54+
" def forward(self, x):\n",
55+
" out = x\n",
56+
" out = self.convs(out)\n",
57+
" out = out.view(-1,12*12*64)\n",
58+
" out = self.fcs(out)\n",
59+
" return out"
60+
]
61+
},
62+
{
63+
"cell_type": "markdown",
64+
"metadata": {},
65+
"source": [
66+
"For our purposes, the ``classifier`` which we will initialize now acts just like any scikit-learn estimator."
67+
]
68+
},
69+
{
70+
"cell_type": "code",
71+
"execution_count": 2,
72+
"metadata": {},
73+
"outputs": [],
74+
"source": [
75+
"# create the classifier\n",
76+
"device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
77+
"classifier = NeuralNetClassifier(Torch_Model,\n",
78+
" criterion=nn.CrossEntropyLoss,\n",
79+
" optimizer=torch.optim.Adam,\n",
80+
" train_split=None,\n",
81+
" verbose=1,\n",
82+
" device=device)"
83+
]
84+
},
85+
{
86+
"cell_type": "markdown",
87+
"metadata": {},
88+
"source": [
89+
"Active learning with Pytorch\n",
90+
"---------------------------------------\n",
91+
"\n",
92+
"In this example, we are going to use the famous MNIST dataset, which is available as a built-in for PyTorch."
93+
]
94+
},
95+
{
96+
"cell_type": "code",
97+
"execution_count": null,
98+
"metadata": {},
99+
"outputs": [
100+
{
101+
"name": "stderr",
102+
"output_type": "stream",
103+
"text": [
104+
"\r",
105+
"0it [00:00, ?it/s]"
106+
]
107+
},
108+
{
109+
"name": "stdout",
110+
"output_type": "stream",
111+
"text": [
112+
"Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to ./MNIST/raw/train-images-idx3-ubyte.gz\n"
113+
]
114+
},
115+
{
116+
"name": "stderr",
117+
"output_type": "stream",
118+
"text": [
119+
" 97%|█████████▋| 9584640/9912422 [00:15<00:00, 1777143.52it/s]"
120+
]
121+
},
122+
{
123+
"name": "stdout",
124+
"output_type": "stream",
125+
"text": [
126+
"Extracting ./MNIST/raw/train-images-idx3-ubyte.gz\n"
127+
]
128+
},
129+
{
130+
"name": "stderr",
131+
"output_type": "stream",
132+
"text": [
133+
"\n",
134+
"0it [00:00, ?it/s]\u001b[A"
135+
]
136+
},
137+
{
138+
"name": "stdout",
139+
"output_type": "stream",
140+
"text": [
141+
"Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to ./MNIST/raw/train-labels-idx1-ubyte.gz\n"
142+
]
143+
},
144+
{
145+
"name": "stderr",
146+
"output_type": "stream",
147+
"text": [
148+
"\n",
149+
" 0%| | 0/28881 [00:00<?, ?it/s]\u001b[A\n",
150+
" 57%|█████▋ | 16384/28881 [00:00<00:00, 62622.03it/s]\u001b[A\n",
151+
"32768it [00:00, 41627.01it/s] \u001b[A\n",
152+
"0it [00:00, ?it/s]\u001b[A"
153+
]
154+
},
155+
{
156+
"name": "stdout",
157+
"output_type": "stream",
158+
"text": [
159+
"Extracting ./MNIST/raw/train-labels-idx1-ubyte.gz\n",
160+
"Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to ./MNIST/raw/t10k-images-idx3-ubyte.gz\n"
161+
]
162+
},
163+
{
164+
"name": "stderr",
165+
"output_type": "stream",
166+
"text": [
167+
"\n",
168+
" 0%| | 0/1648877 [00:00<?, ?it/s]\u001b[A"
169+
]
170+
}
171+
],
172+
"source": [
173+
"import numpy as np\n",
174+
"from torch.utils.data import DataLoader\n",
175+
"from torchvision.transforms import ToTensor\n",
176+
"from torchvision.datasets import MNIST\n",
177+
"\n",
178+
"\n",
179+
"mnist_data = MNIST('.', download=True, transform=ToTensor())\n",
180+
"dataloader = DataLoader(mnist_data, shuffle=True, batch_size=60000)\n",
181+
"X, y = next(iter(dataloader))\n",
182+
"\n",
183+
"# read training data\n",
184+
"X_train, X_test, y_train, y_test = X[:50000], X[50000:], y[:50000], y[50000:]\n",
185+
"X_train = X_train.reshape(50000, 1, 28, 28)\n",
186+
"X_test = X_test.reshape(10000, 1, 28, 28)\n",
187+
"\n",
188+
"# assemble initial data\n",
189+
"n_initial = 1000\n",
190+
"initial_idx = np.random.choice(range(len(X_train)), size=n_initial, replace=False)\n",
191+
"X_initial = X_train[initial_idx]\n",
192+
"y_initial = y_train[initial_idx]\n",
193+
"\n",
194+
"# generate the pool\n",
195+
"# remove the initial data from the training dataset\n",
196+
"X_pool = np.delete(X_train, initial_idx, axis=0)[:5000]\n",
197+
"y_pool = np.delete(y_train, initial_idx, axis=0)[:5000]"
198+
]
199+
},
200+
{
201+
"cell_type": "markdown",
202+
"metadata": {},
203+
"source": [
204+
"Active learning with data and classifier ready is as easy as always. Because training is *very* expensive in large neural networks, this time we are going to query the best 200 instances each time we measure the uncertainty of the pool."
205+
]
206+
},
207+
{
208+
"cell_type": "code",
209+
"execution_count": null,
210+
"metadata": {},
211+
"outputs": [],
212+
"source": [
213+
"from modAL.models import ActiveLearner\n",
214+
"\n",
215+
"# initialize ActiveLearner\n",
216+
"learner = ActiveLearner(\n",
217+
" estimator=classifier,\n",
218+
" X_training=X_initial, y_training=y_initial,\n",
219+
")"
220+
]
221+
},
222+
{
223+
"cell_type": "markdown",
224+
"metadata": {},
225+
"source": [
226+
"To make sure that you train only on newly queried labels, pass ``only_new=True`` to the ``.teach()`` method of the learner."
227+
]
228+
},
229+
{
230+
"cell_type": "code",
231+
"execution_count": null,
232+
"metadata": {},
233+
"outputs": [],
234+
"source": [
235+
"# the active learning loop\n",
236+
"n_queries = 10\n",
237+
"for idx in range(n_queries):\n",
238+
" print('Query no. %d' % (idx + 1))\n",
239+
" query_idx, query_instance = learner.query(X_pool, n_instances=100)\n",
240+
" learner.teach(\n",
241+
" X=X_pool[query_idx], y=y_pool[query_idx], only_new=True,\n",
242+
" )\n",
243+
" # remove queried instance from pool\n",
244+
" X_pool = np.delete(X_pool, query_idx, axis=0)\n",
245+
" y_pool = np.delete(y_pool, query_idx, axis=0)"
246+
]
247+
},
248+
{
249+
"cell_type": "code",
250+
"execution_count": null,
251+
"metadata": {},
252+
"outputs": [],
253+
"source": []
254+
}
255+
],
256+
"metadata": {
257+
"kernelspec": {
258+
"display_name": "Python [conda env:modAL] *",
259+
"language": "python",
260+
"name": "conda-env-modAL-py"
261+
},
262+
"language_info": {
263+
"codemirror_mode": {
264+
"name": "ipython",
265+
"version": 3
266+
},
267+
"file_extension": ".py",
268+
"mimetype": "text/x-python",
269+
"name": "python",
270+
"nbconvert_exporter": "python",
271+
"pygments_lexer": "ipython3",
272+
"version": "3.7.3"
273+
}
274+
},
275+
"nbformat": 4,
276+
"nbformat_minor": 2
277+
}

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ Currently supported active learning strategies are
6767
content/examples/query_by_committee
6868
content/examples/bootstrapping_and_bagging
6969
content/examples/Keras_integration
70+
content/examples/Pytorch_integration
7071

7172
.. toctree::
7273
:glob:

examples/pytorch_integration.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
"""
2+
This example demonstrates how to use the active learning interface with Pytorch.
3+
The example uses Skorch, a scikit learn wrapper of Pytorch.
4+
For more info, see https://skorch.readthedocs.io/en/stable/
5+
"""
6+
7+
import torch
8+
import numpy as np
9+
10+
from torch import nn
11+
from torch.utils.data import DataLoader
12+
from torchvision.transforms import ToTensor
13+
from torchvision.datasets import MNIST
14+
from skorch import NeuralNetClassifier
15+
16+
from modAL.models import ActiveLearner
17+
18+
19+
# build class for the skorch API
20+
class Torch_Model(nn.Module):
21+
def __init__(self,):
22+
super(Torch_Model, self).__init__()
23+
self.convs = nn.Sequential(
24+
nn.Conv2d(1,32,3),
25+
nn.ReLU(),
26+
nn.Conv2d(32,64,3),
27+
nn.ReLU(),
28+
nn.MaxPool2d(2),
29+
nn.Dropout(0.25)
30+
)
31+
self.fcs = nn.Sequential(
32+
nn.Linear(12*12*64,128),
33+
nn.ReLU(),
34+
nn.Dropout(0.5),
35+
nn.Linear(128,10),
36+
)
37+
38+
def forward(self, x):
39+
out = x
40+
out = self.convs(out)
41+
out = out.view(-1,12*12*64)
42+
out = self.fcs(out)
43+
return out
44+
45+
46+
# create the classifier
47+
device = "cuda" if torch.cuda.is_available() else "cpu"
48+
classifier = NeuralNetClassifier(Torch_Model,
49+
# max_epochs=100,
50+
criterion=nn.CrossEntropyLoss,
51+
optimizer=torch.optim.Adam,
52+
train_split=None,
53+
verbose=1,
54+
device=device)
55+
56+
"""
57+
Data wrangling
58+
1. Reading data from torchvision
59+
2. Assembling initial training data for ActiveLearner
60+
3. Generating the pool
61+
"""
62+
63+
mnist_data = MNIST('.', download=True, transform=ToTensor())
64+
dataloader = DataLoader(mnist_data, shuffle=True, batch_size=60000)
65+
X, y = next(iter(dataloader))
66+
67+
# read training data
68+
X_train, X_test, y_train, y_test = X[:50000], X[50000:], y[:50000], y[50000:]
69+
X_train = X_train.reshape(50000, 1, 28, 28)
70+
X_test = X_test.reshape(10000, 1, 28, 28)
71+
72+
# assemble initial data
73+
n_initial = 1000
74+
initial_idx = np.random.choice(range(len(X_train)), size=n_initial, replace=False)
75+
X_initial = X_train[initial_idx]
76+
y_initial = y_train[initial_idx]
77+
78+
# generate the pool
79+
# remove the initial data from the training dataset
80+
X_pool = np.delete(X_train, initial_idx, axis=0)
81+
y_pool = np.delete(y_train, initial_idx, axis=0)
82+
83+
"""
84+
Training the ActiveLearner
85+
"""
86+
87+
# initialize ActiveLearner
88+
learner = ActiveLearner(
89+
estimator=classifier,
90+
X_training=X_initial, y_training=y_initial,
91+
)
92+
93+
# the active learning loop
94+
n_queries = 10
95+
for idx in range(n_queries):
96+
query_idx, query_instance = learner.query(X_pool, n_instances=100)
97+
learner.teach(X_pool[query_idx], y_pool[query_idx], only_new=True)
98+
# remove queried instance from pool
99+
X_pool = np.delete(X_pool, query_idx, axis=0)
100+
y_pool = np.delete(y_pool, query_idx, axis=0)
101+
102+
# the final accuracy score
103+
print(learner.score(X_test, y_test))

modAL/utils/data.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,7 @@ def data_vstack(blocks: Container) -> modALinput:
2525
elif sp.issparse(blocks[0]):
2626
return sp.vstack(blocks)
2727
else:
28-
raise TypeError('%s datatype is not supported' % type(blocks[0]))
28+
try:
29+
return np.concatenate(blocks)
30+
except:
31+
raise TypeError('%s datatype is not supported' % type(blocks[0]))

0 commit comments

Comments
 (0)