Skip to content

Commit b291e0e

Browse files
committed
Updated converter in order to support the latest pytorch master.
1 parent 946ad65 commit b291e0e

File tree

2 files changed

+41
-39
lines changed

2 files changed

+41
-39
lines changed

pytorch2keras/converter.py

Lines changed: 17 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,22 +32,23 @@ def set_training(model, mode):
3232
model.train(old_mode)
3333

3434

35-
def _optimize_trace(trace, aten):
35+
def _optimize_graph(graph, aten):
3636
# run dce first to eliminate dead parts of the graph that might have been
3737
# left behind by things like symbolic_override
38-
torch._C._jit_pass_dce(trace)
39-
torch._C._jit_pass_lint(trace)
40-
41-
torch._C._jit_pass_peephole(trace)
42-
torch._C._jit_pass_lint(trace)
43-
torch._C._jit_pass_onnx(trace, aten)
44-
torch._C._jit_pass_lint(trace)
45-
torch._C._jit_pass_onnx_peephole(trace)
46-
torch._C._jit_pass_lint(trace)
47-
torch._C._jit_pass_dce(trace)
48-
torch._C._jit_pass_lint(trace)
49-
torch._C._jit_pass_canonicalize(trace)
50-
torch._C._jit_pass_lint(trace)
38+
torch._C._jit_pass_dce(graph)
39+
torch._C._jit_pass_lint(graph)
40+
41+
torch._C._jit_pass_peephole(graph)
42+
torch._C._jit_pass_lint(graph)
43+
graph = torch._C._jit_pass_onnx(graph, aten)
44+
torch._C._jit_pass_lint(graph)
45+
torch._C._jit_pass_onnx_peephole(graph)
46+
torch._C._jit_pass_lint(graph)
47+
torch._C._jit_pass_dce(graph)
48+
torch._C._jit_pass_lint(graph)
49+
graph = torch._C._jit_pass_canonicalize(graph)
50+
torch._C._jit_pass_lint(graph)
51+
return graph
5152

5253

5354
def get_node_id(node):
@@ -88,7 +89,8 @@ def pytorch_to_keras(
8889
raise RuntimeError("state_dict changed after running the tracer; "
8990
"something weird is happening in your model!")
9091

91-
_optimize_trace(trace, False)
92+
# _optimize_trace(trace, False)
93+
trace.set_graph(_optimize_graph(trace.graph(), False))
9294

9395
if verbose:
9496
print(trace.graph())

pytorch2keras/layers.py

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -738,29 +738,29 @@ def convert_upsample(params, w_name, scope_name, inputs, layers, weights):
738738

739739

740740
AVAILABLE_CONVERTERS = {
741-
'Conv': convert_conv,
742-
'ConvTranspose': convert_convtranspose,
743-
'Flatten': convert_flatten,
744-
'Gemm': convert_gemm,
745-
'MaxPool': convert_maxpool,
741+
'onnx::Conv': convert_conv,
742+
'onnx::ConvTranspose': convert_convtranspose,
743+
'onnx::Flatten': convert_flatten,
744+
'onnx::Gemm': convert_gemm,
745+
'onnx::MaxPool': convert_maxpool,
746746
'max_pool2d': convert_maxpool,
747-
'AveragePool': convert_avgpool,
748-
'Dropout': convert_dropout,
749-
'BatchNormalization': convert_batchnorm,
750-
'Add': convert_elementwise_add,
751-
'Mul': convert_elementwise_mul,
752-
'Sub': convert_elementwise_sub,
753-
'Concat': convert_concat,
754-
'Relu': convert_relu,
755-
'LeakyRelu': convert_lrelu,
756-
'Sigmoid': convert_sigmoid,
757-
'Softmax': convert_softmax,
758-
'Tanh': convert_tanh,
759-
'Transpose': convert_transpose,
760-
'Reshape': convert_reshape,
761-
'MatMul': convert_matmul,
762-
'Gather': convert_gather,
763-
'ReduceSum': convert_reduce_sum,
764-
'Constant': convert_constant,
765-
'Upsample': convert_upsample,
747+
'onnx::AveragePool': convert_avgpool,
748+
'onnx::Dropout': convert_dropout,
749+
'onnx::BatchNormalization': convert_batchnorm,
750+
'onnx::Add': convert_elementwise_add,
751+
'onnx::Mul': convert_elementwise_mul,
752+
'onnx::Sub': convert_elementwise_sub,
753+
'onnx::Concat': convert_concat,
754+
'onnx::Relu': convert_relu,
755+
'onnx::LeakyRelu': convert_lrelu,
756+
'onnx::Sigmoid': convert_sigmoid,
757+
'onnx::Softmax': convert_softmax,
758+
'onnx::Tanh': convert_tanh,
759+
'onnx::Transpose': convert_transpose,
760+
'onnx::Reshape': convert_reshape,
761+
'onnx::MatMul': convert_matmul,
762+
'onnx::Gather': convert_gather,
763+
'onnx::ReduceSum': convert_reduce_sum,
764+
'onnx::Constant': convert_constant,
765+
'onnx::Upsample': convert_upsample,
766766
}

0 commit comments

Comments
 (0)