Skip to content

Commit 5c85e9d

Browse files
authored
Merge pull request #1123 from JanFSchulte/constant
Support Constant nodes in pytorch parser
2 parents 3fa2902 + 15cb4ad commit 5c85e9d

File tree

2 files changed

+56
-6
lines changed

2 files changed

+56
-6
lines changed

hls4ml/converters/pytorch/core.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,25 @@
1+
import numpy as np
2+
13
from hls4ml.converters.pytorch_to_hls import pytorch_handler
24

35

6+
@pytorch_handler('Constant')
7+
def parse_constant_layer(operation, layer_name, node):
8+
assert 'Constant' in operation
9+
10+
layer = {}
11+
layer['inputs'] = []
12+
13+
layer['class_name'] = 'Constant'
14+
layer['name'] = layer_name
15+
16+
constant = np.array(node._args)
17+
layer['value'] = constant
18+
output_shape = constant.shape
19+
20+
return layer, output_shape
21+
22+
423
@pytorch_handler('Linear')
524
def parse_linear_layer(operation, layer_name, input_names, input_shapes, node, class_object, data_reader, config):
625
assert 'Linear' in operation

hls4ml/converters/pytorch_to_hls.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
import torch
23

34
from hls4ml.model import ModelGraph
@@ -159,6 +160,23 @@ def parse_pytorch_model(config, verbose=True):
159160

160161
n_inputs = 0
161162

163+
# check for constant nodes
164+
merge_layers = ['add', 'mul', 'sub', 'fmin', 'fmax']
165+
i = 0 # count number of consts and use it in the name
166+
for node in traced_model.graph.nodes:
167+
if node.name.split('_')[0] in merge_layers:
168+
for arg in node.args:
169+
if np.isscalar(arg):
170+
# add an input node with the constant value
171+
new_node = traced_model.graph.placeholder(
172+
name='const_' + str(i), type_expr=torch.Tensor, default_value=arg
173+
)
174+
node.prepend(new_node)
175+
node.update_arg(1, new_node)
176+
i += 1
177+
178+
traced_model.graph.lint()
179+
162180
for node in traced_model.graph.nodes:
163181
if node.op == 'call_module':
164182
# modules that are part of a torch.nn.Sequential with name 'name' have target names 'name.x',
@@ -249,13 +267,26 @@ def parse_pytorch_model(config, verbose=True):
249267

250268
input_layer = {}
251269
input_layer['name'] = node.name
252-
input_layer['class_name'] = 'InputLayer'
253-
input_layer['input_shape'] = list(input_shapes[n_inputs][1:])
254-
layer_list.insert(n_inputs, input_layer)
255270

256-
output_shapes[input_layer['name']] = list(input_shapes[n_inputs])
257-
input_layers.append(input_layer['name'])
258-
n_inputs += 1
271+
if 'const' in node.name:
272+
pytorch_class = 'Constant'
273+
layer, output_shape = layer_handlers[pytorch_class](pytorch_class, node.name, node)
274+
275+
layer_list.append(layer)
276+
277+
assert output_shape is not None
278+
output_shapes[layer['name']] = output_shape
279+
280+
else:
281+
282+
input_layer['class_name'] = 'InputLayer'
283+
input_layer['input_shape'] = list(input_shapes[n_inputs][1:])
284+
layer_list.insert(n_inputs, input_layer)
285+
286+
output_shapes[input_layer['name']] = list(input_shapes[n_inputs])
287+
288+
input_layers.append(input_layer['name'])
289+
n_inputs += 1
259290

260291
layer_counter += 1
261292

0 commit comments

Comments
 (0)