15
15
16
16
import os
17
17
18
+ import numpy as np
19
+ import onnx
18
20
import pytest
21
+ from _test_utils .torch_model .vision_models import get_tiny_resnet_and_input
22
+ from onnx .helper import (
23
+ make_graph ,
24
+ make_model ,
25
+ make_node ,
26
+ make_opsetid ,
27
+ make_tensor ,
28
+ make_tensor_value_info ,
29
+ )
19
30
20
- from modelopt .onnx .utils import save_onnx_bytes_to_dir , validate_onnx
31
+ from modelopt .onnx .utils import (
32
+ get_input_names_from_bytes ,
33
+ get_output_names_from_bytes ,
34
+ randomize_weights_onnx_bytes ,
35
+ remove_node_training_mode ,
36
+ remove_weights_data ,
37
+ save_onnx_bytes_to_dir ,
38
+ validate_onnx ,
39
+ )
40
+ from modelopt .torch ._deploy .utils import get_onnx_bytes
21
41
22
42
23
43
@pytest .mark .parametrize (
@@ -31,3 +51,205 @@ def test_validate_onnx(onnx_bytes):
31
51
def test_save_onnx (tmp_path ):
32
52
save_onnx_bytes_to_dir (b"test_onnx_bytes" , tmp_path , "test" )
33
53
assert os .path .exists (os .path .join (tmp_path , "test.onnx" ))
54
+
55
+
56
+ def make_onnx_model_for_matmul_op ():
57
+ input_left = np .array ([1 , 2 ])
58
+ input_right = np .array ([1 , 3 ])
59
+ output_shape = np .matmul (input_left , input_right ).shape
60
+ node = make_node ("MatMul" , ["X" , "Y" ], ["Z" ], name = "matmul" )
61
+ graph = make_graph (
62
+ [node ],
63
+ "test_graph" ,
64
+ [
65
+ make_tensor_value_info ("X" , onnx .TensorProto .FLOAT , input_left .shape ),
66
+ make_tensor_value_info ("Y" , onnx .TensorProto .FLOAT , input_right .shape ),
67
+ ],
68
+ [make_tensor_value_info ("Z" , onnx .TensorProto .FLOAT , output_shape )],
69
+ )
70
+ model = make_model (graph , producer_name = "Omniengine Tester" )
71
+ return model .SerializeToString ()
72
+
73
+
74
+ def test_input_names ():
75
+ model_bytes = make_onnx_model_for_matmul_op ()
76
+ input_names = get_input_names_from_bytes (model_bytes )
77
+ assert input_names == ["X" , "Y" ]
78
+
79
+
80
+ def test_output_names ():
81
+ model_bytes = make_onnx_model_for_matmul_op ()
82
+ output_names = get_output_names_from_bytes (model_bytes )
83
+ assert output_names == ["Z" ]
84
+
85
+
86
+ def _get_avg_var_of_weights (model ):
87
+ inits = model .graph .initializer
88
+ avg_var_dict = {}
89
+
90
+ for init in inits :
91
+ if len (init .dims ) > 1 :
92
+ dtype = onnx .helper .tensor_dtype_to_np_dtype (init .data_type )
93
+ if dtype in ["float16" , "float32" , "float64" ]:
94
+ np_tensor = np .frombuffer (init .raw_data , dtype = dtype )
95
+ avg_var_dict [init .name + "_avg" ] = np .average (np_tensor )
96
+ avg_var_dict [init .name + "_var" ] = np .var (np_tensor )
97
+
98
+ return avg_var_dict
99
+
100
+
101
+ def test_random_onnx_weights ():
102
+ model , args , kwargs = get_tiny_resnet_and_input ()
103
+ assert not kwargs
104
+
105
+ onnx_bytes = get_onnx_bytes (model , args )
106
+ original_avg_var_dict = _get_avg_var_of_weights (onnx .load_from_string (onnx_bytes ))
107
+ original_model_size = len (onnx_bytes )
108
+
109
+ onnx_bytes = remove_weights_data (onnx_bytes )
110
+ # Removed model weights should be greater than 18 MB
111
+ assert original_model_size - len (onnx_bytes ) > 18e6
112
+
113
+ # After assigning random weights, model size should be slightly greater than the the original
114
+ # size due to some extra metadata
115
+ onnx_bytes = randomize_weights_onnx_bytes (onnx_bytes )
116
+ assert len (onnx_bytes ) > original_model_size
117
+
118
+ randomized_avg_var_dict = _get_avg_var_of_weights (onnx .load_from_string (onnx_bytes ))
119
+ for key , value in original_avg_var_dict .items ():
120
+ assert abs (value - randomized_avg_var_dict [key ]) < 0.1
121
+
122
+
123
+ def test_reproducible_random_weights ():
124
+ model , args , kwargs = get_tiny_resnet_and_input ()
125
+ assert not kwargs
126
+
127
+ original_onnx_bytes = get_onnx_bytes (model , args )
128
+ onnx_bytes_wo_weights = remove_weights_data (original_onnx_bytes )
129
+
130
+ # Check if the randomization produces the same weights
131
+ onnx_bytes_1 = randomize_weights_onnx_bytes (onnx_bytes_wo_weights )
132
+ onnx_bytes_2 = randomize_weights_onnx_bytes (onnx_bytes_wo_weights )
133
+ assert onnx_bytes_1 == onnx_bytes_2
134
+
135
+
136
+ def _make_bn_initializer (name : str , shape , value = 1.0 ):
137
+ """Helper to create an initializer tensor for BatchNorm."""
138
+ data = np .full (shape , value , dtype = np .float32 )
139
+ return make_tensor (name , onnx .TensorProto .FLOAT , shape , data .flatten ())
140
+
141
+
142
+ def _make_batchnorm_model (bn_node , extra_value_infos = None ):
143
+ """Helper to create an ONNX model with a BatchNormalization node.
144
+
145
+ The created model has the following schematic structure:
146
+
147
+ graph name: "test_graph"
148
+ inputs:
149
+ - input: FLOAT [1, 3, 224, 224]
150
+ initializers:
151
+ - scale: FLOAT [3]
152
+ - bias: FLOAT [3]
153
+ - mean: FLOAT [3]
154
+ - var: FLOAT [3]
155
+ nodes:
156
+ - BatchNormalization (name comes from `bn_node`), with:
157
+ inputs = ["input", "scale", "bias", "mean", "var"]
158
+ outputs = as provided by `bn_node` (e.g., ["output"], or
159
+ ["output", "running_mean", "running_var", "saved_mean"])
160
+ outputs:
161
+ - output: FLOAT [1, 3, 224, 224]
162
+
163
+ If `extra_value_infos` is provided (e.g., value_info for non-training outputs
164
+ like "running_mean"/"running_var" and/or training-only outputs like
165
+ "saved_mean"/"saved_inv_std"), they are attached to the graph's value_info.
166
+ Some tests subsequently invoke utilities (e.g., remove_node_training_mode)
167
+ that prune training-only outputs and their value_info entries, while keeping
168
+ regular outputs such as "running_mean" and "running_var" intact.
169
+ """
170
+ initializers = [
171
+ _make_bn_initializer ("scale" , [3 ], 1.0 ),
172
+ _make_bn_initializer ("bias" , [3 ], 0.0 ),
173
+ _make_bn_initializer ("mean" , [3 ], 0.0 ),
174
+ _make_bn_initializer ("var" , [3 ], 1.0 ),
175
+ ]
176
+
177
+ graph_outputs = []
178
+ for output_name , shape in [
179
+ ("output" , [1 , 3 , 224 , 224 ]),
180
+ ("running_mean" , [3 ]),
181
+ ("running_var" , [3 ]),
182
+ ]:
183
+ if output_name in bn_node .output :
184
+ graph_outputs .append (make_tensor_value_info (output_name , onnx .TensorProto .FLOAT , shape ))
185
+
186
+ graph_def = make_graph (
187
+ [bn_node ],
188
+ "test_graph" ,
189
+ [make_tensor_value_info ("input" , onnx .TensorProto .FLOAT , [1 , 3 , 224 , 224 ])],
190
+ graph_outputs ,
191
+ initializer = initializers ,
192
+ value_info = extra_value_infos or [],
193
+ )
194
+
195
+ return make_model (graph_def , opset_imports = [make_opsetid ("" , 14 )])
196
+
197
+
198
+ def test_remove_node_training_mode_attribute ():
199
+ """Test removal of training_mode attribute from BatchNormalization nodes."""
200
+ bn_node = make_node (
201
+ "BatchNormalization" ,
202
+ inputs = ["input" , "scale" , "bias" , "mean" , "var" ],
203
+ outputs = ["output" ],
204
+ name = "bn1" ,
205
+ training_mode = 1 , # This attribute should be removed
206
+ )
207
+
208
+ model = _make_batchnorm_model (bn_node )
209
+ result_model = remove_node_training_mode (model , "BatchNormalization" )
210
+
211
+ bn_node_result = result_model .graph .node [0 ]
212
+ assert bn_node_result .op_type == "BatchNormalization"
213
+
214
+ # Check that training_mode attribute is not present
215
+ attr_names = [attr .name for attr in bn_node_result .attribute ]
216
+ assert "training_mode" not in attr_names
217
+
218
+
219
+ def test_remove_node_extra_training_outputs ():
220
+ """Test removal of extra training outputs from BatchNormalization nodes."""
221
+ bn_node = make_node (
222
+ "BatchNormalization" ,
223
+ inputs = ["input" , "scale" , "bias" , "mean" , "var" ],
224
+ outputs = [
225
+ "output" ,
226
+ "running_mean" ,
227
+ "running_var" ,
228
+ "saved_mean" ,
229
+ "saved_inv_std" ,
230
+ ],
231
+ name = "bn1" ,
232
+ training_mode = 1 ,
233
+ )
234
+
235
+ # Extra training outputs are attached to the graph's value_info
236
+ value_infos = [
237
+ make_tensor_value_info ("saved_mean" , onnx .TensorProto .FLOAT , [3 ]),
238
+ make_tensor_value_info ("saved_inv_std" , onnx .TensorProto .FLOAT , [3 ]),
239
+ ]
240
+
241
+ model = _make_batchnorm_model (bn_node , extra_value_infos = value_infos )
242
+ result_model = remove_node_training_mode (model , "BatchNormalization" )
243
+
244
+ # Verify only the non-training outputs remain
245
+ bn_node_result = result_model .graph .node [0 ]
246
+ print (bn_node_result .output )
247
+ assert len (bn_node_result .output ) == 3
248
+ assert bn_node_result .output [0 ] == "output"
249
+ assert bn_node_result .output [1 ] == "running_mean"
250
+ assert bn_node_result .output [2 ] == "running_var"
251
+
252
+ # Verify value_info entries for removed outputs are cleaned up
253
+ value_info_names = [vi .name for vi in result_model .graph .value_info ]
254
+ assert "saved_mean" not in value_info_names
255
+ assert "saved_inv_std" not in value_info_names
0 commit comments