Skip to content

Commit d639e74

Browse files
aboubezarikevalmorabia97
authored andcommitted
[Autocast] Optimize _convert_initializers runtime (#459)
Signed-off-by: Ali Boubezari <[email protected]>
1 parent fe213fd commit d639e74

File tree

1 file changed

+250
-118
lines changed

1 file changed

+250
-118
lines changed

modelopt/onnx/autocast/precisionconverter.py

Lines changed: 250 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,9 @@
2121
through type checking and cleanup of redundant operations.
2222
"""
2323

24-
from collections import namedtuple
24+
from collections import defaultdict, namedtuple
2525
from copy import deepcopy
26+
from dataclasses import dataclass, field
2627

2728
import ml_dtypes
2829
import numpy as np
@@ -39,6 +40,23 @@
3940

4041
PrecisionTypes = namedtuple("PrecisionTypes", ["onnx_type", "numpy_type", "str_short", "str_full"])
4142

43+
44+
@dataclass
45+
class InputIndexTracker:
46+
"""A class that tracks the index of an input to a node."""
47+
48+
node: onnx.NodeProto
49+
node_index: int
50+
51+
52+
@dataclass
53+
class InitializerConsumerTracker:
54+
"""A class that tracks the nodes that consume an initializer."""
55+
56+
low_precision_nodes: list[InputIndexTracker] = field(default_factory=list)
57+
high_precision_nodes: list[InputIndexTracker] = field(default_factory=list)
58+
59+
4260
PRECISION_MAP = {
4361
"fp32": PrecisionTypes(TensorProto.FLOAT, np.float32, "fp32", "float32"),
4462
"fp16": PrecisionTypes(TensorProto.FLOAT16, np.float16, "fp16", "float16"),
@@ -472,133 +490,247 @@ def _get_tensors_to_cast(
472490
def _convert_initializers(
473491
self, low_precision_nodes: list[str], high_precision_nodes: list[str]
474492
) -> onnx.ModelProto:
475-
def convert_initializer(
476-
init: onnx.TensorProto,
477-
node: onnx.NodeProto,
478-
from_type: PrecisionTypes,
479-
to_type: PrecisionTypes,
480-
):
481-
if init.data_type != from_type.onnx_type:
493+
"""Convert model initializers to appropriate precision based on their consumer nodes.
494+
495+
This method analyzes how each initializer is used by different precision nodes and converts
496+
or duplicates initializers as needed to ensure type compatibility:
497+
498+
1. Maps each initializer to the high/low precision nodes that consume it
499+
2. For each initializer, applies one of these strategies:
500+
- If only used by low precision nodes: convert to low precision
501+
- If only used by high precision nodes: convert to high precision
502+
- If used by both precision types: duplicate the initializer, creating separate
503+
copies for each precision type and updating node references accordingly
504+
3. Skips conversion for non-float initializers or those already at correct precision
505+
506+
The method handles special cases like bfloat16 conversion and provides warnings when
507+
values are clamped or replaced due to precision limits.
508+
509+
Args:
510+
low_precision_nodes: List of node names that should use low precision initializers.
511+
high_precision_nodes: List of node names that should use high precision initializers.
512+
"""
513+
# 1. Compute a mapping from initiailizers to high precision nodes & low precision nodes that use them.
514+
low_precision_nodes_set: set[str] = set(low_precision_nodes)
515+
high_precision_nodes_set: set[str] = set(high_precision_nodes)
516+
initializer_to_nodes: dict[str, InitializerConsumerTracker] = defaultdict(
517+
lambda: InitializerConsumerTracker()
518+
)
519+
for node in self.model.graph.node:
520+
# Compute the mapping from initializers to low precision nodes that use them.
521+
if node.name in low_precision_nodes_set:
522+
for idx, input_name in enumerate(node.input):
523+
if input_name in self.initializer_map:
524+
if self._should_skip_low_precision_input_conversion(node, input_name):
525+
# Handle low precision nodes that require certain high precision inputs.
526+
initializer_to_nodes[input_name].high_precision_nodes.append(
527+
InputIndexTracker(node=node, node_index=idx)
528+
)
529+
else:
530+
initializer_to_nodes[input_name].low_precision_nodes.append(
531+
InputIndexTracker(node=node, node_index=idx)
532+
)
533+
# Compute the mapping from initializers to high precision nodes that use them.
534+
elif node.name in high_precision_nodes_set:
535+
for idx, input_name in enumerate(node.input):
536+
if input_name in self.initializer_map:
537+
initializer_to_nodes[input_name].high_precision_nodes.append(
538+
InputIndexTracker(node=node, node_index=idx)
539+
)
540+
541+
onnx_float_types = set(ONNX_TYPES)
542+
# 2. Convert initializers to appropriate precision based on their consumer nodes.
543+
for init_name, tracker in initializer_to_nodes.items():
544+
# Get the initializer.
545+
init = self.initializer_map[init_name]
546+
# If not used, just skip.
547+
if len(tracker.low_precision_nodes) == 0 and len(tracker.high_precision_nodes) == 0:
548+
logger.debug(f"Initializer {init_name} is not used by any nodes, skipping")
549+
continue
550+
# If the initializer is not a float, then just skip.
551+
if init.data_type not in onnx_float_types:
552+
logger.debug(f"Initializer {init_name} is not a float, skipping")
553+
continue
554+
# If the initializer is only used by high precision nodes and is high precision, then just skip.
555+
if (
556+
len(tracker.low_precision_nodes) == 0
557+
and init.data_type == self.high_precision_type.onnx_type
558+
):
482559
logger.debug(
483-
f"Initializer {init.name} has data type {init.data_type}, and size {len(init.raw_data)},"
484-
"skipping conversion"
560+
f"Initializer {init_name} is already high precision and only used "
561+
"by high precision nodes, skipping"
485562
)
486-
return False
563+
continue
564+
# If the initializer is only used by low precision nodes and is low precision, then just skip.
565+
if (
566+
len(tracker.high_precision_nodes) == 0
567+
and init.data_type == self.low_precision_type.onnx_type
568+
):
569+
logger.debug(
570+
f"Initializer {init_name} is already low precision and only used "
571+
"by low precision nodes, skipping"
572+
)
573+
continue
574+
575+
# If the initializer is used by only one precision type, then convert it to the other precision type.
576+
if len(tracker.high_precision_nodes) == 0 or len(tracker.low_precision_nodes) == 0:
577+
if len(tracker.low_precision_nodes) > 0:
578+
logger.debug(
579+
f"Convert initializer {init_name} to "
580+
f"{self.low_precision_type.str_short}, only used by low precision nodes"
581+
)
582+
from_type = self.high_precision_type
583+
to_type = self.low_precision_type
584+
elif len(tracker.high_precision_nodes) > 0:
585+
logger.debug(
586+
f"Convert initializer {init_name} to "
587+
f"{self.high_precision_type.str_short}, "
588+
"only used by high precision nodes"
589+
)
590+
from_type = self.low_precision_type
591+
to_type = self.high_precision_type
592+
else:
593+
raise ValueError(
594+
f"Unexpected: initializer {init_name} is not used by any "
595+
"nodes and is not a float"
596+
)
597+
598+
new_init = self._cast_initializer(
599+
init=init,
600+
from_type=from_type,
601+
to_type=to_type,
602+
low_precision_nodes=tracker.low_precision_nodes,
603+
high_precision_nodes=tracker.high_precision_nodes,
604+
)
605+
if new_init is not None:
606+
self.model.graph.initializer.remove(init)
607+
self.model.graph.initializer.extend([new_init])
608+
continue
487609

488-
# If initializer is too large, skip conversion, perform cast instead
489-
if init.raw_data and len(init.raw_data) > self.init_conversion_max_bytes:
610+
# This initializer is used by both high precision and low precision nodes, so we need
611+
# to duplicate it for low precision nodes.
612+
assert len(tracker.low_precision_nodes) > 0 and len(tracker.high_precision_nodes) > 0
613+
if init.data_type == self.low_precision_type.onnx_type:
490614
logger.debug(
491-
f"Initializer {init.name} is too large, skipping initializer conversion, cast in "
492-
"runtime instead"
615+
f"Convert initializer {init_name} to "
616+
f"{self.high_precision_type.str_short}, "
617+
"used by both high precision and low precision nodes"
493618
)
494-
exclude_consumers = (
495-
low_precision_nodes if self._is_fp32(to_type) else high_precision_nodes
619+
from_type = self.low_precision_type
620+
to_type = self.high_precision_type
621+
nodes_to_update = tracker.high_precision_nodes
622+
elif init.data_type == self.high_precision_type.onnx_type:
623+
logger.debug(
624+
f"Convert initializer {init_name} to "
625+
f"{self.low_precision_type.str_short}, "
626+
"used by both high precision and low precision nodes"
496627
)
497-
self._add_cast(init.name, to_type, exclude_consumers=exclude_consumers)
498-
return True
499-
try:
500-
np_array = numpy_helper.to_array(init)
501-
assert from_type.str_short in PRECISION_MAP
502-
assert to_type.str_short in PRECISION_MAP
503-
assert from_type.str_short != to_type.str_short
504-
505-
if np_array.dtype == from_type.numpy_type:
506-
consumers = [n.name for n in utils.get_consumer_nodes(self.model, init.name)]
507-
should_duplicate = len(consumers) > 1 and set(consumers) & set(
508-
high_precision_nodes
509-
)
628+
from_type = self.high_precision_type
629+
to_type = self.low_precision_type
630+
nodes_to_update = tracker.low_precision_nodes
631+
else:
632+
raise ValueError(f"Unexpected: initializer {init_name} is not a float")
633+
634+
new_init = self._cast_initializer(
635+
init=init,
636+
from_type=from_type,
637+
to_type=to_type,
638+
low_precision_nodes=tracker.low_precision_nodes,
639+
high_precision_nodes=tracker.high_precision_nodes,
640+
)
641+
if new_init is not None:
642+
new_init_name = f"{init_name}_{to_type.str_short}"
643+
new_init.name = new_init_name
644+
for node in nodes_to_update:
645+
node.node.input[node.node_index] = new_init_name
646+
self.model.graph.initializer.extend([new_init])
647+
648+
def _cast_initializer(
649+
self,
650+
init: onnx.TensorProto,
651+
from_type: PrecisionTypes,
652+
to_type: PrecisionTypes,
653+
low_precision_nodes: list[InputIndexTracker] | list[onnx.NodeProto],
654+
high_precision_nodes: list[InputIndexTracker] | list[onnx.NodeProto],
655+
) -> onnx.TensorProto | None:
656+
"""Cast an initializer to a new precision based on its consumer nodes.
510657
511-
if should_duplicate:
512-
# Create a new low precision copy with a different name
513-
new_name = f"{init.name}_{to_type.str_short}"
514-
logger.debug(
515-
f"Initializer {init.name} is shared, creating {to_type.str_short} copy as {new_name} due "
516-
f"to node {node.name}"
517-
)
658+
This method converts an initializer to a new precision while handling special cases like bfloat16 conversion
659+
and providing warnings when values are clamped or replaced due to precision limits.
518660
519-
# Update the node to use the new initializer
520-
for i, input_name in enumerate(node.input):
521-
if input_name == init.name:
522-
node.input[i] = new_name
523-
break
661+
Args:
662+
init: The initializer to cast.
663+
from_type: The original precision of the initializer.
664+
to_type: The new precision to cast the initializer to.
524665
525-
if init.name in initializer_converted_dup:
526-
return False
527-
initializer_converted_dup.append(init.name)
528-
else:
529-
if init.name in initializer_converted:
530-
return False
531-
new_name = init.name
532-
logger.debug(
533-
f"Converting initializer {new_name} to {to_type.str_short} due to node {node.name}"
534-
)
535-
initializer_converted.append(init.name)
536-
self.model.graph.initializer.remove(init)
537-
538-
# Numpy does not support bfloat16, use ml_dtypes to create the raw data instead
539-
if self._is_bf16(to_type) and self._is_fp32(from_type):
540-
new_init = onnx.TensorProto()
541-
new_init.dims.extend(np_array.shape)
542-
new_init.name = new_name
543-
new_init.data_type = onnx.TensorProto.BFLOAT16
544-
bf16_bytes = np_array.astype(ml_dtypes.bfloat16).view(np.uint16)
545-
new_init.raw_data = bf16_bytes.tobytes()
546-
else:
547-
assert to_type.numpy_type is not None
548-
data_max, data_lowest = (
549-
np.finfo(to_type.numpy_type).max,
550-
np.finfo(to_type.numpy_type).smallest_subnormal,
551-
)
552-
if np.any(np.abs(np_array) > data_max):
553-
logger.warning(
554-
f"Initializer {init.name} used by node {node.name} contains values larger than "
555-
f"largest {to_type.str_short} value, values will be clamped to {data_max}."
556-
)
557-
np_array = np.clip(np_array, -1 * data_max, data_max)
558-
if np.any((np_array != 0.0) & (np.abs(np_array) < data_lowest)):
559-
logger.warning(
560-
f"Initializer {init.name} used by node {node.name} contains values smaller than "
561-
f"smallest {to_type.str_short} value, values will be replaced with {data_lowest:.1e}."
562-
)
563-
np_array = np.where(
564-
(np_array != 0.0) & (np.abs(np_array) < data_lowest),
565-
data_lowest,
566-
np_array,
567-
)
568-
new_array = np_array.astype(to_type.numpy_type)
569-
new_init = numpy_helper.from_array(new_array, new_name)
570-
self.model.graph.initializer.extend([new_init])
571-
return True
572-
return False
573-
except Exception as e:
574-
logger.error(f"Error converting initializer {init.name}: {e}")
575-
return False
666+
Returns:
667+
onnx.TensorProto: The casted initializer.
668+
"""
576669

577-
initializer_converted = []
578-
initializer_converted_dup = []
579-
modified = False
580-
for node in self.model.graph.node:
581-
if node.name in low_precision_nodes:
582-
for init in self.node_to_init_map[node.name]:
583-
if self._should_skip_low_precision_input_conversion(node, init.name):
584-
continue
585-
modified |= convert_initializer(
586-
init,
587-
node,
588-
from_type=self.high_precision_type,
589-
to_type=self.low_precision_type,
590-
)
591-
if modified:
592-
_, _, self.node_to_init_map = utils.setup_mappings(self.model)
593-
594-
if node.name in high_precision_nodes:
595-
for init in self.node_to_init_map[node.name]:
596-
convert_initializer(
597-
init,
598-
node,
599-
from_type=self.low_precision_type,
600-
to_type=self.high_precision_type,
601-
)
670+
def _get_name(node: onnx.NodeProto | InputIndexTracker) -> str:
671+
"""Get the name of a node or input index tracker."""
672+
if isinstance(node, onnx.NodeProto):
673+
return node.name
674+
elif isinstance(node, InputIndexTracker):
675+
return node.node.name
676+
else:
677+
raise ValueError(f"Unexpected: {type(node)}")
678+
679+
# Ensure the initializer is of the expected type
680+
assert init.data_type == from_type.onnx_type, (
681+
f"Initializer {init.name} is not of type {from_type.str_short}"
682+
)
683+
684+
if init.raw_data and len(init.raw_data) > self.init_conversion_max_bytes:
685+
# The initializer is too large, so we need to convert it at runtime.
686+
logger.debug(
687+
f"Initializer {init.name} is too large, skipping initializer conversion, cast in "
688+
"runtime instead"
689+
)
690+
exclude_consumers = (
691+
low_precision_nodes if self._is_fp32(to_type) else high_precision_nodes
692+
)
693+
exclude_consumers_names: list[str] = []
694+
695+
exclude_consumers_names = [_get_name(node) for node in exclude_consumers]
696+
self._add_cast(init.name, to_type, exclude_consumers=exclude_consumers_names)
697+
return None
698+
699+
np_array = numpy_helper.to_array(init)
700+
# Numpy does not support bfloat16, use ml_dtypes to create the raw data instead
701+
if self._is_bf16(to_type) and self._is_fp32(from_type):
702+
new_init = onnx.TensorProto()
703+
new_init.dims.extend(np_array.shape)
704+
new_init.name = init.name
705+
new_init.data_type = onnx.TensorProto.BFLOAT16
706+
bf16_bytes = np_array.astype(ml_dtypes.bfloat16).view(np.uint16)
707+
new_init.raw_data = bf16_bytes.tobytes()
708+
else:
709+
assert to_type.numpy_type is not None
710+
data_max, data_lowest = (
711+
np.finfo(to_type.numpy_type).max,
712+
np.finfo(to_type.numpy_type).smallest_subnormal,
713+
)
714+
if np.any(np.abs(np_array) > data_max):
715+
logger.warning(
716+
f"Initializer {init.name} contains values larger than largest "
717+
f"{to_type.str_short} value, values will be clamped to {data_max}."
718+
)
719+
np_array = np.clip(np_array, -1 * data_max, data_max)
720+
if np.any((np_array != 0.0) & (np.abs(np_array) < data_lowest)):
721+
logger.warning(
722+
f"Initializer {init.name} contains values smaller than smallest "
723+
f"{to_type.str_short} value, values will be replaced with {data_lowest:.1e}."
724+
)
725+
np_array = np.where(
726+
(np_array != 0.0) & (np.abs(np_array) < data_lowest),
727+
data_lowest,
728+
np_array,
729+
)
730+
new_array = np_array.astype(to_type.numpy_type)
731+
new_init = numpy_helper.from_array(new_array, init.name)
732+
733+
return new_init
602734

603735
def _replace_tensor_name(
604736
self, consumers: list[onnx.NodeProto], original_tensor_name: str, new_tensor_name: str

0 commit comments

Comments
 (0)