Skip to content

Commit 9e827a9

Browse files
aboubezarikevalmorabia97
authored andcommitted
[Autocast] Optimize _add_cast runtime (#469)
Signed-off-by: Ali Boubezari <[email protected]>
1 parent 619ddbb commit 9e827a9

File tree

1 file changed

+40
-4
lines changed

1 file changed

+40
-4
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 40 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -184,6 +184,17 @@ def convert(
184184
logger.debug(f"cast down (to {self.low_precision_type.str_full}): {cast_down_tensors}")
185185
logger.debug(f"cast up (to {self.high_precision_type.str_full}): {cast_up_tensors}")
186186

187+
# Since we have removed all casts, we can pre-compute the tensor_to_consumers and
188+
# tensor_to_producers maps since they will not change for the duration of the conversion.
189+
tensor_to_consumers = defaultdict(list)
190+
tensor_to_producers = defaultdict(list)
191+
192+
for node in self.model.graph.node:
193+
for input in node.input:
194+
tensor_to_consumers[input].append(node)
195+
for output in node.output:
196+
tensor_to_producers[output].append(node)
197+
187198
# Add cast nodes for "cast_up" tensors
188199
for tensor_name in cast_up_tensors:
189200
exclude_consumers = low_precision_nodes
@@ -194,7 +205,11 @@ def convert(
194205
set(low_precision_nodes) - {fp32_input_to_low_precision_node[tensor_name].name}
195206
)
196207
self._add_cast(
197-
tensor_name, self.high_precision_type, exclude_consumers=exclude_consumers
208+
tensor_name,
209+
self.high_precision_type,
210+
exclude_consumers=exclude_consumers,
211+
tensor_to_consumers=tensor_to_consumers,
212+
tensor_to_producers=tensor_to_producers,
198213
)
199214

200215
# Add cast nodes for "cast_down" tensors
@@ -203,6 +218,8 @@ def convert(
203218
tensor_name,
204219
self.low_precision_type,
205220
exclude_consumers=high_precision_nodes,
221+
tensor_to_consumers=tensor_to_consumers,
222+
tensor_to_producers=tensor_to_producers,
206223
)
207224

208225
# Convert initializers to correct precision according to the consumer nodes
@@ -803,14 +820,27 @@ def _remove_preexisting_casts(self) -> None:
803820
self.model.graph.node.remove(node)
804821

805822
def _add_cast(
806-
self, tensor_name: str, cast_to: PrecisionTypes, exclude_consumers: list[str] = []
823+
self,
824+
tensor_name: str,
825+
cast_to: PrecisionTypes,
826+
exclude_consumers: list[str] = [],
827+
tensor_to_consumers: dict[str, list[onnx.NodeProto]] | None = None,
828+
tensor_to_producers: dict[str, list[onnx.NodeProto]] | None = None,
807829
) -> None:
808830
"""Adds a cast operation on a tensor and reconnects its consumers.
809831
810832
Args:
811833
tensor_name: Name of the tensor to cast.
812834
cast_to: Target precision type to cast to.
813835
exclude_consumers: List of consumer nodes to exclude from reconnection.
836+
tensor_to_consumers: Optional pre-computed map of tensor names to their consumer nodes.
837+
If not provided, the map will be computed on the fly.
838+
tensor_to_producers: Optional pre-computed map of tensor names to their producer nodes.
839+
If not provided, the map will be computed on the fly.
840+
841+
NOTE: It is up to the user to ensure that the tensor_to_consumers and tensor_to_producers
842+
maps are up to date before calling this function. Consecutive casts in the graph will break
843+
this assumption and the maps must be recomputed.
814844
"""
815845
# Empty tensors may have special handling in ONNX (e.g. for Resize scales) which can break when redundant casts
816846
# are injected. Since there's no data, it's safe to only update the metadata.
@@ -848,7 +878,10 @@ def _add_cast(
848878
name=f"{tensor_name}_cast_to_{cast_to.str_short}",
849879
)
850880

851-
consumer_nodes = utils.get_consumer_nodes(self.model, tensor_name)
881+
if tensor_to_consumers is None:
882+
utils.get_consumer_nodes(self.model, tensor_name)
883+
else:
884+
consumer_nodes = tensor_to_consumers.get(tensor_name, [])
852885
consumer_nodes = [n for n in consumer_nodes if n.name not in exclude_consumers]
853886
for node in consumer_nodes:
854887
for i, input_name in enumerate(node.input):
@@ -868,7 +901,10 @@ def _add_cast(
868901
break
869902

870903
# Find producer node to insert cast after it
871-
producer_nodes = utils.get_producer_nodes(self.model, tensor_name)
904+
if tensor_to_producers is None:
905+
producer_nodes = utils.get_producer_nodes(self.model, tensor_name)
906+
else:
907+
producer_nodes = tensor_to_producers.get(tensor_name, [])
872908
if producer_nodes:
873909
# Insert after the producer node
874910
# Find index by iterating since RepeatedCompositeContainer doesn't support index()

0 commit comments

Comments
 (0)