Skip to content

Commit c122704

Browse files
committed
Added workaround for capability with previous API. Updated docstring.
1 parent 511483f commit c122704

File tree

2 files changed

+8
-1
lines changed

2 files changed

+8
-1
lines changed

README.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,9 @@ To use the converter properly, please, make changes in your `~/.keras/keras.json
2424
...
2525
```
2626

27+
From the latest releases, multiple inputs is also supported.
28+
29+
2730
## Tensorflow.js
2831

2932
For the proper convertion to the tensorflow.js format, please use a new flag `short_names=True`.
@@ -70,7 +73,7 @@ We're using dummy-variable in order to trace the model.
7073
```
7174
from converter import pytorch_to_keras
7275
# we should specify shape of the input tensor
73-
k_model = pytorch_to_keras(model, input_var, (10, 32, 32,), verbose=True)
76+
k_model = pytorch_to_keras(model, input_var, [(10, 32, 32,)], verbose=True)
7477
```
7578

7679
That's all! If all is ok, the Keras model is stores into the `k_model` variable.

pytorch2keras/converter.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ def pytorch_to_keras(
8181
if isinstance(args, torch.autograd.Variable):
8282
args = (args, )
8383

84+
# Workaround for previous versions
85+
if isinstance(input_shapes, tuple):
86+
input_shapes = [input_shapes]
87+
8488
orig_state_dict_keys = _unique_state_dict(model).keys()
8589

8690
with set_training(model, training):

0 commit comments

Comments
 (0)