Skip to content

Commit e70aeef

Browse files
committed
Tests relocated
Signed-off-by: Riyad Islam <[email protected]>
1 parent 062224a commit e70aeef

File tree

2 files changed

+224
-221
lines changed

2 files changed

+224
-221
lines changed

tests/unit/onnx/test_onnx_utils.py

Lines changed: 223 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,29 @@
1515

1616
import os
1717

18+
import numpy as np
19+
import onnx
1820
import pytest
21+
from _test_utils.torch_model.vision_models import get_tiny_resnet_and_input
22+
from onnx.helper import (
23+
make_graph,
24+
make_model,
25+
make_node,
26+
make_opsetid,
27+
make_tensor,
28+
make_tensor_value_info,
29+
)
1930

20-
from modelopt.onnx.utils import save_onnx_bytes_to_dir, validate_onnx
31+
from modelopt.onnx.utils import (
32+
get_input_names_from_bytes,
33+
get_output_names_from_bytes,
34+
randomize_weights_onnx_bytes,
35+
remove_node_training_mode,
36+
remove_weights_data,
37+
save_onnx_bytes_to_dir,
38+
validate_onnx,
39+
)
40+
from modelopt.torch._deploy.utils import get_onnx_bytes
2141

2242

2343
@pytest.mark.parametrize(
@@ -31,3 +51,205 @@ def test_validate_onnx(onnx_bytes):
3151
def test_save_onnx(tmp_path):
3252
save_onnx_bytes_to_dir(b"test_onnx_bytes", tmp_path, "test")
3353
assert os.path.exists(os.path.join(tmp_path, "test.onnx"))
54+
55+
56+
def make_onnx_model_for_matmul_op():
57+
input_left = np.array([1, 2])
58+
input_right = np.array([1, 3])
59+
output_shape = np.matmul(input_left, input_right).shape
60+
node = make_node("MatMul", ["X", "Y"], ["Z"], name="matmul")
61+
graph = make_graph(
62+
[node],
63+
"test_graph",
64+
[
65+
make_tensor_value_info("X", onnx.TensorProto.FLOAT, input_left.shape),
66+
make_tensor_value_info("Y", onnx.TensorProto.FLOAT, input_right.shape),
67+
],
68+
[make_tensor_value_info("Z", onnx.TensorProto.FLOAT, output_shape)],
69+
)
70+
model = make_model(graph, producer_name="Omniengine Tester")
71+
return model.SerializeToString()
72+
73+
74+
def test_input_names():
75+
model_bytes = make_onnx_model_for_matmul_op()
76+
input_names = get_input_names_from_bytes(model_bytes)
77+
assert input_names == ["X", "Y"]
78+
79+
80+
def test_output_names():
81+
model_bytes = make_onnx_model_for_matmul_op()
82+
output_names = get_output_names_from_bytes(model_bytes)
83+
assert output_names == ["Z"]
84+
85+
86+
def _get_avg_var_of_weights(model):
87+
inits = model.graph.initializer
88+
avg_var_dict = {}
89+
90+
for init in inits:
91+
if len(init.dims) > 1:
92+
dtype = onnx.helper.tensor_dtype_to_np_dtype(init.data_type)
93+
if dtype in ["float16", "float32", "float64"]:
94+
np_tensor = np.frombuffer(init.raw_data, dtype=dtype)
95+
avg_var_dict[init.name + "_avg"] = np.average(np_tensor)
96+
avg_var_dict[init.name + "_var"] = np.var(np_tensor)
97+
98+
return avg_var_dict
99+
100+
101+
def test_random_onnx_weights():
102+
model, args, kwargs = get_tiny_resnet_and_input()
103+
assert not kwargs
104+
105+
onnx_bytes = get_onnx_bytes(model, args)
106+
original_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_bytes))
107+
original_model_size = len(onnx_bytes)
108+
109+
onnx_bytes = remove_weights_data(onnx_bytes)
110+
# Removed model weights should be greater than 18 MB
111+
assert original_model_size - len(onnx_bytes) > 18e6
112+
113+
# After assigning random weights, model size should be slightly greater than the the original
114+
# size due to some extra metadata
115+
onnx_bytes = randomize_weights_onnx_bytes(onnx_bytes)
116+
assert len(onnx_bytes) > original_model_size
117+
118+
randomized_avg_var_dict = _get_avg_var_of_weights(onnx.load_from_string(onnx_bytes))
119+
for key, value in original_avg_var_dict.items():
120+
assert abs(value - randomized_avg_var_dict[key]) < 0.1
121+
122+
123+
def test_reproducible_random_weights():
124+
model, args, kwargs = get_tiny_resnet_and_input()
125+
assert not kwargs
126+
127+
original_onnx_bytes = get_onnx_bytes(model, args)
128+
onnx_bytes_wo_weights = remove_weights_data(original_onnx_bytes)
129+
130+
# Check if the randomization produces the same weights
131+
onnx_bytes_1 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights)
132+
onnx_bytes_2 = randomize_weights_onnx_bytes(onnx_bytes_wo_weights)
133+
assert onnx_bytes_1 == onnx_bytes_2
134+
135+
136+
def _make_bn_initializer(name: str, shape, value=1.0):
137+
"""Helper to create an initializer tensor for BatchNorm."""
138+
data = np.full(shape, value, dtype=np.float32)
139+
return make_tensor(name, onnx.TensorProto.FLOAT, shape, data.flatten())
140+
141+
142+
def _make_batchnorm_model(bn_node, extra_value_infos=None):
143+
"""Helper to create an ONNX model with a BatchNormalization node.
144+
145+
The created model has the following schematic structure:
146+
147+
graph name: "test_graph"
148+
inputs:
149+
- input: FLOAT [1, 3, 224, 224]
150+
initializers:
151+
- scale: FLOAT [3]
152+
- bias: FLOAT [3]
153+
- mean: FLOAT [3]
154+
- var: FLOAT [3]
155+
nodes:
156+
- BatchNormalization (name comes from `bn_node`), with:
157+
inputs = ["input", "scale", "bias", "mean", "var"]
158+
outputs = as provided by `bn_node` (e.g., ["output"], or
159+
["output", "running_mean", "running_var", "saved_mean"])
160+
outputs:
161+
- output: FLOAT [1, 3, 224, 224]
162+
163+
If `extra_value_infos` is provided (e.g., value_info for non-training outputs
164+
like "running_mean"/"running_var" and/or training-only outputs like
165+
"saved_mean"/"saved_inv_std"), they are attached to the graph's value_info.
166+
Some tests subsequently invoke utilities (e.g., remove_node_training_mode)
167+
that prune training-only outputs and their value_info entries, while keeping
168+
regular outputs such as "running_mean" and "running_var" intact.
169+
"""
170+
initializers = [
171+
_make_bn_initializer("scale", [3], 1.0),
172+
_make_bn_initializer("bias", [3], 0.0),
173+
_make_bn_initializer("mean", [3], 0.0),
174+
_make_bn_initializer("var", [3], 1.0),
175+
]
176+
177+
graph_outputs = []
178+
for output_name, shape in [
179+
("output", [1, 3, 224, 224]),
180+
("running_mean", [3]),
181+
("running_var", [3]),
182+
]:
183+
if output_name in bn_node.output:
184+
graph_outputs.append(make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, shape))
185+
186+
graph_def = make_graph(
187+
[bn_node],
188+
"test_graph",
189+
[make_tensor_value_info("input", onnx.TensorProto.FLOAT, [1, 3, 224, 224])],
190+
graph_outputs,
191+
initializer=initializers,
192+
value_info=extra_value_infos or [],
193+
)
194+
195+
return make_model(graph_def, opset_imports=[make_opsetid("", 14)])
196+
197+
198+
def test_remove_node_training_mode_attribute():
199+
"""Test removal of training_mode attribute from BatchNormalization nodes."""
200+
bn_node = make_node(
201+
"BatchNormalization",
202+
inputs=["input", "scale", "bias", "mean", "var"],
203+
outputs=["output"],
204+
name="bn1",
205+
training_mode=1, # This attribute should be removed
206+
)
207+
208+
model = _make_batchnorm_model(bn_node)
209+
result_model = remove_node_training_mode(model, "BatchNormalization")
210+
211+
bn_node_result = result_model.graph.node[0]
212+
assert bn_node_result.op_type == "BatchNormalization"
213+
214+
# Check that training_mode attribute is not present
215+
attr_names = [attr.name for attr in bn_node_result.attribute]
216+
assert "training_mode" not in attr_names
217+
218+
219+
def test_remove_node_extra_training_outputs():
220+
"""Test removal of extra training outputs from BatchNormalization nodes."""
221+
bn_node = make_node(
222+
"BatchNormalization",
223+
inputs=["input", "scale", "bias", "mean", "var"],
224+
outputs=[
225+
"output",
226+
"running_mean",
227+
"running_var",
228+
"saved_mean",
229+
"saved_inv_std",
230+
],
231+
name="bn1",
232+
training_mode=1,
233+
)
234+
235+
# Extra training outputs are attached to the graph's value_info
236+
value_infos = [
237+
make_tensor_value_info("saved_mean", onnx.TensorProto.FLOAT, [3]),
238+
make_tensor_value_info("saved_inv_std", onnx.TensorProto.FLOAT, [3]),
239+
]
240+
241+
model = _make_batchnorm_model(bn_node, extra_value_infos=value_infos)
242+
result_model = remove_node_training_mode(model, "BatchNormalization")
243+
244+
# Verify only the non-training outputs remain
245+
bn_node_result = result_model.graph.node[0]
246+
print(bn_node_result.output)
247+
assert len(bn_node_result.output) == 3
248+
assert bn_node_result.output[0] == "output"
249+
assert bn_node_result.output[1] == "running_mean"
250+
assert bn_node_result.output[2] == "running_var"
251+
252+
# Verify value_info entries for removed outputs are cleaned up
253+
value_info_names = [vi.name for vi in result_model.graph.value_info]
254+
assert "saved_mean" not in value_info_names
255+
assert "saved_inv_std" not in value_info_names

0 commit comments

Comments
 (0)