|
| 1 | +import numpy as np |
1 | 2 | import torch |
2 | 3 |
|
3 | 4 | from hls4ml.model import ModelGraph |
@@ -159,6 +160,23 @@ def parse_pytorch_model(config, verbose=True): |
159 | 160 |
|
160 | 161 | n_inputs = 0 |
161 | 162 |
|
| 163 | + # check for constant nodes |
| 164 | + merge_layers = ['add', 'mul', 'sub', 'fmin', 'fmax'] |
| 165 | + i = 0 # count number of consts and use it in the name |
| 166 | + for node in traced_model.graph.nodes: |
| 167 | + if node.name.split('_')[0] in merge_layers: |
| 168 | + for arg in node.args: |
| 169 | + if np.isscalar(arg): |
| 170 | + # add an input node with the constant value |
| 171 | + new_node = traced_model.graph.placeholder( |
| 172 | + name='const_' + str(i), type_expr=torch.Tensor, default_value=arg |
| 173 | + ) |
| 174 | + node.prepend(new_node) |
| 175 | + node.update_arg(1, new_node) |
| 176 | + i += 1 |
| 177 | + |
| 178 | + traced_model.graph.lint() |
| 179 | + |
162 | 180 | for node in traced_model.graph.nodes: |
163 | 181 | if node.op == 'call_module': |
164 | 182 | # modules that are part of a torch.nn.Sequential with name 'name' have target names 'name.x', |
@@ -249,13 +267,26 @@ def parse_pytorch_model(config, verbose=True): |
249 | 267 |
|
250 | 268 | input_layer = {} |
251 | 269 | input_layer['name'] = node.name |
252 | | - input_layer['class_name'] = 'InputLayer' |
253 | | - input_layer['input_shape'] = list(input_shapes[n_inputs][1:]) |
254 | | - layer_list.insert(n_inputs, input_layer) |
255 | 270 |
|
256 | | - output_shapes[input_layer['name']] = list(input_shapes[n_inputs]) |
257 | | - input_layers.append(input_layer['name']) |
258 | | - n_inputs += 1 |
| 271 | + if 'const' in node.name: |
| 272 | + pytorch_class = 'Constant' |
| 273 | + layer, output_shape = layer_handlers[pytorch_class](pytorch_class, node.name, node) |
| 274 | + |
| 275 | + layer_list.append(layer) |
| 276 | + |
| 277 | + assert output_shape is not None |
| 278 | + output_shapes[layer['name']] = output_shape |
| 279 | + |
| 280 | + else: |
| 281 | + |
| 282 | + input_layer['class_name'] = 'InputLayer' |
| 283 | + input_layer['input_shape'] = list(input_shapes[n_inputs][1:]) |
| 284 | + layer_list.insert(n_inputs, input_layer) |
| 285 | + |
| 286 | + output_shapes[input_layer['name']] = list(input_shapes[n_inputs]) |
| 287 | + |
| 288 | + input_layers.append(input_layer['name']) |
| 289 | + n_inputs += 1 |
259 | 290 |
|
260 | 291 | layer_counter += 1 |
261 | 292 |
|
|
0 commit comments