Skip to content

Commit 5c9819e

Browse files
committed
[5545101]: AutoCast: Add options to force include node/op in low precision
Add options nodes_to_include, op_types_to_include that force-include nodes in the conversion, overriding NodeClassifier exclusion logic Signed-off-by: Gal Hubara Agam <[email protected]>
1 parent 615f3c0 commit 5c9819e

File tree

5 files changed

+131
-5
lines changed

5 files changed

+131
-5
lines changed

docs/source/guides/8_autocast.rst

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ AutoCast can also be used programmatically through its Python API:
3131
low_precision_type="fp16", # or "bf16"
3232
nodes_to_exclude=None, # optional list of node name patterns to keep in FP32
3333
op_types_to_exclude=None, # optional list of op types to keep in FP32
34+
nodes_to_include=None, # optional list of node name patterns to force-include in low precision
35+
op_types_to_include=None, # optional list of op types to force-include in low precision
3436
data_max=512, # threshold for node outputs
3537
init_max=65504, # threshold for initializers
3638
keep_io_types=False, # whether to preserve input/output types
@@ -60,6 +62,19 @@ AutoCast follows these steps to convert a model:
6062
- Analyzes each node in the graph
6163
- Determines which nodes should remain in FP32 based on input and output tensors magnitudes, operation types and node name patterns
6264
- If a calibration dataset is provided, it will be used to generate intermediate tensor magnitudes for more accurate node classification, otherwise random data will be used.
65+
- Use ``nodes_to_include`` and ``op_types_to_include`` to force-include nodes in low precision, even if they would otherwise be excluded.
66+
67+
- Default classification rules. Nodes that meet any of these rules will be kept in high precision:
68+
- Node I/O magnitudes are higher than ``data_max`` (default: 512). Due to precision limitations, compute of high magnitude tensors in low precision might not be accurate. The unit in last place (ULP) for 512 is 0.5, for 1024 it is 1.0, etc.
69+
- Initializers magnitudes are higher than ``init_max`` (default: 65504). Initializers are often used for non-compute intensive operations and are more likely to be controlled by the user. However, values above ``init_max`` will cause overflow, therefore they are kept in high precision.
70+
71+
Additional classification rules (disabled by default):
72+
- ``max_depth_of_reduction``: Require nodes with a high depth of reduction (e.g., large matrix multiplications, convolutions with large kernels) to be kept in high precision.
73+
- ``nodes_to_exclude``: List of regex patterns for node names to keep in high precision.
74+
- ``op_types_to_exclude``: List of operation types to keep in high precision.
75+
- ``nodes_to_include``: List of regex patterns for node names to force-include in low precision.
76+
- ``op_types_to_include``: List of operation types to force-include in low precision.
77+
- ``custom_rule``: Optional custom rule for node classification (inherits from NodeRuleBase).
6378

6479
#. **Precision Conversion**:
6580

