diff --git a/modelopt/onnx/autocast/precisionconverter.py b/modelopt/onnx/autocast/precisionconverter.py index 4f380e1d2..5abf7fdc5 100644 --- a/modelopt/onnx/autocast/precisionconverter.py +++ b/modelopt/onnx/autocast/precisionconverter.py @@ -67,9 +67,6 @@ class InitializerConsumerTracker: OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION = ["Upsample", "NonMaxSuppression", "Celu"] -# Temporarily block these ops in low precision, as they are not supported yet -OP_TYPES_NOT_SUPPORTED_IN_LOW_PRECISION.extend(["Scan", "If", "Loop"]) - # Mapping of op types to indices of inputs that should not be converted to low precision. SKIP_LOW_PRECISION_MAPPING_FP16 = {"Resize": {2}} SKIP_LOW_PRECISION_MAPPING_BF16 = {"Resize": {1, 2}} @@ -240,8 +237,8 @@ def convert( tensor_to_producers=tensor_to_producers, ) - # Convert initializers to correct precision according to the consumer nodes - self._convert_initializers( + # Convert initializers to correct precision according to the consumer nodes (main graph + subgraphs) + self._convert_initializers_recursive( low_precision_nodes=low_precision_nodes, high_precision_nodes=high_precision_nodes ) @@ -250,17 +247,8 @@ def convert( # Populate type information with inferred types self.model = self._propagate_types_shapes_custom_ops(self.model) else: - # Clear type/shape information for intermediates and outputs - for vi in self.model.graph.value_info: - vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - for idx, d in enumerate(vi.type.tensor_type.shape.dim): - if d.dim_value: - vi.type.tensor_type.shape.dim[idx].dim_param = "unk" - for out in self.model.graph.output: - out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED - for idx, d in enumerate(out.type.tensor_type.shape.dim): - if d.dim_value: - out.type.tensor_type.shape.dim[idx].dim_param = "unk" + # Clear type/shape information for intermediates and outputs (including subgraphs) + self._clear_types_and_shapes_recursive(self.model.graph) # Populate type information with inferred types self.model = onnx_utils.infer_shapes(self.model, strict_mode=True, check_type=False) self._ensure_types_are_defined() @@ -285,6 +273,47 @@ def _ensure_types_are_defined(self): if vi.type.tensor_type.elem_type == onnx.TensorProto.UNDEFINED: vi.type.tensor_type.elem_type = self.low_precision_type.onnx_type + def _clear_types_and_shapes_recursive( + self, graph: onnx.GraphProto, is_subgraph: bool = False + ) -> None: + """Recursively clear type/shape information for a graph and all its subgraphs. + + This is necessary for control flow operators (Scan, If, Loop) which have subgraphs. + + Args: + graph: The ONNX graph to clear types and shapes for. + is_subgraph: Whether this is a subgraph (True) or the main graph (False). + """ + + def _clear_callback(g: onnx.GraphProto, parent: onnx.NodeProto, is_sub: bool) -> None: + logger.debug( + f"Clearing types/shapes in {'subgraph' if is_sub else 'main graph'}: {g.name}" + ) + + # Clear type/shape information for inputs (only for subgraphs, not main graph inputs) + if is_sub: + for inp in g.input: + if inp.type.HasField("tensor_type"): + inp.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED + for idx, d in enumerate(inp.type.tensor_type.shape.dim): + if d.dim_value: + inp.type.tensor_type.shape.dim[idx].dim_param = "unk" + + # Clear type/shape information for intermediates and outputs + for vi in g.value_info: + vi.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED + for idx, d in enumerate(vi.type.tensor_type.shape.dim): + if d.dim_value: + vi.type.tensor_type.shape.dim[idx].dim_param = "unk" + + for out in g.output: + out.type.tensor_type.elem_type = onnx.TensorProto.UNDEFINED + for idx, d in enumerate(out.type.tensor_type.shape.dim): + if d.dim_value: + out.type.tensor_type.shape.dim[idx].dim_param = "unk" + + utils.walk_subgraphs_recursive(graph, _clear_callback, is_subgraph=is_subgraph) + def _propagate_types_shapes_custom_ops(self, model): """Propagate types and shapes after insertion of 'Cast' nodes or other graph modifications.""" logger.info("Propagating tensor shapes and types in model with custom ops.") @@ -682,6 +711,118 @@ def _convert_initializers( node.node.input[node.node_index] = new_init_name self.model.graph.initializer.extend([new_init]) + def _convert_initializers_recursive( + self, low_precision_nodes: list[str], high_precision_nodes: list[str] + ) -> None: + """Convert initializers in main graph and all subgraphs to appropriate precision. + + For the main graph, uses sophisticated consumer tracking to determine precision. + For subgraphs, inherits precision from the parent control flow node and converts + all initializers to that precision (no runtime casts). + + Args: + low_precision_nodes: List of node names in main graph that are low precision. + high_precision_nodes: List of node names in main graph that are high precision. + """ + # Convert main graph initializers with full consumer tracking + self._convert_initializers(low_precision_nodes, high_precision_nodes) + + # Convert subgraph initializers - walk all subgraphs and convert based on parent node precision + low_precision_nodes_set = set(low_precision_nodes) + + def _convert_subgraph_callback( + graph: onnx.GraphProto, parent: onnx.NodeProto, is_subgraph: bool + ) -> None: + if not is_subgraph or parent is None: + return + + # Inherit precision from parent control flow node + target_type = ( + self.low_precision_type + if parent.name in low_precision_nodes_set + else self.high_precision_type + ) + + # Convert all float initializers to target precision + for init in graph.initializer: + if init.data_type not in ONNX_TYPES or init.data_type == target_type.onnx_type: + continue + + from_type = ( + self.high_precision_type + if init.data_type == self.high_precision_type.onnx_type + else self.low_precision_type + if init.data_type == self.low_precision_type.onnx_type + else None + ) + + if from_type is None: + logger.debug( + f"Skipping subgraph initializer {init.name} with unsupported type {init.data_type}" + ) + continue + + new_init = self._convert_initializer_data(init, from_type, target_type) + init.CopyFrom(new_init) + + utils.walk_subgraphs_recursive(self.model.graph, _convert_subgraph_callback) + + def _convert_initializer_data( + self, + init: onnx.TensorProto, + from_type: PrecisionTypes, + to_type: PrecisionTypes, + ) -> onnx.TensorProto: + """Convert initializer data to a new precision. + + This is the core conversion logic extracted for reuse. Handles bfloat16 conversion + and provides warnings when values are clamped or replaced due to precision limits. + + Args: + init: The initializer to convert. + from_type: The original precision of the initializer. + to_type: The new precision to cast the initializer to. + + Returns: + onnx.TensorProto: The converted initializer. + """ + np_array = numpy_helper.to_array(init) + + # Handle bfloat16 conversion + if self._is_bf16(to_type) and self._is_fp32(from_type): + new_init = onnx.TensorProto() + new_init.dims.extend(np_array.shape) + new_init.name = init.name + new_init.data_type = onnx.TensorProto.BFLOAT16 + bf16_bytes = np_array.astype(ml_dtypes.bfloat16).view(np.uint16) + new_init.raw_data = bf16_bytes.tobytes() + else: + assert to_type.numpy_type is not None + data_max, data_lowest = ( + np.finfo(to_type.numpy_type).max, + np.finfo(to_type.numpy_type).smallest_subnormal, + ) + if np.any(np.abs(np_array) > data_max): + logger.warning( + f"Initializer '{init.name}' contains values larger than largest " + f"{to_type.str_short} value, values will be clamped to {data_max}." + ) + np_array = np.clip(np_array, -1 * data_max, data_max) + if np.any((np_array != 0.0) & (np.abs(np_array) < data_lowest)): + logger.warning( + f"Initializer '{init.name}' contains values smaller than smallest " + f"{to_type.str_short} value, values will be replaced with {data_lowest:.1e}." + ) + np_array = np.where( + (np_array != 0.0) & (np.abs(np_array) < data_lowest), + data_lowest, + np_array, + ) + new_array = np_array.astype(to_type.numpy_type) + new_init = numpy_helper.from_array(new_array, init.name) + + return new_init + def _cast_initializer( self, init: onnx.TensorProto, @@ -699,9 +840,11 @@ def _cast_initializer( init: The initializer to cast. from_type: The original precision of the initializer. to_type: The new precision to cast the initializer to. + low_precision_nodes: Low precision nodes that consume this initializer. + high_precision_nodes: High precision nodes that consume this initializer. Returns: - onnx.TensorProto: The casted initializer. + onnx.TensorProto | None: The casted initializer, or None if a runtime cast was inserted instead. """ def _get_name(node: onnx.NodeProto | InputIndexTracker) -> str: @@ -727,47 +870,11 @@ def _get_name(node: onnx.NodeProto | InputIndexTracker) -> str: exclude_consumers = ( low_precision_nodes if self._is_fp32(to_type) else high_precision_nodes ) - exclude_consumers_names: list[str] = [] - exclude_consumers_names = [_get_name(node) for node in exclude_consumers] self._add_cast(init.name, to_type, exclude_consumers=exclude_consumers_names) return None - np_array = numpy_helper.to_array(init) - # Numpy does not support bfloat16, use ml_dtypes to create the raw data instead - if self._is_bf16(to_type) and self._is_fp32(from_type): - new_init = onnx.TensorProto() - new_init.dims.extend(np_array.shape) - new_init.name = init.name - new_init.data_type = onnx.TensorProto.BFLOAT16 - bf16_bytes = np_array.astype(ml_dtypes.bfloat16).view(np.uint16) - new_init.raw_data = bf16_bytes.tobytes() - else: - assert to_type.numpy_type is not None - data_max, data_lowest = ( - np.finfo(to_type.numpy_type).max, - np.finfo(to_type.numpy_type).smallest_subnormal, - ) - if np.any(np.abs(np_array) > data_max): - logger.warning( - f"Initializer {init.name} contains values larger than largest " - f"{to_type.str_short} value, values will be clamped to {data_max}." - ) - np_array = np.clip(np_array, -1 * data_max, data_max) - if np.any((np_array != 0.0) & (np.abs(np_array) < data_lowest)): - logger.warning( - f"Initializer {init.name} contains values smaller than smallest " - f"{to_type.str_short} value, values will be replaced with {data_lowest:.1e}." - ) - np_array = np.where( - (np_array != 0.0) & (np.abs(np_array) < data_lowest), - data_lowest, - np_array, - ) - new_array = np_array.astype(to_type.numpy_type) - new_init = numpy_helper.from_array(new_array, init.name) - - return new_init + return self._convert_initializer_data(init, from_type, to_type) def _replace_tensor_name( self, consumers: list[onnx.NodeProto], original_tensor_name: str, new_tensor_name: str diff --git a/modelopt/onnx/autocast/utils.py b/modelopt/onnx/autocast/utils.py index d9dc3a1f1..629fab089 100644 --- a/modelopt/onnx/autocast/utils.py +++ b/modelopt/onnx/autocast/utils.py @@ -23,6 +23,7 @@ import logging from collections import defaultdict +from collections.abc import Callable import onnx @@ -122,6 +123,41 @@ def get_cast_to_type(cast_node: onnx.NodeProto) -> int: raise ValueError("Cast node does not have 'to' attribute") +def walk_subgraphs_recursive( + graph: onnx.GraphProto, + callback: Callable, + parent_node: onnx.NodeProto = None, + is_subgraph: bool = False, +) -> None: + """Recursively walk through a graph and all its subgraphs, applying a callback. + + This utility function traverses an ONNX graph and all nested subgraphs by examining + graph attributes in nodes. It works with standard control flow operators (Scan, If, Loop) + as well as custom operators that define subgraphs using ONNX graph attributes. + + Args: + graph: The graph to walk. + callback: Function to call for each graph. Signature: callback(graph, parent_node, is_subgraph). + parent_node: The parent node containing this subgraph (None for main graph). + is_subgraph: Whether this is a subgraph (True) or the main graph (False). + + Note: + Works with any node that has attributes of type AttributeProto.GRAPH or + AttributeProto.GRAPHS, including custom operators. + """ + # Apply callback to current graph + callback(graph, parent_node, is_subgraph) + + # Recursively process subgraphs in control flow nodes + for node in graph.node: + for attr in node.attribute: + if attr.type == onnx.AttributeProto.GRAPH: + walk_subgraphs_recursive(attr.g, callback, parent_node=node, is_subgraph=True) + elif attr.type == onnx.AttributeProto.GRAPHS: + for subgraph in attr.graphs: + walk_subgraphs_recursive(subgraph, callback, parent_node=node, is_subgraph=True) + + def get_op_types_not_supported_in_low_precision( model: onnx.ModelProto, min_opset: int, diff --git a/tests/unit/onnx/autocast/test_precisionconverter.py b/tests/unit/onnx/autocast/test_precisionconverter.py index b460c5842..bf87d8058 100644 --- a/tests/unit/onnx/autocast/test_precisionconverter.py +++ b/tests/unit/onnx/autocast/test_precisionconverter.py @@ -1324,3 +1324,199 @@ def test_resize_op_tensor_scales_conversion( high_precision_nodes=[], low_precision_nodes=[node.name for node in model.graph.node] ) onnx.checker.check_model(converted_model) + + +#################################################################################################### +# Testing subgraph support, using If - Then Else subgraphs with initializers +#################################################################################################### +@pytest.fixture +def model_with_if_subgraph(): + """Create a model with an If operation containing subgraphs with initializers. + + The model has a preprocessing Add on X, then If branches use initializers. + This tests both external inputs (X flows through Add) and subgraph initializers. + """ + # Main graph inputs/outputs + x = helper.make_tensor_value_info("X", TensorProto.FLOAT, [2, 3]) + condition = helper.make_tensor_value_info("condition", TensorProto.BOOL, []) + y = helper.make_tensor_value_info("Y", TensorProto.FLOAT, [2, 4]) + + # Add a preprocessing node in main graph to use external input X + preprocess_weight = numpy_helper.from_array( + np.ones((2, 3), dtype=np.float32), name="preprocess_weight" + ) + preprocess_node = helper.make_node( + "Add", ["X", "preprocess_weight"], ["X_processed"], name="preprocess" + ) + + # Create "then" branch subgraph with initializers + then_y = helper.make_tensor_value_info("then_y", TensorProto.FLOAT, [2, 4]) + + w_true = np.random.randn(2, 4).astype(np.float32) + b_true = np.random.randn(2, 4).astype(np.float32) + w_true_init = numpy_helper.from_array(w_true, name="W_true") + b_true_init = numpy_helper.from_array(b_true, name="b_true") + + then_add = helper.make_node("Add", ["W_true", "b_true"], ["then_y"], name="then_add") + + then_graph = helper.make_graph( + [then_add], + "then_branch", + [], + [then_y], + [w_true_init, b_true_init], + ) + + # Create "else" branch subgraph with different initializers + else_y = helper.make_tensor_value_info("else_y", TensorProto.FLOAT, [2, 4]) + + w_false = np.random.randn(2, 4).astype(np.float32) * 2 # Different values + b_false = np.random.randn(2, 4).astype(np.float32) * 2 + w_false_init = numpy_helper.from_array(w_false, name="W_false") + b_false_init = numpy_helper.from_array(b_false, name="b_false") + + else_add = helper.make_node("Add", ["W_false", "b_false"], ["else_y"], name="else_add") + + else_graph = helper.make_graph( + [else_add], + "else_branch", + [], + [else_y], + [w_false_init, b_false_init], + ) + + # Create If node + if_node = helper.make_node( + "If", + inputs=["condition"], + outputs=["Y"], + name="if_node", + then_branch=then_graph, + else_branch=else_graph, + ) + + # Create main graph with preprocessing using external input + main_graph = helper.make_graph( + [preprocess_node, if_node], + "model_with_if", + [x, condition], + [y], + [preprocess_weight], + ) + + model = helper.make_model(main_graph, producer_name="model_with_if") + model.opset_import[0].version = 20 + model.ir_version = 10 + onnx.checker.check_model(model) + + model = onnx_utils.infer_shapes(model) + value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + return model, value_info_map, initializer_map, node_to_init_map + + +@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +@pytest.mark.parametrize("if_precision", ["low", "high"]) +def test_if_subgraph_initializer_conversion( + model_with_if_subgraph, low_precision_type, if_precision +): + """Test that initializers in If subgraphs are converted based on parent node precision.""" + model, value_info_map, initializer_map, node_to_init_map = model_with_if_subgraph + + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + keep_io_types=True, + low_precision_type=low_precision_type, + ) + + # Classify the If node based on test parameter + if if_precision == "low": + high_precision_nodes = [] + low_precision_nodes = ["if_node"] + expected_init_type = low_precision_onnx_type(low_precision_type) + else: + high_precision_nodes = ["if_node"] + low_precision_nodes = [] + expected_init_type = TensorProto.FLOAT + + converted_model = converter.convert(high_precision_nodes, low_precision_nodes) + + # Verify the model is valid + onnx.checker.check_model(converted_model) + + # Find the If node and check its subgraph initializers + if_node = next(n for n in converted_model.graph.node if n.op_type == "If") + + then_branch = None + else_branch = None + for attr in if_node.attribute: + if attr.name == "then_branch": + then_branch = attr.g + elif attr.name == "else_branch": + else_branch = attr.g + + assert then_branch is not None, "If node should have a then_branch attribute" + assert else_branch is not None, "If node should have an else_branch attribute" + + # Check that subgraph initializers in both branches were converted + assert len(then_branch.initializer) == 2, ( + "Then branch should have 2 initializers (W_true, b_true)" + ) + assert len(else_branch.initializer) == 2, ( + "Else branch should have 2 initializers (W_false, b_false)" + ) + + for init in then_branch.initializer: + assert init.data_type == expected_init_type, ( + f"Then branch initializer '{init.name}' should be {expected_init_type}, " + f"but is {init.data_type}" + ) + + for init in else_branch.initializer: + assert init.data_type == expected_init_type, ( + f"Else branch initializer '{init.name}' should be {expected_init_type}, " + f"but is {init.data_type}" + ) + + +@pytest.mark.parametrize("low_precision_type", ["fp16", "bf16"]) +def test_if_subgraph_mixed_precision_boundary(model_with_if_subgraph, low_precision_type): + """Test that types are correctly handled at If subgraph boundaries in mixed precision.""" + model, value_info_map, initializer_map, node_to_init_map = model_with_if_subgraph + + # Add another node after the If to create a mixed precision scenario + add_weight = numpy_helper.from_array(np.ones((2, 4), dtype=np.float32), name="add_weight") + model.graph.initializer.append(add_weight) + + add_node = helper.make_node("Add", ["Y", "add_weight"], ["output"], name="add_after_if") + model.graph.node.append(add_node) + + # Update output + output_tensor = helper.make_tensor_value_info("output", TensorProto.FLOAT, [2, 4]) + model.graph.output.append(output_tensor) + + # Refresh mappings + value_info_map, initializer_map, node_to_init_map = utils.setup_mappings(model) + + converter = PrecisionConverter( + model, + value_info_map, + initializer_map, + node_to_init_map, + keep_io_types=True, + low_precision_type=low_precision_type, + ) + + # If in low precision, Add in high precision + converted_model = converter.convert( + high_precision_nodes=["add_after_if"], low_precision_nodes=["if_node"] + ) + + # Verify the model is valid (this tests type inference through subgraph boundaries) + onnx.checker.check_model(converted_model) + + # Verify a cast was inserted between If output and Add input + cast_nodes = [n for n in converted_model.graph.node if n.op_type == "Cast"] + assert len(cast_nodes) > 0, "Should have cast nodes for mixed precision"