Skip to content

Commit 83da7de

Browse files
committed
Updated graph tracing function.
1 parent ebad143 commit 83da7de

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

pytorch2keras/converter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def pytorch_to_keras(
8282
orig_state_dict_keys = _unique_state_dict(model).keys()
8383

8484
with set_training(model, training):
85-
trace, torch_out = torch.jit.trace(model, args)
85+
trace, torch_out = torch.jit.get_trace_graph(model, args)
8686

8787
if orig_state_dict_keys != _unique_state_dict(model).keys():
8888
raise RuntimeError("state_dict changed after running the tracer; "

pytorch2keras/layers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -660,7 +660,8 @@ def convert_reduce_sum(params, w_name, scope_name, inputs, layers, weights):
660660
print('Converting reduce_sum ...')
661661

662662
keepdims = params['keepdims'] > 0
663-
target_layer = lambda x: keras.backend.sum(x, keepdims=keepdims, axis=params['axes'])
663+
axis = np.array(params['axes'])
664+
target_layer = lambda x: keras.backend.sum(x, keepdims=keepdims, axis=axis)
664665

665666
lambda_layer = keras.layers.Lambda(target_layer)
666667
layers[scope_name] = lambda_layer(layers[inputs[0]])

0 commit comments

Comments
 (0)