Skip to content

Commit c971d7a

Browse files
author
rbodo
committed
Implemented parser for pytorch models.
1 parent 4b143a4 commit c971d7a

File tree

8 files changed

+396
-11
lines changed

8 files changed

+396
-11
lines changed

snntoolbox/bin/utils.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def run_pipeline(config, queue=None):
6565

6666
normset, testset = get_dataset(config)
6767

68+
results = None
6869
parsed_model = None
6970
if config.getboolean('tools', 'parse') and not is_stop(queue):
7071

@@ -80,9 +81,10 @@ def run_pipeline(config, queue=None):
8081
if config.getboolean('tools', 'evaluate_ann') and not is_stop(queue):
8182
print("Evaluating input model on {} samples...".format(
8283
num_to_test))
83-
model_lib.evaluate(input_model['val_fn'],
84-
config.getint('simulation', 'batch_size'),
85-
num_to_test, **testset)
84+
acc = model_lib.evaluate(input_model['val_fn'],
85+
config.getint('simulation', 'batch_size'),
86+
num_to_test, **testset)
87+
results = [acc]
8688

8789
# ____________________________ PARSE ________________________________ #
8890

@@ -100,8 +102,10 @@ def run_pipeline(config, queue=None):
100102
if config.getboolean('tools', 'evaluate_ann') and not is_stop(queue):
101103
print("Evaluating parsed model on {} samples...".format(
102104
num_to_test))
103-
model_parser.evaluate(config.getint(
104-
'simulation', 'batch_size'), num_to_test, **testset)
105+
score = model_parser.evaluate(
106+
config.getint('simulation', 'batch_size'),
107+
num_to_test, **testset)
108+
results = [score[1]]
105109

