@@ -184,6 +184,17 @@ def convert(
184184 logger .debug (f"cast down (to { self .low_precision_type .str_full } ): { cast_down_tensors } " )
185185 logger .debug (f"cast up (to { self .high_precision_type .str_full } ): { cast_up_tensors } " )
186186
187+ # Since we have removed all casts, we can pre-compute the tensor_to_consumers and
188+ # tensor_to_producers maps since they will not change for the duration of the conversion.
189+ tensor_to_consumers = defaultdict (list )
190+ tensor_to_producers = defaultdict (list )
191+
192+ for node in self .model .graph .node :
193+ for input in node .input :
194+ tensor_to_consumers [input ].append (node )
195+ for output in node .output :
196+ tensor_to_producers [output ].append (node )
197+
187198 # Add cast nodes for "cast_up" tensors
188199 for tensor_name in cast_up_tensors :
189200 exclude_consumers = low_precision_nodes
@@ -194,7 +205,11 @@ def convert(
194205 set (low_precision_nodes ) - {fp32_input_to_low_precision_node [tensor_name ].name }
195206 )
196207 self ._add_cast (
197- tensor_name , self .high_precision_type , exclude_consumers = exclude_consumers
208+ tensor_name ,
209+ self .high_precision_type ,
210+ exclude_consumers = exclude_consumers ,
211+ tensor_to_consumers = tensor_to_consumers ,
212+ tensor_to_producers = tensor_to_producers ,
198213 )
199214
200215 # Add cast nodes for "cast_down" tensors
@@ -203,6 +218,8 @@ def convert(
203218 tensor_name ,
204219 self .low_precision_type ,
205220 exclude_consumers = high_precision_nodes ,
221+ tensor_to_consumers = tensor_to_consumers ,
222+ tensor_to_producers = tensor_to_producers ,
206223 )
207224
208225 # Convert initializers to correct precision according to the consumer nodes
@@ -803,14 +820,27 @@ def _remove_preexisting_casts(self) -> None:
803820 self .model .graph .node .remove (node )
804821
805822 def _add_cast (
806- self , tensor_name : str , cast_to : PrecisionTypes , exclude_consumers : list [str ] = []
823+ self ,
824+ tensor_name : str ,
825+ cast_to : PrecisionTypes ,
826+ exclude_consumers : list [str ] = [],
827+ tensor_to_consumers : dict [str , list [onnx .NodeProto ]] | None = None ,
828+ tensor_to_producers : dict [str , list [onnx .NodeProto ]] | None = None ,
807829 ) -> None :
808830 """Adds a cast operation on a tensor and reconnects its consumers.
809831
810832 Args:
811833 tensor_name: Name of the tensor to cast.
812834 cast_to: Target precision type to cast to.
813835 exclude_consumers: List of consumer nodes to exclude from reconnection.
836+ tensor_to_consumers: Optional pre-computed map of tensor names to their consumer nodes.
837+ If not provided, the map will be computed on the fly.
838+ tensor_to_producers: Optional pre-computed map of tensor names to their producer nodes.
839+ If not provided, the map will be computed on the fly.
840+
841+ NOTE: It is up to the user to ensure that the tensor_to_consumers and tensor_to_producers
842+ maps are up to date before calling this function. Consecutive casts in the graph will break
843+ this assumption and the maps must be recomputed.
814844 """
815845 # Empty tensors may have special handling in ONNX (e.g. for Resize scales) which can break when redundant casts
816846 # are injected. Since there's no data, it's safe to only update the metadata.
@@ -848,7 +878,10 @@ def _add_cast(
848878 name = f"{ tensor_name } _cast_to_{ cast_to .str_short } " ,
849879 )
850880
851- consumer_nodes = utils .get_consumer_nodes (self .model , tensor_name )
881+ if tensor_to_consumers is None :
882+ utils .get_consumer_nodes (self .model , tensor_name )
883+ else :
884+ consumer_nodes = tensor_to_consumers .get (tensor_name , [])
852885 consumer_nodes = [n for n in consumer_nodes if n .name not in exclude_consumers ]
853886 for node in consumer_nodes :
854887 for i , input_name in enumerate (node .input ):
@@ -868,7 +901,10 @@ def _add_cast(
868901 break
869902
870903 # Find producer node to insert cast after it
871- producer_nodes = utils .get_producer_nodes (self .model , tensor_name )
904+ if tensor_to_producers is None :
905+ producer_nodes = utils .get_producer_nodes (self .model , tensor_name )
906+ else :
907+ producer_nodes = tensor_to_producers .get (tensor_name , [])
872908 if producer_nodes :
873909 # Insert after the producer node
874910 # Find index by iterating since RepeatedCompositeContainer doesn't support index()
0 commit comments