Skip to content

Commit 78b11e8

Browse files
committed
Added ResNet34 test case.
1 parent 9ca22d8 commit 78b11e8

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

tests/resnet34.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
import numpy as np
2+
import torch
3+
from torch.autograd import Variable
4+
from pytorch2keras.converter import pytorch_to_keras
5+
import torchvision
6+
7+
8+
if __name__ == '__main__':
9+
max_error = 0
10+
for i in range(10):
11+
model = torchvision.models.resnet34()
12+
for m in model.modules():
13+
m.training = False
14+
15+
input_np = np.random.uniform(0, 1, (1, 3, 224, 224))
16+
input_var = Variable(torch.FloatTensor(input_np))
17+
output = model(input_var)
18+
19+
k_model = pytorch_to_keras(model, input_var, (3, 224, 224,), verbose=True)
20+
21+
pytorch_output = output.data.numpy()
22+
keras_output = k_model.predict(input_np)
23+
24+
error = np.max(pytorch_output - keras_output)
25+
print(error)
26+
if max_error < error:
27+
max_error = error
28+
29+
print('Max error: {0}'.format(max_error))

0 commit comments

Comments
 (0)