106110
# Write parsed model to disk
107111
parsed_model.save(str(
@@ -153,7 +157,7 @@ def run(snn, **test_set):
153157
if queue:
154158
queue.put(results)
155159

156-
return results
160+
return results
157161

158162

159163
def is_stop(queue):

snntoolbox/config_defaults

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ plotproperties = {
116116
# validity of config.
117117

118118
[restrictions]
119-
model_libs = {'keras', 'lasagne', 'caffe'}
119+
model_libs = {'keras', 'lasagne', 'caffe', 'pytorch'}
120120
dataset_formats = {'npz', 'jpg', 'aedat'}
121121
frame_gen_method = {'signed_sum', 'rectified_sum'}
122122
maxpool_types = {'fir_max', 'exp_max', 'avg_max'}
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# -*- coding: utf-8 -*-
2+
"""Keras model parser.
3+
4+
@author: rbodo
5+
"""
6+
7+
import os
8+
import numpy as np
9+
10+
import keras
11+
import torch
12+
import onnx
13+
from onnx2keras import onnx_to_keras
14+
import onnxruntime
15+
16+
from snntoolbox.parsing.model_libs import keras_input_lib
17+
from snntoolbox.utils.utils import import_script
18+
19+
20+
def to_numpy(tensor):
21+
return tensor.detach().cpu().numpy() if tensor.requires_grad \
22+
else tensor.cpu().numpy()
23+
24+
25+
class ModelParser(keras_input_lib.ModelParser):
26+
27+
def try_insert_flatten(self, layer, idx, name_map):
28+
return False
29+
30+
31+
def load(path, filename):
32+
"""Load network from file.
33+
34+
Parameters
35+
----------
36+
37+
path: str
38+
Path to directory where to load model from.
39+
40+
filename: str
41+
Name of file to load model from.
42+
43+
Returns
44+
-------
45+
46+
: dict[str, Union[keras.models.Sequential, function]]
47+
A dictionary of objects that constitute the input model. It must
48+
contain the following two keys:
49+
50+
- 'model': keras.models.Sequential
51+
Keras model instance of the network.
52+
- 'val_fn': function
53+
Function that allows evaluating the original model.
54+
"""
55+
56+
filepath = str(os.path.join(path, filename))
57+
58+
# Load the Pytorch model.
59+
mod = import_script(path, filename)
60+
model_pytorch = mod.Model()
61+
model_pytorch.load_state_dict(torch.load(filepath + '.pkl'))
62+
63+
# Switch from train to eval mode to ensure Dropout / BatchNorm is handled
64+
# correctly.
65+
model_pytorch.eval()
66+
67+
# Run on dummy input with correct shape to trace the Pytorch model.
68+
input_shape = [1] + list(model_pytorch.input_shape)
69+
input_numpy = np.random.random_sample(input_shape).astype(np.float32)
70+
input_torch = torch.from_numpy(input_numpy).float()
71+
output_torch = model_pytorch(input_torch)
72+
output_numpy = to_numpy(output_torch)
73+
74+
# Export as onnx model, and then reload.
75+
input_names = ['input_0']
76+
output_names = ['output_{}'.format(i) for i in range(len(output_torch))]
77+
dynamic_axes = {'input_0': {0: 'batch_size'}}
78+
dynamic_axes.update({name: {0: 'batch_size'} for name in output_names})
79+
torch.onnx.export(model_pytorch, input_torch, filepath + '.onnx',
80+
input_names=input_names,
81+
output_names=output_names,
82+
dynamic_axes=dynamic_axes)
83+
model_onnx = onnx.load(filepath + '.onnx')
84+
# onnx.checker.check_model(model_onnx) # Crashes with segmentation fault.
85+
86+
# Compute ONNX Runtime output prediction.
87+
ort_session = onnxruntime.InferenceSession(filepath + '.onnx')
88+
input_onnx = {ort_session.get_inputs()[0].name: input_numpy}
89+
output_onnx = ort_session.run(None, input_onnx)
90+
91+
# Compare ONNX Runtime and PyTorch results.
92+
err_msg = "Pytorch model could not be ported to ONNX. Output difference: "
93+
np.testing.assert_allclose(output_numpy, output_onnx[0],
94+
rtol=1e-03, atol=1e-05, err_msg=err_msg)
95+
print("Pytorch model was successfully ported to ONNX.")
96+
97+
# Port ONNX model to Keras.
98+
model_keras = onnx_to_keras(model_onnx, input_names, [input_shape[1:]])
99+
100+
# Save the keras model.
101+
keras.models.save_model(model_keras, filepath + '.h5')
102+
103+
# Loading the model here is a workaround for version conflicts with
104+
# TF > 2.0.1 and keras > 2.2.5. Should be able to remove this later.
105+
model_keras = keras.models.load_model(filepath + '.h5')
106+
model_keras.compile('sgd', 'categorical_crossentropy',
107+
['accuracy', keras.metrics.top_k_categorical_accuracy])
108+
109+
# Compute Keras output and compare against ONNX.
110+
output_keras = model_keras.predict(input_numpy)
111+
err_msg = "ONNX model could not be ported to Keras. Output difference: "
112+
np.testing.assert_allclose(output_numpy, output_keras,
113+
rtol=1e-03, atol=1e-05, err_msg=err_msg)
114+
print("ONNX model was successfully ported to Keras.")
115+
116+
return {'model': model_keras, 'val_fn': model_keras.evaluate}
117+
118+
119+
def evaluate(*args, **kwargs):
120+
return keras_input_lib.evaluate(*args, **kwargs)

tests/conftest.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,79 @@ def _model_3(_dataset):
227227
return model
228228

229229

230+
@pytest.fixture(scope='session')
231+
def _model_4(_dataset):
232+
233+
if not is_module_installed('torch'):
234+
return
235+
236+
import torch
237+
import torch.nn as nn
238+
from tests.parsing.models import pytorch
239+
240+
x_train, y_train, x_test, y_test = _dataset
241+
242+
# Pytorch doesn't support one-hot labels.
243+
y_train = np.argmax(y_train, 1)
244+
y_test = np.argmax(y_test, 1)
245+
246+
# Pytorch needs channel dimension first.
247+
if keras.backend.image_data_format() == 'channels_last':
248+
x_train = np.moveaxis(x_train, 3, 1)
249+
x_test = np.moveaxis(x_test, 3, 1)
250+
251+
class PytorchDataset(torch.utils.data.Dataset):
252+
def __init__(self, data, target, transform=None):
253+
self.data = torch.from_numpy(data).float()
254+
self.target = torch.from_numpy(target).long()
255+
self.transform = transform
256+
257+
def __getitem__(self, index):
258+
x = self.data[index]
259+
260+
if self.transform:
261+
x = self.transform(x)
262+
263+
return x, self.target[index]
264+
265+
def __len__(self):
266+
return len(self.data)
267+
268+
trainset = torch.utils.data.DataLoader(PytorchDataset(x_train, y_train),
269+
batch_size=64)
270+
testset = torch.utils.data.DataLoader(PytorchDataset(x_test, y_test),
271+
batch_size=64)
272+
273+
model = pytorch.Model()
274+
275+
criterion = nn.CrossEntropyLoss()
276+
optimizer = torch.optim.Adam(model.parameters())
277+
278+
acc = 0
279+
for epoch in range(3):
280+
for i, (xx, y) in enumerate(trainset):
281+
optimizer.zero_grad()
282+
outputs = model(xx)
283+
loss = criterion(outputs, y)
284+
loss.backward()
285+
optimizer.step()
286+
287+
total = 0
288+
correct = 0
289+
with torch.no_grad():
290+
for xx, y in testset:
291+
outputs = model(xx)
292+
_, predicted = torch.max(outputs.data, 1)
293+
total += y.size(0)
294+
correct += (predicted == y).sum().item()
295+
acc = correct / total
296+
297+
print("Test accuracy: {:.2%}".format(acc))
298+
# assert acc > 0.96, "Test accuracy after training not high enough."
299+
300+
return model
301+
302+
230303
spinnaker_conditions = (is_module_installed('keras_rewiring') and
231304
is_module_installed('pynn_object_serialisation') and
232305
(is_module_installed('pyNN.spiNNaker') or
@@ -243,6 +316,12 @@ def _model_3(_dataset):
243316
brian2_skip_if_dependency_missing = pytest.mark.skipif(
244317
not brian2_conditions, reason="Brian2 dependency missing.")
245318

319+
pytorch_conditions = (is_module_installed('torch') and
320+
is_module_installed('onnx') and
321+
is_module_installed('onnx2keras'))
322+
pytorch_skip_if_dependency_missing = pytest.mark.skipif(
323+
not pytorch_conditions, reason='Pytorch dependencies missing')
324+
246325

247326
def get_examples():
248327
path = os.path.abspath(os.path.join(os.path.dirname(__file__), '..',

tests/core/test_models.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -112,12 +112,11 @@ def test_inisim(self, _model_2, _config):
112112
os.path.join(path_wd, model_name + '.h5'))
113113

114114
updates = {
115-
'tools': {'evaluate_ann': False}, 'simulation': {
116-
'simulator': 'INI',
115+
'tools': {'evaluate_ann': False},
116+
'simulation': {
117117
'duration': 100,
118118
'num_to_test': 100,
119-
'batch_size': 50,
120-
'keras_backend': 'tensorflow'},
119+
'batch_size': 50},
121120
'output': {
122121
'log_vars': {'activations_n_b_l', 'spiketrains_n_b_l_t'}}}
123122

tests/parsing/models/__init__.py

Whitespace-only changes.

tests/parsing/models/pytorch.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
import torch
2+
import torch.nn as nn
3+
4+
5+
class Model(nn.Module):
6+
def __init__(self):
7+
super(Model, self).__init__()
8+
9+
self.input_shape = (1, 28, 28)
10+
11+
layers_trunk = [
12+
nn.Conv2d(1, 16, kernel_size=5, stride=2),
13+
# BatchNorm doesn't work with Keras==2.3.1 because for some reason
14+
# they put the batch-norm axis in a list.
15+
# nn.BatchNorm2d(16),
16+
nn.ReLU(),
17+
nn.AvgPool2d(kernel_size=2, stride=2)]
18+
layers_branch1 = [
19+
nn.Conv2d(16, 32, kernel_size=3, padding=1),
20+
nn.ReLU()]
21+
layers_branch2 = [
22+
nn.Conv2d(16, 8, kernel_size=1),
23+
nn.ReLU()]
24+
layers_head = [
25+
nn.Conv2d(40, 8, kernel_size=1),
26+
nn.ReLU()]
27+
layers_classifier = [
28+
nn.Dropout(1e-5),
29+
nn.Linear(288, 10),
30+
nn.Softmax(1)]
31+
self.trunk = nn.Sequential(*layers_trunk)
32+
self.branch1 = nn.Sequential(*layers_branch1)
33+
self.branch2 = nn.Sequential(*layers_branch2)
34+
self.head = nn.Sequential(*layers_head)
35+
self.classifier = nn.Sequential(*layers_classifier)
36+
37+
def forward(self, x):
38+
x = self.trunk(x)
39+
x1 = self.branch1(x)
40+
x2 = self.branch2(x)
41+
x = torch.cat([x1, x2], 1)
42+
x = self.head(x)
43+
x = x.view(-1, 288) # Flatten
44+
x = self.classifier(x)
45+
return x

0 commit comments

Comments
 (0)