diff --git a/modelopt/onnx/utils.py b/modelopt/onnx/utils.py index 0e1740c2a..4650e99b2 100644 --- a/modelopt/onnx/utils.py +++ b/modelopt/onnx/utils.py @@ -25,7 +25,6 @@ import numpy as np import onnx import onnx_graphsurgeon as gs -from onnx import TensorProto, ValueInfoProto, numpy_helper from onnx.helper import get_attribute_value from onnx_graphsurgeon import Constant, Node, Variable @@ -289,7 +288,7 @@ def _convert_types_to_np(types: dict[str, int] | list[int] | int) -> Any: def get_tensor_by_name( onnx_model: onnx.ModelProto, tensor_name: str -) -> ValueInfoProto | TensorProto | None: +) -> onnx.ValueInfoProto | onnx.TensorProto | None: """This function returns a tensor from its name. This function searches for a tensor in the model's: @@ -438,7 +437,7 @@ def randomize_weights_onnx_bytes(onnx_bytes: bytes, seed: int = 0) -> bytes: numpy_array = np.random.normal(float(avg), float(var), size=init.dims).astype( dtype ) - tensor = numpy_helper.from_array(numpy_array, init.name) + tensor = onnx.numpy_helper.from_array(numpy_array, init.name) model.graph.initializer[idx].CopyFrom(tensor) buffer = io.BytesIO() @@ -751,3 +750,53 @@ def onnx_type_str_to_enum(dtype: str) -> int: dtype = dtype.split("tensor(")[-1].split(")")[0] dtype = "FLOAT" if dtype == "float32" else dtype.upper() return getattr(onnx.TensorProto, dtype) + + +def remove_node_training_mode(onnx_model: onnx.ModelProto, node_op_type: str) -> onnx.ModelProto: + """Remove `training_mode` attribute and extra training outputs from nodes of a given op type. + + This also removes the unused outputs from the training_mode nodes. + + Args: + onnx_model: The onnx model. + node_op_type: The node type to remove training_mode attribute from. + + Returns: + The onnx model with the training_mode attribute removed. + """ + removed_output_names = set() + all_inputs = {inp for n in onnx_model.graph.node for inp in n.input} + graph_outputs = {o.name for o in onnx_model.graph.output} + keep = all_inputs | graph_outputs + + for node in onnx_model.graph.node: + if node.op_type != node_op_type: + continue + + is_training_mode = False + # Drop the 'training_mode' attribute if present + for idx, attr in enumerate(list(node.attribute)): + if attr.name == "training_mode": + del node.attribute[idx] + if attr.i == 1: + is_training_mode = True + break + + # If the node has extra outputs, remove them all including the training outputs + if is_training_mode: + to_remove = [] + for name in node.output: + if name not in keep: + removed_output_names.add(name) + to_remove.append(name) + + for name in to_remove: + node.output.remove(name) + + if removed_output_names: + # Clean up corresponding value_info entries + keep = [vi for vi in onnx_model.graph.value_info if vi.name not in removed_output_names] + del onnx_model.graph.value_info[:] + onnx_model.graph.value_info.extend(keep) + + return onnx_model diff --git a/modelopt/torch/_deploy/utils/torch_onnx.py b/modelopt/torch/_deploy/utils/torch_onnx.py index a922bb19f..e18a9d209 100644 --- a/modelopt/torch/_deploy/utils/torch_onnx.py +++ b/modelopt/torch/_deploy/utils/torch_onnx.py @@ -45,6 +45,7 @@ get_node_names, get_output_names, get_output_shapes, + remove_node_training_mode, ) from modelopt.torch.quantization.export_onnx import configure_linear_module_onnx_quantizers from modelopt.torch.utils import flatten_tree, standardize_named_model_args @@ -569,25 +570,3 @@ def get_onnx_bytes(*args, **kwargs) -> bytes: onnx_bytes = get_onnx_bytes_and_metadata(*args, **kwargs)[0] onnx_bytes_obj = OnnxBytes.from_bytes(onnx_bytes) return onnx_bytes_obj.get_onnx_model_file_bytes() - - -def remove_node_training_mode(onnx_model: ModelProto, node_op_type: str) -> ModelProto: - """Remove training_mode attribute from selected node type. - - Args: - onnx_model: The onnx model. - node_op_type: The node type to remove training_mode attribute from. - - Returns: - The onnx model with the training_mode attribute removed. - """ - for node in onnx_model.graph.node: - if node.op_type == node_op_type: - for attribute in node.attribute: - if attribute.name == "training_mode": - if attribute.i == 1: - node.output.remove(node.output[1]) - node.output.remove(node.output[1]) - attribute.i = 0 - - return onnx_model diff --git a/tests/unit/onnx/test_onnx_utils.py b/tests/unit/onnx/test_onnx_utils.py index 44a155593..ede97302d 100644 --- a/tests/unit/onnx/test_onnx_utils.py +++ b/tests/unit/onnx/test_onnx_utils.py @@ -15,9 +15,29 @@ import os +import numpy as np +import onnx import pytest +from _test_utils.torch_model.vision_models import get_tiny_resnet_and_input +from onnx.helper import ( + make_graph, + make_model, + make_node, + make_opsetid, + make_tensor, + make_tensor_value_info, +) -from modelopt.onnx.utils import save_onnx_bytes_to_dir, validate_onnx +from modelopt.onnx.utils import ( + get_input_names_from_bytes, + get_output_names_from_bytes, + randomize_weights_onnx_bytes, + remove_node_training_mode, + remove_weights_data, + save_onnx_bytes_to_dir, + validate_onnx, +) +from modelopt.torch._deploy.utils import get_onnx_bytes @pytest.mark.parametrize( @@ -31,3 +51,205 @@ def test_validate_onnx(onnx_bytes): def test_save_onnx(tmp_path): save_onnx_bytes_to_dir(b"test_onnx_bytes", tmp_path, "test") assert os.path.exists(os.path.join(tmp_path, "test.onnx")) + + +def make_onnx_model_for_matmul_op(): + input_left = np.array([1, 2]) + input_right = np.array([1, 3]) + output_shape = np.matmul(input_left, input_right).shape + node = make_node("MatMul", ["X", "Y"], ["Z"], name="matmul") + graph = make_graph( + [node], + "test_graph", + [ + make_tensor_value_info("X", onnx.TensorProto.FLOAT, input_left.shape), + make_tensor_value_info("Y", onnx.TensorProto.FLOAT, input_right.shape), + ], + [make_tensor_value_info("Z", onnx.TensorProto.FLOAT, output_shape)], + ) + model = make_model(graph, producer_name="Omniengine Tester") + return model.SerializeToString() + + +def test_input_names(): + model_bytes = make_onnx_model_for_matmul_op() + input_names = get_input_names_from_bytes(model_bytes) + assert input_names == ["X", "Y"] + + +def test_output_names(): + model_bytes = make_onnx_model_for_matmul_op() + output_names = get_output_names_from_bytes(model_bytes) + assert output_names == ["Z"] + + +def _get_avg_var_of_weights(model): + inits = model.graph.initializer + avg_var_dict = {} + + for init in inits: + if len(init.dims) > 1: + dtype = onnx.helper.tensor_dtype_to_np_dtype(init.data_type) + if dtype in ["float16", "float32", "float64"]: + np_tensor = np.frombuffer(init.raw_data, dtype=dtype) + avg_var_dict[init.name + "_avg"] = np.average(np_tensor) + avg_var_dict[init.name + "_var"] = np.var(np_tensor) + + return avg_var_dict + + +def test_random_onnx_weights(): + model, args, kwargs = get_tiny_resnet_and_input() + assert not kwargs + + onnx_bytes = get_onnx_bytes(model, args) + original_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_bytes)) + original_model_size = len(onnx_bytes) + + onnx_bytes = remove_weights_data(onnx_bytes) + # Removed model weights should be greater than 18 MB + assert original_model_size - len(onnx_bytes) > 18e6 + + # After assigning random weights, model size should be slightly greater than the the original + # size due to some extra metadata + onnx_bytes = randomize_weights_onnx_bytes(onnx_bytes) + assert len(onnx_bytes) > original_model_size + + randomized_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_bytes)) + for key, value in original_avg_var_dict.items(): + assert abs(value - randomized_avg_var_dict[key]) < 0.1 + + +def test_reproducible_random_weights(): + model, args, kwargs = get_tiny_resnet_and_input() + assert not kwargs + + original_onnx_bytes = get_onnx_bytes(model, args) + onnx_bytes_wo_weights = remove_weights_data(original_onnx_bytes) + + # Check if the randomization produces the same weights + onnx_bytes_1 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights) + onnx_bytes_2 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights) + assert onnx_bytes_1 == onnx_bytes_2 + + +def _make_bn_initializer(name: str, shape, value=1.0): + """Helper to create an initializer tensor for BatchNorm.""" + data = np.full(shape, value, dtype=np.float32) + return make_tensor(name, onnx.TensorProto.FLOAT, shape, data.flatten()) + + +def _make_batchnorm_model(bn_node, extra_value_infos=None): + """Helper to create an ONNX model with a BatchNormalization node. + + The created model has the following schematic structure: + + graph name: "test_graph" + inputs: + - input: FLOAT [1, 3, 224, 224] + initializers: + - scale: FLOAT [3] + - bias: FLOAT [3] + - mean: FLOAT [3] + - var: FLOAT [3] + nodes: + - BatchNormalization (name comes from `bn_node`), with: + inputs = ["input", "scale", "bias", "mean", "var"] + outputs = as provided by `bn_node` (e.g., ["output"], or + ["output", "running_mean", "running_var", "saved_mean"]) + outputs: + - output: FLOAT [1, 3, 224, 224] + + If `extra_value_infos` is provided (e.g., value_info for non-training outputs + like "running_mean"/"running_var" and/or training-only outputs like + "saved_mean"/"saved_inv_std"), they are attached to the graph's value_info. + Some tests subsequently invoke utilities (e.g., remove_node_training_mode) + that prune training-only outputs and their value_info entries, while keeping + regular outputs such as "running_mean" and "running_var" intact. + """ + initializers = [ + _make_bn_initializer("scale", [3], 1.0), + _make_bn_initializer("bias", [3], 0.0), + _make_bn_initializer("mean", [3], 0.0), + _make_bn_initializer("var", [3], 1.0), + ] + + graph_outputs = [] + for output_name, shape in [ + ("output", [1, 3, 224, 224]), + ("running_mean", [3]), + ("running_var", [3]), + ]: + if output_name in bn_node.output: + graph_outputs.append(make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, shape)) + + graph_def = make_graph( + [bn_node], + "test_graph", + [make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224])], + graph_outputs, + initializer=initializers, + value_info=extra_value_infos or [], + ) + + return make_model(graph_def, opset_imports=[make_opsetid("", 14)]) + + +def test_remove_node_training_mode_attribute(): + """Test removal of training_mode attribute from BatchNormalization nodes.""" + bn_node = make_node( + "BatchNormalization", + inputs=["input", "scale", "bias", "mean", "var"], + outputs=["output"], + name="bn1", + training_mode=1, # This attribute should be removed + ) + + model = _make_batchnorm_model(bn_node) + result_model = remove_node_training_mode(model, "BatchNormalization") + + bn_node_result = result_model.graph.node[0] + assert bn_node_result.op_type == "BatchNormalization" + + # Check that training_mode attribute is not present + attr_names = [attr.name for attr in bn_node_result.attribute] + assert "training_mode" not in attr_names + + +def test_remove_node_extra_training_outputs(): + """Test removal of extra training outputs from BatchNormalization nodes.""" + bn_node = make_node( + "BatchNormalization", + inputs=["input", "scale", "bias", "mean", "var"], + outputs=[ + "output", + "running_mean", + "running_var", + "saved_mean", + "saved_inv_std", + ], + name="bn1", + training_mode=1, + ) + + # Extra training outputs are attached to the graph's value_info + value_infos = [ + make_tensor_value_info("saved_mean", onnx.TensorProto.FLOAT, [3]), + make_tensor_value_info("saved_inv_std", onnx.TensorProto.FLOAT, [3]), + ] + + model = _make_batchnorm_model(bn_node, extra_value_infos=value_infos) + result_model = remove_node_training_mode(model, "BatchNormalization") + + # Verify only the non-training outputs remain + bn_node_result = result_model.graph.node[0] + print(bn_node_result.output) + assert len(bn_node_result.output) == 3 + assert bn_node_result.output[0] == "output" + assert bn_node_result.output[1] == "running_mean" + assert bn_node_result.output[2] == "running_var" + + # Verify value_info entries for removed outputs are cleaned up + value_info_names = [vi.name for vi in result_model.graph.value_info] + assert "saved_mean" not in value_info_names + assert "saved_inv_std" not in value_info_names diff --git a/tests/unit/torch/deploy/utils/test_torch_onnx_utils.py b/tests/unit/torch/deploy/utils/test_torch_onnx_utils.py index 8540a33d6..7363819ba 100644 --- a/tests/unit/torch/deploy/utils/test_torch_onnx_utils.py +++ b/tests/unit/torch/deploy/utils/test_torch_onnx_utils.py @@ -23,17 +23,8 @@ import torch import torch.nn as nn from _test_utils.torch_model.deploy_models import BaseDeployModel, get_deploy_models -from _test_utils.torch_model.vision_models import get_tiny_resnet_and_input -from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info - -from modelopt.onnx.utils import ( - get_batch_size_from_bytes, - get_input_names_from_bytes, - get_output_names_from_bytes, - randomize_weights_onnx_bytes, - remove_weights_data, - validate_batch_size, -) + +from modelopt.onnx.utils import get_batch_size_from_bytes, validate_batch_size from modelopt.torch._deploy.utils import ( OnnxBytes, flatten_tree, @@ -175,83 +166,3 @@ def test_get_and_validate_batch_size(model, n_args, batch_size): assert validate_batch_size(onnx_bytes, 3) is False assert batch_size == get_batch_size_from_bytes(onnx_bytes) - - -def make_onnx_model_for_matmul_op(): - input_left = np.array([1, 2]) - input_right = np.array([1, 3]) - output_shape = np.matmul(input_left, input_right).shape - node = make_node("MatMul", ["X", "Y"], ["Z"], name="matmul") - graph = make_graph( - [node], - "test_graph", - [ - make_tensor_value_info("X", onnx.TensorProto.FLOAT, input_left.shape), - make_tensor_value_info("Y", onnx.TensorProto.FLOAT, input_right.shape), - ], - [make_tensor_value_info("Z", onnx.TensorProto.FLOAT, output_shape)], - ) - model = make_model(graph, producer_name="Omniengine Tester") - return model.SerializeToString() - - -def test_input_names(): - model_bytes = make_onnx_model_for_matmul_op() - input_names = get_input_names_from_bytes(model_bytes) - assert input_names == ["X", "Y"] - - -def test_output_names(): - model_bytes = make_onnx_model_for_matmul_op() - output_names = get_output_names_from_bytes(model_bytes) - assert output_names == ["Z"] - - -def _get_avg_var_of_weights(model): - inits = model.graph.initializer - avg_var_dict = {} - - for init in inits: - if len(init.dims) > 1: - dtype = onnx.helper.tensor_dtype_to_np_dtype(init.data_type) - if dtype in ["float16", "float32", "float64"]: - np_tensor = np.frombuffer(init.raw_data, dtype=dtype) - avg_var_dict[init.name + "_avg"] = np.average(np_tensor) - avg_var_dict[init.name + "_var"] = np.var(np_tensor) - - return avg_var_dict - - -def test_random_onnx_weights(): - model, args, kwargs = get_tiny_resnet_and_input() - assert not kwargs - - onnx_bytes = get_onnx_bytes(model, args) - original_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_bytes)) - original_model_size = len(onnx_bytes) - - onnx_bytes = remove_weights_data(onnx_bytes) - # Removed model weights should be greater than 18 MB - assert original_model_size - len(onnx_bytes) > 18e6 - - # After assigning random weights, model size should be slightly greater than the the original - # size due to some extra metadata - onnx_bytes = randomize_weights_onnx_bytes(onnx_bytes) - assert len(onnx_bytes) > original_model_size - - randomized_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_bytes)) - for key, value in original_avg_var_dict.items(): - assert abs(value - randomized_avg_var_dict[key]) < 0.1 - - -def test_reproducible_random_weights(): - model, args, kwargs = get_tiny_resnet_and_input() - assert not kwargs - - original_onnx_bytes = get_onnx_bytes(model, args) - onnx_bytes_wo_weights = remove_weights_data(original_onnx_bytes) - - # Check if the randomization produces the same weights - onnx_bytes_1 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights) - onnx_bytes_2 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights) - assert onnx_bytes_1 == onnx_bytes_2