Skip to content

Commit 511483f

Browse files
committed
Updated docstring.
1 parent 35ae095 commit 511483f

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

pytorch2keras/converter.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def get_node_id(node):
5858

5959

6060
def pytorch_to_keras(
61-
model, args, input_shape,
61+
model, args, input_shapes,
6262
change_ordering=False, training=False, verbose=False, short_names=False,
6363
):
6464
"""
@@ -67,7 +67,7 @@ def pytorch_to_keras(
6767
Args:
6868
model: pytorch model
6969
args: pytorch model arguments
70-
input_shape: keras input shape (using for InputLayer creation)
70+
input_shapes: keras input shapes (using for each InputLayer)
7171
change_ordering: change CHW to HWC
7272
training: switch model to training mode
7373
verbose: verbose output
@@ -120,7 +120,7 @@ def pytorch_to_keras(
120120
keras_inputs = []
121121
for i in range(len(args)):
122122
layers['input{0}'.format(i)] = keras.layers.InputLayer(
123-
input_shape=input_shape[i], name='input{0}'.format(i)
123+
input_shape=input_shapes[i], name='input{0}'.format(i)
124124
).output
125125
keras_inputs.append(layers['input{0}'.format(i)])
126126

0 commit comments

Comments
 (0)