@@ -109,6 +109,12 @@ def get_layer_config(self, layer):
109109
110110 return layer_config
111111
112+ def set_name_config (self , name , config ):
113+ """sets hls_config["LayerName"][name] = config"""
114+ hls_config = self .config ['HLSConfig' ]
115+ layer_config = hls_config .setdefault ('LayerName' , {})
116+ layer_config [name ] = config
117+
112118 def get_precision (self , layer , var = 'default' ):
113119 precision = self .layer_name_precision .get (layer .name .lower () + '_' + var )
114120 type_name = layer .name .lower () + '_' + var + '_t'
@@ -192,6 +198,35 @@ def get_compression(self, layer):
192198
193199 return compression
194200
201+ def parse_name_config (self , layer_name , layer_cfg ):
202+ """This is used by _parse_hls_config below, but also in optimizers when a new layer config is created"""
203+ precision_cfg = layer_cfg .get ('Precision' )
204+ if isinstance (precision_cfg , dict ):
205+ for var , precision in precision_cfg .items ():
206+ self .layer_name_precision [layer_name .lower () + '_' + var ] = precision
207+ else :
208+ self .layer_name_precision [layer_name .lower () + '_default' ] = precision_cfg
209+
210+ rf = layer_cfg .get ('ReuseFactor' )
211+ if rf is not None :
212+ self .layer_name_rf [layer_name .lower ()] = rf
213+
214+ targ_cycles = layer_cfg .get ('TargetCycles' )
215+ if targ_cycles is not None :
216+ self .layer_name_targ_cycles [layer_name .lower ()] = targ_cycles
217+
218+ strategy = layer_cfg .get ('Strategy' )
219+ if strategy is not None :
220+ self .layer_name_strategy [layer_name .lower ()] = strategy
221+
222+ conv_implementation = layer_cfg .get ('ConvImplementation' )
223+ if conv_implementation is not None :
224+ self .layer_name_conv_implementation [layer_name .lower ()] = conv_implementation
225+
226+ compression = layer_cfg .get ('Compression' )
227+ if compression is not None :
228+ self .layer_name_compression [layer_name .lower ()] = bool (compression )
229+
195230 def get_writer_config (self ):
196231 return self .writer_config
197232
@@ -267,32 +302,7 @@ def _parse_hls_config(self):
267302 layer_name_cfg = hls_config .get ('LayerName' )
268303 if layer_name_cfg is not None :
269304 for layer_name , layer_cfg in layer_name_cfg .items ():
270- precision_cfg = layer_cfg .get ('Precision' )
271- if isinstance (precision_cfg , dict ):
272- for var , precision in precision_cfg .items ():
273- self .layer_name_precision [layer_name .lower () + '_' + var ] = precision
274- else :
275- self .layer_name_precision [layer_name .lower () + '_default' ] = precision_cfg
276-
277- rf = layer_cfg .get ('ReuseFactor' )
278- if rf is not None :
279- self .layer_name_rf [layer_name .lower ()] = rf
280-
281- targ_cycles = layer_cfg .get ('TargetCycles' )
282- if targ_cycles is not None :
283- self .layer_name_targ_cycles [layer_name .lower ()] = targ_cycles
284-
285- strategy = layer_cfg .get ('Strategy' )
286- if strategy is not None :
287- self .layer_name_strategy [layer_name .lower ()] = strategy
288-
289- conv_implementation = layer_cfg .get ('ConvImplementation' )
290- if conv_implementation is not None :
291- self .layer_name_conv_implementation [layer_name .lower ()] = conv_implementation
292-
293- compression = layer_cfg .get ('Compression' )
294- if compression is not None :
295- self .layer_name_compression [layer_name .lower ()] = bool (compression )
305+ self .parse_name_config (layer_name , layer_cfg )
296306
297307 def _validate_hls_config (self ):
298308 use_dataflow = False
@@ -617,6 +627,44 @@ def replace_node(self, old_node, new_node):
617627 self .graph = OrderedDict ((new_node .name , new_node ) if k == old_node .name else (k , v ) for k , v in self .graph .items ())
618628 self ._update_model_outputs ()
619629
630+ def split_node (self , old_node , new_node1 , new_node2 ):
631+ """Replace an existing node in the graph with two nodes in sequence.
632+
633+ Args:
634+ old_node (Layer): The node to replace
635+ new_node1 (Layer): The first new node in sequence
636+ new_node2 (Layer): The second new node in sequence
637+
638+ """
639+
640+ # fmt: off
641+ assert len (new_node1 .inputs ) == len (old_node .inputs ), \
642+ f'{ new_node1 .name } and { old_node .name } have different number of inputs'
643+ assert len (new_node2 .outputs ) == len (old_node .outputs ), \
644+ f'{ new_node2 .name } and { old_node .name } have different number of outputs'
645+ # fmt: on
646+
647+ repl = {old_name : new_name for old_name , new_name in zip (old_node .outputs , new_node2 .outputs )}
648+ repl .update ({old_name : new_name for old_name , new_name in zip (old_node .inputs , new_node1 .inputs )})
649+
650+ for node in self .graph .values ():
651+ for i , n in enumerate (node .inputs ):
652+ if n in repl :
653+ node .inputs [i ] = repl [n ]
654+ for i , n in enumerate (node .outputs ):
655+ if n in repl :
656+ node .outputs [i ] = repl [n ]
657+
658+ new_graph = OrderedDict ()
659+ for key , value in self .graph .items ():
660+ if key == old_node .name :
661+ new_graph [new_node1 .name ] = new_node1
662+ new_graph [new_node2 .name ] = new_node2
663+ else :
664+ new_graph [key ] = value
665+ self .graph = new_graph
666+ self ._update_model_outputs ()
667+
620668 def _update_model_outputs (self ):
621669 '''Update the model outputs
622670
0 commit comments