@@ -102,7 +102,7 @@ def decorator(function):
102102# ----------------------------------------------------------------
103103
104104
105- def pytorch_to_hls (config ):
105+ def parse_pytorch_model (config , verbose = True ):
106106 """Convert PyTorch model to hls4ml ModelGraph.
107107
108108 Args:
@@ -118,14 +118,15 @@ def pytorch_to_hls(config):
118118 # This is a list of dictionaries to hold all the layer info we need to generate HLS
119119 layer_list = []
120120
121- print ( 'Interpreting Model ...' )
122-
121+ if verbose :
122+ print ( 'Interpreting Model ...' )
123123 reader = PyTorchFileReader (config ) if isinstance (config ['PytorchModel' ], str ) else PyTorchModelReader (config )
124124 if type (reader .input_shape ) is tuple :
125125 input_shapes = [list (reader .input_shape )]
126126 else :
127127 input_shapes = list (reader .input_shape )
128- input_shapes = [list (shape ) for shape in input_shapes ]
128+ # first element needs to 'None' as placeholder for the batch size, insert it if not present
129+ input_shapes = [[None ] + list (shape ) if shape [0 ] is not None else list (shape ) for shape in input_shapes ]
129130
130131 model = reader .torch_model
131132
@@ -151,7 +152,8 @@ def pytorch_to_hls(config):
151152 output_shape = None
152153
153154 # Loop through layers
154- print ('Topology:' )
155+ if verbose :
156+ print ('Topology:' )
155157 layer_counter = 0
156158
157159 n_inputs = 0
@@ -226,13 +228,14 @@ def pytorch_to_hls(config):
226228 pytorch_class , layer_name , input_names , input_shapes , node , class_object , reader , config
227229 )
228230
229- print (
230- 'Layer name: {}, layer type: {}, input shape: {}' .format (
231- layer ['name' ],
232- layer ['class_name' ],
233- input_shapes ,
231+ if verbose :
232+ print (
233+ 'Layer name: {}, layer type: {}, input shape: {}' .format (
234+ layer ['name' ],
235+ layer ['class_name' ],
236+ input_shapes ,
237+ )
234238 )
235- )
236239 layer_list .append (layer )
237240
238241 assert output_shape is not None
@@ -288,7 +291,12 @@ def pytorch_to_hls(config):
288291 operation , layer_name , input_names , input_shapes , node , None , reader , config
289292 )
290293
291- print ('Layer name: {}, layer type: {}, input shape: {}' .format (layer ['name' ], layer ['class_name' ], input_shapes ))
294+ if verbose :
295+ print (
296+ 'Layer name: {}, layer type: {}, input shape: {}' .format (
297+ layer ['name' ], layer ['class_name' ], input_shapes
298+ )
299+ )
292300 layer_list .append (layer )
293301
294302 assert output_shape is not None
@@ -342,7 +350,12 @@ def pytorch_to_hls(config):
342350 operation , layer_name , input_names , input_shapes , node , None , reader , config
343351 )
344352
345- print ('Layer name: {}, layer type: {}, input shape: {}' .format (layer ['name' ], layer ['class_name' ], input_shapes ))
353+ if verbose :
354+ print (
355+ 'Layer name: {}, layer type: {}, input shape: {}' .format (
356+ layer ['name' ], layer ['class_name' ], input_shapes
357+ )
358+ )
346359 layer_list .append (layer )
347360
348361 assert output_shape is not None
@@ -351,6 +364,11 @@ def pytorch_to_hls(config):
351364 if len (input_layers ) == 0 :
352365 input_layers = None
353366
367+ return layer_list , input_layers
368+
369+
370+ def pytorch_to_hls (config ):
371+ layer_list , input_layers = parse_pytorch_model (config )
354372 print ('Creating HLS model' )
355373 hls_model = ModelGraph (config , layer_list , inputs = input_layers )
356374 return hls_model
0 commit comments