Skip to content

Commit 8161829

Browse files
authored
Merge branch 'main' into initialRecurr
2 parents 838bf37 + 3b7e595 commit 8161829

File tree

3 files changed

+17
-3
lines changed

3 files changed

+17
-3
lines changed

hls4ml/converters/pytorch/reshape.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def parse_squeeze_layer(operation, layer_name, input_names, input_shapes, node,
3838
layer = {}
3939
layer['class_name'] = 'Reshape'
4040
layer['name'] = layer_name
41+
layer['inputs'] = input_names
4142

4243
if len(node.args) > 1 or len(node.kwargs) > 0: # 'dim' argument is specified
4344
output_shape = [i for i in input_shapes[0]]

hls4ml/converters/pytorch_to_hls.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,7 @@ def parse_pytorch_model(config, verbose=True):
151151
inputs_map = {}
152152

153153
input_layers = []
154+
output_layers = []
154155

155156
# Output shape tracking
156157
output_shapes = {}
@@ -407,12 +408,23 @@ def parse_pytorch_model(config, verbose=True):
407408
if len(input_layers) == 0:
408409
input_layers = None
409410

410-
return layer_list, input_layers
411+
for layer in layer_list:
412+
if layer['class_name'] == 'InputLayer':
413+
continue
414+
is_input = False
415+
for lay in layer_list:
416+
if 'inputs' not in lay.keys():
417+
continue
418+
if layer['name'] in lay['inputs']:
419+
is_input = True
420+
if not is_input:
421+
output_layers.append(layer['name'])
422+
return layer_list, input_layers, output_layers
411423

412424

413425
@requires('_torch')
414426
def pytorch_to_hls(config):
415-
layer_list, input_layers = parse_pytorch_model(config)
427+
layer_list, input_layers, output_layers = parse_pytorch_model(config)
416428
print('Creating HLS model')
417-
hls_model = ModelGraph(config, layer_list, inputs=input_layers)
429+
hls_model = ModelGraph(config, layer_list, inputs=input_layers, outputs=output_layers)
418430
return hls_model

hls4ml/utils/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -368,6 +368,7 @@ def config_from_pytorch_model(
368368
(
369369
layer_list,
370370
_,
371+
_,
371372
) = parse_pytorch_model(config, verbose=False)
372373

373374
def make_layer_config(layer):

0 commit comments

Comments
 (0)