1515
1616import os
1717
18+ import numpy as np
19+ import onnx
1820import pytest
21+ from _test_utils .torch_model .vision_models import get_tiny_resnet_and_input
22+ from onnx .helper import (
23+ make_graph ,
24+ make_model ,
25+ make_node ,
26+ make_opsetid ,
27+ make_tensor ,
28+ make_tensor_value_info ,
29+ )
1930
20- from modelopt .onnx .utils import save_onnx_bytes_to_dir , validate_onnx
31+ from modelopt .onnx .utils import (
32+ get_input_names_from_bytes ,
33+ get_output_names_from_bytes ,
34+ randomize_weights_onnx_bytes ,
35+ remove_node_training_mode ,
36+ remove_weights_data ,
37+ save_onnx_bytes_to_dir ,
38+ validate_onnx ,
39+ )
40+ from modelopt .torch ._deploy .utils import get_onnx_bytes
2141
2242
2343@pytest .mark .parametrize (
@@ -31,3 +51,205 @@ def test_validate_onnx(onnx_bytes):
3151def test_save_onnx (tmp_path ):
3252 save_onnx_bytes_to_dir (b"test_onnx_bytes" , tmp_path , "test" )
3353 assert os .path .exists (os .path .join (tmp_path , "test.onnx" ))
54+
55+
56+ def make_onnx_model_for_matmul_op ():
57+ input_left = np .array ([1 , 2 ])
58+ input_right = np .array ([1 , 3 ])
59+ output_shape = np .matmul (input_left , input_right ).shape
60+ node = make_node ("MatMul" , ["X" , "Y" ], ["Z" ], name = "matmul" )
61+ graph = make_graph (
62+ [node ],
63+ "test_graph" ,
64+ [
65+ make_tensor_value_info ("X" , onnx .TensorProto .FLOAT , input_left .shape ),
66+ make_tensor_value_info ("Y" , onnx .TensorProto .FLOAT , input_right .shape ),
67+ ],
68+ [make_tensor_value_info ("Z" , onnx .TensorProto .FLOAT , output_shape )],
69+ )
70+ model = make_model (graph , producer_name = "Omniengine Tester" )
71+ return model .SerializeToString ()
72+
73+
74+ def test_input_names ():
75+ model_bytes = make_onnx_model_for_matmul_op ()
76+ input_names = get_input_names_from_bytes (model_bytes )
77+ assert input_names == ["X" , "Y" ]
78+
79+
80+ def test_output_names ():
81+ model_bytes = make_onnx_model_for_matmul_op ()
82+ output_names = get_output_names_from_bytes (model_bytes )
83+ assert output_names == ["Z" ]
84+
85+
86+ def _get_avg_var_of_weights (model ):
87+ inits = model .graph .initializer
88+ avg_var_dict = {}
89+
90+ for init in inits :
91+ if len (init .dims ) > 1 :
92+ dtype = onnx .helper .tensor_dtype_to_np_dtype (init .data_type )
93+ if dtype in ["float16" , "float32" , "float64" ]:
94+ np_tensor = np .frombuffer (init .raw_data , dtype = dtype )
95+ avg_var_dict [init .name + "_avg" ] = np .average (np_tensor )
96+ avg_var_dict [init .name + "_var" ] = np .var (np_tensor )
97+
98+ return avg_var_dict
99+
100+
101+ def test_random_onnx_weights ():
102+ model , args , kwargs = get_tiny_resnet_and_input ()
103+ assert not kwargs
104+
105+ onnx_bytes = get_onnx_bytes (model , args )
106+ original_avg_var_dict = _get_avg_var_of_weights (onnx .load_from_string (onnx_bytes ))
107+ original_model_size = len (onnx_bytes )
108+
109+ onnx_bytes = remove_weights_data (onnx_bytes )
110+ # Removed model weights should be greater than 18 MB
111+ assert original_model_size - len (onnx_bytes ) > 18e6
112+
113+ # After assigning random weights, model size should be slightly greater than the the original
114+ # size due to some extra metadata
115+ onnx_bytes = randomize_weights_onnx_bytes (onnx_bytes )
116+ assert len (onnx_bytes ) > original_model_size
117+
118+ randomized_avg_var_dict = _get_avg_var_of_weights (onnx .load_from_string (onnx_bytes ))
119+ for key , value in original_avg_var_dict .items ():
120+ assert abs (value - randomized_avg_var_dict [key ]) < 0.1
121+
122+
123+ def test_reproducible_random_weights ():
124+ model , args , kwargs = get_tiny_resnet_and_input ()
125+ assert not kwargs
126+
127+ original_onnx_bytes = get_onnx_bytes (model , args )
128+ onnx_bytes_wo_weights = remove_weights_data (original_onnx_bytes )
129+
130+ # Check if the randomization produces the same weights
131+ onnx_bytes_1 = randomize_weights_onnx_bytes (onnx_bytes_wo_weights )
132+ onnx_bytes_2 = randomize_weights_onnx_bytes (onnx_bytes_wo_weights )
133+ assert onnx_bytes_1 == onnx_bytes_2
134+
135+
136+ def _make_bn_initializer (name : str , shape , value = 1.0 ):
137+ """Helper to create an initializer tensor for BatchNorm."""
138+ data = np .full (shape , value , dtype = np .float32 )
139+ return make_tensor (name , onnx .TensorProto .FLOAT , shape , data .flatten ())
140+
141+
142+ def _make_batchnorm_model (bn_node , extra_value_infos = None ):
143+ """Helper to create an ONNX model with a BatchNormalization node.
144+
145+ The created model has the following schematic structure:
146+
147+ graph name: "test_graph"
148+ inputs:
149+ - input: FLOAT [1, 3, 224, 224]
150+ initializers:
151+ - scale: FLOAT [3]
152+ - bias: FLOAT [3]
153+ - mean: FLOAT [3]
154+ - var: FLOAT [3]
155+ nodes:
156+ - BatchNormalization (name comes from `bn_node`), with:
157+ inputs = ["input", "scale", "bias", "mean", "var"]
158+ outputs = as provided by `bn_node` (e.g., ["output"], or
159+ ["output", "running_mean", "running_var", "saved_mean"])
160+ outputs:
161+ - output: FLOAT [1, 3, 224, 224]
162+
163+ If `extra_value_infos` is provided (e.g., value_info for non-training outputs
164+ like "running_mean"/"running_var" and/or training-only outputs like
165+ "saved_mean"/"saved_inv_std"), they are attached to the graph's value_info.
166+ Some tests subsequently invoke utilities (e.g., remove_node_training_mode)
167+ that prune training-only outputs and their value_info entries, while keeping
168+ regular outputs such as "running_mean" and "running_var" intact.
169+ """
170+ initializers = [
171+ _make_bn_initializer ("scale" , [3 ], 1.0 ),
172+ _make_bn_initializer ("bias" , [3 ], 0.0 ),
173+ _make_bn_initializer ("mean" , [3 ], 0.0 ),
174+ _make_bn_initializer ("var" , [3 ], 1.0 ),
175+ ]
176+
177+ graph_outputs = []
178+ for output_name , shape in [
179+ ("output" , [1 , 3 , 224 , 224 ]),
180+ ("running_mean" , [3 ]),
181+ ("running_var" , [3 ]),
182+ ]:
183+ if output_name in bn_node .output :
184+ graph_outputs .append (make_tensor_value_info (output_name , onnx .TensorProto .FLOAT , shape ))
185+
186+ graph_def = make_graph (
187+ [bn_node ],
188+ "test_graph" ,
189+ [make_tensor_value_info ("input" , onnx .TensorProto .FLOAT , [1 , 3 , 224 , 224 ])],
190+ graph_outputs ,
191+ initializer = initializers ,
192+ value_info = extra_value_infos or [],
193+ )
194+
195+ return make_model (graph_def , opset_imports = [make_opsetid ("" , 14 )])
196+
197+
198+ def test_remove_node_training_mode_attribute ():
199+ """Test removal of training_mode attribute from BatchNormalization nodes."""
200+ bn_node = make_node (
201+ "BatchNormalization" ,
202+ inputs = ["input" , "scale" , "bias" , "mean" , "var" ],
203+ outputs = ["output" ],
204+ name = "bn1" ,
205+ training_mode = 1 , # This attribute should be removed
206+ )
207+
208+ model = _make_batchnorm_model (bn_node )
209+ result_model = remove_node_training_mode (model , "BatchNormalization" )
210+
211+ bn_node_result = result_model .graph .node [0 ]
212+ assert bn_node_result .op_type == "BatchNormalization"
213+
214+ # Check that training_mode attribute is not present
215+ attr_names = [attr .name for attr in bn_node_result .attribute ]
216+ assert "training_mode" not in attr_names
217+
218+
219+ def test_remove_node_extra_training_outputs ():
220+ """Test removal of extra training outputs from BatchNormalization nodes."""
221+ bn_node = make_node (
222+ "BatchNormalization" ,
223+ inputs = ["input" , "scale" , "bias" , "mean" , "var" ],
224+ outputs = [
225+ "output" ,
226+ "running_mean" ,
227+ "running_var" ,
228+ "saved_mean" ,
229+ "saved_inv_std" ,
230+ ],
231+ name = "bn1" ,
232+ training_mode = 1 ,
233+ )
234+
235+ # Extra training outputs are attached to the graph's value_info
236+ value_infos = [
237+ make_tensor_value_info ("saved_mean" , onnx .TensorProto .FLOAT , [3 ]),
238+ make_tensor_value_info ("saved_inv_std" , onnx .TensorProto .FLOAT , [3 ]),
239+ ]
240+
241+ model = _make_batchnorm_model (bn_node , extra_value_infos = value_infos )
242+ result_model = remove_node_training_mode (model , "BatchNormalization" )
243+
244+ # Verify only the non-training outputs remain
245+ bn_node_result = result_model .graph .node [0 ]
246+ print (bn_node_result .output )
247+ assert len (bn_node_result .output ) == 3
248+ assert bn_node_result .output [0 ] == "output"
249+ assert bn_node_result .output [1 ] == "running_mean"
250+ assert bn_node_result .output [2 ] == "running_var"
251+
252+ # Verify value_info entries for removed outputs are cleaned up
253+ value_info_names = [vi .name for vi in result_model .graph .value_info ]
254+ assert "saved_mean" not in value_info_names
255+ assert "saved_inv_std" not in value_info_names
0 commit comments