Skip to content

Commit 9447272

Browse files
committed
Update code structure.
1. Fixed linter errors; 2. Added ,travis; 3. Added Sum layer; 4. Minor tests clean up; 5. Changed imports to package-relatives.
1 parent b664499 commit 9447272

35 files changed

+147
-610
lines changed

.travis.yml

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
group: travis_latest
2+
language: python
3+
cache: pip
4+
python:
5+
- 2.7
6+
- 3.6
7+
#- nightly
8+
#- pypy
9+
#- pypy3
10+
matrix:
11+
allow_failures:
12+
- python: nightly
13+
- python: pypy
14+
- python: pypy3
15+
install:
16+
#- pip install -r requirements.txt
17+
- pip install flake8 # pytest # add another testing frameworks later
18+
before_script:
19+
# stop the build if there are Python syntax errors or undefined names
20+
- flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics
21+
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
22+
- flake8 . --count --exit-zero --max-complexity=32 --max-line-length=127 --statistics
23+
script:
24+
- true # pytest --capture=sys # add other tests here
25+
notifications:
26+
on_success: change
27+
on_failure: change # `always` will be the setting once code changes slow down

pytorch2keras/converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import contextlib
1010
from torch.jit import _unique_state_dict
1111

12-
from layers import AVAILABLE_CONVERTERS
12+
from .layers import AVAILABLE_CONVERTERS
1313

1414

1515
@contextlib.contextmanager