modelopt/onnx/autocast/__main__.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,24 @@ def get_parser() -> argparse.ArgumentParser:
8585
default=[],
8686
help="List of op types that should remain in FP32",
8787
)
88+
parser.add_argument(
89+
"--nodes_to_include",
90+
"-ni",
91+
type=str,
92+
nargs="*",
93+
default=[],
94+
help="List of regex patterns to match node names that should be force-included in low precision, even if they "
95+
"would otherwise be excluded",
96+
)
97+
parser.add_argument(
98+
"--op_types_to_include",
99+
"-opi",
100+
type=str,
101+
nargs="*",
102+
default=[],
103+
help="List of op types that should be force-included in low precision, even if they would otherwise be "
104+
"excluded",
105+
)
88106
parser.add_argument(
89107
"--data_max",
90108
type=float,
@@ -112,7 +130,8 @@ def get_parser() -> argparse.ArgumentParser:
112130
parser.add_argument(
113131
"--keep_io_types",
114132
action="store_true",
115-
help="Keep the input and output types of the model, otherwise they will be converted to FP16",
133+
help="Keep the input and output types of the model; otherwise they will be converted to reduced precision "
134+
"(FP16/BF16)",
116135
)
117136
parser.add_argument(
118137
"--log_level",
@@ -164,6 +183,8 @@ def main(argv=None):
164183
low_precision_type=args.low_precision_type,
165184
nodes_to_exclude=args.nodes_to_exclude,
166185
op_types_to_exclude=args.op_types_to_exclude,
186+
nodes_to_include=args.nodes_to_include,
187+
op_types_to_include=args.op_types_to_include,
167188
data_max=args.data_max,
168189
init_max=args.init_max,
169190
keep_io_types=args.keep_io_types,

modelopt/onnx/autocast/convert.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
"""AutoCast module for converting ONNX models to mixed precision.
1818
1919
AutoCast is a tool for converting FP32 ONNX models to mixed precision FP32-FP16 or FP32-BF16 models.
20-
While casting FP32 to FP6/BF16, some nodes might be more sensitive to effecting accuracy.
20+
While casting FP32 to FP16/BF16, some nodes might be more sensitive to effecting accuracy.
2121
AutoCast intelligently selects nodes to keep in FP32 precision to maintain model accuracy while benefiting from
2222
reduced precision on the rest of the nodes. AutoCast automatically injects cast operations around the selected
2323
nodes.
@@ -48,6 +48,8 @@ def convert_to_mixed_precision(
4848
low_precision_type: str = "fp16",
4949
nodes_to_exclude: list[str] | None = None,
5050
op_types_to_exclude: list[str] | None = None,
51+
nodes_to_include: list[str] | None = None,
52+
op_types_to_include: list[str] | None = None,
5153
data_max: float = DEFAULT_DATA_MAX,
5254
init_max: float = DEFAULT_INIT_MAX,
5355
keep_io_types: bool = False,
@@ -65,6 +67,8 @@ def convert_to_mixed_precision(
6567
low_precision_type: Target precision to reduce to ('fp16' or 'bf16').
6668
nodes_to_exclude: List of regex patterns to match node names that should remain in FP32.
6769
op_types_to_exclude: List of operation types that should remain in FP32.
70+
nodes_to_include: List of regex patterns to match node names that should be included in low precision.
71+
op_types_to_include: List of operation types that should be included in low precision.
6872
data_max: Maximum absolute value for node input and output values.
6973
init_max: Maximum absolute value for initializers.
7074
keep_io_types: Whether to preserve input/output types.
@@ -108,6 +112,8 @@ def convert_to_mixed_precision(
108112
initializer_map,
109113
nodes_to_exclude=nodes_to_exclude or [],
110114
op_types_to_exclude=op_types_to_exclude or [],
115+
nodes_to_include=nodes_to_include or [],
116+
op_types_to_include=op_types_to_include or [],
111117
data_max=data_max,
112118
init_max=init_max,
113119
custom_rule=custom_rule,

modelopt/onnx/autocast/nodeclassifier.py

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,28 @@ def _check_inner(self, node):
9393
return node.op_type in self.op_types_to_exclude
9494

9595

96+
class IncludeNodeNameRegexRule(DisabledNodeNameRegexRule):
97+
"""Rule for force-including nodes with matching names in low precision.
98+
99+
Inherits matching behavior from DisabledNodeNameRegexRule but overrides logging.
100+
"""
101+
102+
def _log_skipped(self, node, **kwargs):
103+
# For include rules, a positive match means we will force-include the node in low precision
104+
logger.info(f"Force-including node {node.name}: {self.__class__.__name__}")
105+
106+
107+
class IncludeOpTypes(DisabledOpTypes):
108+
"""Rule for force-including specific operation types in low precision.
109+
110+
Inherits matching behavior from DisabledOpTypes but overrides logging.
111+
"""
112+
113+
def _log_skipped(self, node, **kwargs):
114+
# For include rules, a positive match means we will force-include the node in low precision
115+
logger.info(f"Force-including node {node.name}: {self.__class__.__name__}")
116+
117+
96118
class InitializerRangeRule(NodeRuleBase):
97119
"""Rule for keeping nodes with out-of-range initializers in high precision."""
98120

@@ -332,6 +354,8 @@ def __init__(
332354
initializer_map: dict[str, onnx.TensorProto] | None = None,
333355
nodes_to_exclude: list[str] | None = None,
334356
op_types_to_exclude: list[str] | None = None,
357+
nodes_to_include: list[str] | None = None,
358+
op_types_to_include: list[str] | None = None,
335359
custom_rule: NodeRuleBase | None = None,
336360
data_max: float | None = 1000.0,
337361
init_max: float | None = np.finfo(np.float16).max,
@@ -345,6 +369,8 @@ def __init__(
345369
initializer_map: Mapping from initializer names to their tensors.
346370
nodes_to_exclude: List of regex patterns for node names to keep in high precision.
347371
op_types_to_exclude: List of operation types to keep in high precision.
372+
nodes_to_include: List of regex patterns for node names to force-include in low precision.
373+
op_types_to_include: List of operation types to force-include in low precision.
348374
custom_rule: Optional custom classification rule.
349375
data_max: Maximum absolute value allowed for node I/O.
350376
init_max: Maximum absolute value allowed for initializers.
@@ -355,12 +381,14 @@ def __init__(
355381
self.initializer_map = initializer_map
356382
self.nodes_to_exclude = nodes_to_exclude
357383
self.op_types_to_exclude = op_types_to_exclude
384+
self.nodes_to_include = nodes_to_include
385+
self.op_types_to_include = op_types_to_include
358386
self.custom_rule = custom_rule
359387
self.data_max = data_max
360388
self.init_max = init_max
361389
self.max_depth_of_reduction = max_depth_of_reduction
362390

363-
def _gen_block_node_rules(self, reference_data):
391+
def _gen_exclude_node_rules(self, reference_data):
364392
"""Generate list of rules for blocking nodes from precision conversion.
365393
366394
Args:
@@ -393,6 +421,20 @@ def _gen_block_node_rules(self, reference_data):
393421
block_node_rules.append(self.custom_rule)
394422
return block_node_rules
395423

424+
def _gen_include_node_rules(self):
425+
"""Generate list of rules for force-including nodes in low precision.
426+
427+
Returns:
428+
list[NodeRuleBase]: List of rules to apply.
429+
"""
430+
include_node_rules: list[NodeRuleBase] = []
431+
if self.nodes_to_include:
432+
include_node_rules.append(IncludeNodeNameRegexRule(self.nodes_to_include))
433+
if self.op_types_to_include:
434+
include_node_rules.append(IncludeOpTypes(self.op_types_to_include))
435+
436+
return include_node_rules
437+
396438
def run(self, ref_outputs_dict=None):
397439
"""Run node classification.
398440
@@ -402,12 +444,15 @@ def run(self, ref_outputs_dict=None):
402444
Returns:
403445
tuple: Lists of node names (low_precision_nodes, high_precision_nodes).
404446
"""
405-
block_node_rules = self._gen_block_node_rules(ref_outputs_dict)
447+
exclude_node_rules = self._gen_exclude_node_rules(ref_outputs_dict)
448+
include_node_rules = self._gen_include_node_rules()
406449
low_precision_nodes = []
407450
high_precision_nodes = []
408451
for node in self.model.graph.node:
409452
# If any condition is met - node will be executed in high precision
410-
if any(rule.check(node) for rule in block_node_rules):
453+
if any(rule.check(node) for rule in exclude_node_rules) and not any(
454+
rule.check(node) for rule in include_node_rules
455+
):
411456
high_precision_nodes.append(node.name)
412457
else:
413458
low_precision_nodes.append(node.name)

tests/unit/onnx/autocast/test_nodeclassifier.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -330,3 +330,42 @@ def test_node_classifier_op_types_to_exclude(test_model):
330330
assert len(fp16_nodes) + len(fp32_nodes) == 2
331331
# Test that no node is in both fp16 and fp32 lists
332332
assert not set(fp16_nodes).intersection(set(fp32_nodes))
333+
334+
335+
# Test that nodes_to_include and op_types_to_include force nodes into low precision,
336+
# even if they would otherwise be excluded by other rules.
337+
def test_node_classifier_force_include(test_model):
338+
node_to_init_map = {
339+
"add_node": [
340+
numpy_helper.from_array(np.array([[10.0, 20.0], [30.0, 40.0]], dtype=np.float32))
341+
],
342+
"mul_node": [numpy_helper.from_array(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32))],
343+
}
344+
345+
# Set init_max low so both nodes would normally be excluded (kept in FP32)
346+
# Force add_node to low precision, despite exceeding init_max
347+
classifier = NodeClassifier(
348+
model=test_model,
349+
node_to_init_map=node_to_init_map,
350+
init_max=1.0,
351+
nodes_to_include=["add_node"],
352+
)
353+
fp16_nodes, fp32_nodes = classifier.run()
354+
# add_node should be in fp16_nodes due to nodes_to_include, despite exceeding data_max
355+
assert "add_node" in fp16_nodes
356+
assert "mul_node" in fp32_nodes
357+
assert "add_node" not in fp32_nodes
358+
assert len(fp16_nodes) + len(fp32_nodes) == 2
359+
assert not set(fp16_nodes).intersection(set(fp32_nodes))
360+
361+
# Test that include op rule override exclude op rule
362+
classifier2 = NodeClassifier(
363+
model=test_model,
364+
node_to_init_map=node_to_init_map,
365+
op_types_to_exclude=["Add"],
366+
nodes_to_include=["add_node"], # Should override op_types_to_exclude
367+
)
368+
fp16_nodes, fp32_nodes = classifier2.run()
369+
assert "add_node" in fp16_nodes
370+
assert "add_node" not in fp32_nodes
371+
assert not set(fp16_nodes).intersection(set(fp32_nodes))

0 commit comments

Comments
 (0)