Skip to content

Commit f377fe0

Browse files
authored
Merge pull request #1143 from JanFSchulte/parsing_fixes
Fixes to parsing of pytorch models when using torch functionals
2 parents c8e1857 + a0a573e commit f377fe0

File tree

2 files changed

+28
-14
lines changed

2 files changed

+28
-14
lines changed

hls4ml/converters/pytorch/pooling.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,19 @@ def parse_pooling_layer(operation, layer_name, input_names, input_shapes, node,
9090
layer['stride_height'] = node.kwargs['stride'][0]
9191
layer['stride_width'] = node.kwargs['stride'][1]
9292
else:
93-
layer['stride_height'] = node.kwargs['stride']
94-
layer['stride_width'] = node.kwargs['stride']
95-
if type(node.kwargs['kernel_size']) is tuple:
96-
layer['pool_height'] = node.kwargs['kernel_size'][0]
97-
layer['pool_width'] = node.kwargs['kernel_size'][1]
93+
if node.kwargs['stride'] is None:
94+
# if stride is not set it is supposed to default to the kernel size
95+
layer['stride_height'] = node.args[1]
96+
layer['stride_width'] = node.args[1]
97+
else:
98+
layer['stride_height'] = node.kwargs['stride']
99+
layer['stride_width'] = node.kwargs['stride']
100+
if type(node.args[1]) is tuple:
101+
layer['pool_height'] = node.args[1][0]
102+
layer['pool_width'] = node.args[1][1]
98103
else:
99-
layer['pool_height'] = node.kwargs['kernel_size']
100-
layer['pool_width'] = node.kwargs['kernel_size']
101-
104+
layer['pool_height'] = node.args[1]
105+
layer['pool_width'] = node.args[1]
102106
if type(node.kwargs['padding']) is tuple:
103107
padding = node.kwargs['padding']
104108
else:

hls4ml/converters/pytorch/reshape.py

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,23 @@ def parse_flatten_layer(operation, layer_name, input_names, input_shapes, node,
9393
layer['class_name'] = 'Reshape'
9494
layer['name'] = layer_name
9595
layer['inputs'] = input_names
96-
97-
start_dim = class_object.start_dim
98-
end_dim = class_object.end_dim
99-
if end_dim + 1 == 0 or end_dim + 1 > len(input_shapes[0]):
100-
end_dim = len(input_shapes[0])
96+
if node.op == 'call_module':
97+
start_dim = class_object.start_dim
98+
end_dim = class_object.end_dim
99+
if end_dim + 1 == 0 or end_dim + 1 > len(input_shapes[0]):
100+
end_dim = len(input_shapes[0])
101+
else:
102+
end_dim = end_dim + 1
101103
else:
102-
end_dim = end_dim + 1
104+
start_dim = node.args[1]
105+
if len(node.args) == 3:
106+
end_dim = node.args[2]
107+
else:
108+
end_dim = -1
109+
if end_dim + 1 == 0 or end_dim + 1 > len(input_shapes[0]):
110+
end_dim = len(input_shapes[0])
111+
else:
112+
end_dim = end_dim + 1
103113

104114
layer['target_shape'] = (
105115
input_shapes[0][0:start_dim] + [np.prod(input_shapes[0][start_dim:end_dim])] + input_shapes[0][end_dim:]

0 commit comments

Comments
 (0)