pytorch2keras/layers.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -436,6 +436,29 @@ def convert_elementwise_sub(
436436
layers[scope_name] = sub([model0, model1])
437437

438438

439+
def convert_sum(
440+
params, w_name, scope_name, inputs, layers, weights
441+
):
442+
"""
443+
Convert sum.
444+
445+
Args:
446+
params: dictionary with layer parameters
447+
w_name: name prefix in state_dict
448+
scope_name: pytorch scope name
449+
inputs: pytorch node inputs
450+
layers: dictionary with keras tensors
451+
weights: pytorch state_dict
452+
"""
453+
print('Converting Sum ...')
454+
455+
def target_layer(x):
456+
return keras.backend.sum(x)
457+
458+
lambda_layer = keras.layers.Lambda(target_layer)
459+
layers[scope_name] = lambda_layer(layers[inputs[0]])
460+
461+
439462
def convert_concat(params, w_name, scope_name, inputs, layers, weights):
440463
"""
441464
Convert concatenation.
@@ -469,6 +492,7 @@ def convert_relu(params, w_name, scope_name, inputs, layers, weights):
469492
"""
470493
print('Converting relu ...')
471494

495+
print(w_name, scope_name)
472496
tf_name = w_name + str(random.random())
473497
relu = keras.layers.Activation('relu', name=tf_name)
474498
layers[scope_name] = relu(layers[inputs[0]])
@@ -570,7 +594,6 @@ def convert_selu(params, w_name, scope_name, inputs, layers, weights):
570594
layers[scope_name] = selu(layers[inputs[0]])
571595

572596

573-
574597
def convert_transpose(params, w_name, scope_name, inputs, layers, weights):
575598
"""
576599
Convert transpose layer.
@@ -705,7 +728,9 @@ def convert_reduce_sum(params, w_name, scope_name, inputs, layers, weights):
705728

706729
keepdims = params['keepdims'] > 0
707730
axis = np.array(params['axes'])
708-
target_layer = lambda x: keras.backend.sum(x, keepdims=keepdims, axis=axis)
731+
732+
def target_layer(x, keepdims=keepdims, axis=axis):
733+
return keras.backend.sum(x, keepdims=keepdims, axis=axis)
709734

710735
lambda_layer = keras.layers.Lambda(target_layer)
711736
layers[scope_name] = lambda_layer(layers[inputs[0]])
@@ -725,7 +750,9 @@ def convert_constant(params, w_name, scope_name, inputs, layers, weights):
725750
"""
726751
print('Converting constant ...')
727752

728-
target_layer = lambda x: keras.backend.constant(np.float32(params['value']))
753+
def target_layer(params=params):
754+
return keras.backend.constant(np.float32(params['value']))
755+
729756
lambda_layer = keras.layers.Lambda(target_layer)
730757
layers[scope_name] = lambda_layer(layers[inputs[0]])
731758

@@ -782,7 +809,7 @@ def convert_padding(params, w_name, scope_name, inputs, layers, weights):
782809
padding_name = tf_name + '_pad'
783810
padding_layer = keras.layers.ZeroPadding2D(
784811
padding=((params['pads'][2], params['pads'][6]), (params['pads'][3], params['pads'][7])),
785-
name=tf_name
812+
name=padding_name
786813
)
787814

788815
layers[scope_name] = padding_layer(layers[inputs[0]])
@@ -801,6 +828,7 @@ def convert_padding(params, w_name, scope_name, inputs, layers, weights):
801828
'onnx::Add': convert_elementwise_add,
802829
'onnx::Mul': convert_elementwise_mul,
803830
'onnx::Sub': convert_elementwise_sub,
831+
'onnx::Sum': convert_sum,
804832
'onnx::Concat': convert_concat,
805833
'onnx::Relu': convert_relu,
806834
'onnx::LeakyRelu': convert_lrelu,

setup.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,21 @@
11
from setuptools import setup, find_packages
2-
from setuptools.command.develop import develop
3-
from setuptools.command.install import install
42

53

6-
try: # for pip >= 10
4+
try: # for pip >= 10
75
from pip._internal.req import parse_requirements
8-
except ImportError: # for pip <= 9.0.3
6+
except ImportError: # for pip <= 9.0.3
97
from pip.req import parse_requirements
108

9+
1110
# parse_requirements() returns generator of pip.req.InstallRequirement objects
1211
install_reqs = parse_requirements('requirements.txt', session='null')
1312

13+
1414
# reqs is a list of requirement
1515
# e.g. ['django==1.5.1', 'mezzanine==1.4.6']
1616
reqs = [str(ir.req) for ir in install_reqs]
1717

18+
1819
setup(name='pytorch2keras',
1920
version='0.1',
2021
description='The model convertor',

tests/alexnet.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
1-
import keras # work around segfault
2-
import sys
31
import numpy as np
4-
52
import torch
6-
import torchvision
73
from torch.autograd import Variable
8-
9-
sys.path.append('../pytorch2keras')
10-
from converter import pytorch_to_keras
11-
4+
from pytorch2keras.converter import pytorch_to_keras
5+
import torchvision
126

137
if __name__ == '__main__':
148
max_error = 0

tests/avg_pool.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
1-
import keras # work around segfault
2-
import sys
31
import numpy as np
4-
52
import torch
63
import torch.nn as nn
74
from torch.autograd import Variable
8-
9-
sys.path.append('../pytorch2keras')
10-
from converter import pytorch_to_keras
5+
from pytorch2keras.converter import pytorch_to_keras
116

127

138
class AvgPool(nn.Module):

tests/bn.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
1-
import keras # work around segfault
2-
import sys
31
import numpy as np
4-
52
import torch
63
import torch.nn as nn
74
from torch.autograd import Variable
8-
9-
sys.path.append('../pytorch2keras')
10-
from converter import pytorch_to_keras
5+
from pytorch2keras.converter import pytorch_to_keras
116

127

138
class TestConv2d(nn.Module):

tests/concat_many.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
1-
import keras # work around segfault
2-
import sys
31
import numpy as np
4-
52
import torch
63
import torch.nn as nn
74
from torch.autograd import Variable
8-
9-
sys.path.append('../pytorch2keras')
10-
from converter import pytorch_to_keras
5+
from pytorch2keras.converter import pytorch_to_keras
116

127

138
class TestConcatMany(nn.Module):

tests/const.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
1-
import keras # work around segfault
2-
import sys
31
import numpy as np
4-
52
import torch
63
import torch.nn as nn
74
from torch.autograd import Variable
8-
9-
sys.path.append('../pytorch2keras')
10-
from converter import pytorch_to_keras
5+
from pytorch2keras.converter import pytorch_to_keras
116

127

138
class TestConst(nn.Module):

tests/conv2d.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
1-
import keras # work around segfault
2-
import sys
31
import numpy as np
4-
52
import torch
63
import torch.nn as nn
74
from torch.autograd import Variable
8-
9-
sys.path.append('../pytorch2keras')
10-
from converter import pytorch_to_keras
5+
from pytorch2keras.converter import pytorch_to_keras
116

127

138
class TestConv2d(nn.Module):

0 commit comments

Comments
 (0)