-
Notifications
You must be signed in to change notification settings - Fork 107
Expand file tree
/
Copy pathops_test.py
More file actions
337 lines (293 loc) · 12.8 KB
/
ops_test.py
File metadata and controls
337 lines (293 loc) · 12.8 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
"""Test op correctness by comparing with PyTorch results.
Usage:
pytest onnxscript/tests/function_libs/torch_lib/ops_test.py
To run tests on a specific operator (e.g. torch.ceil):
pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k ceil
To run tests on a nn operator (e.g. nn.functional.scaled_dot_product_attention):
pytest onnxscript/tests/function_libs/torch_lib/ops_test.py -k nn_functional_scaled_dot_product_attention
## Environment variables
1. Set environment variable `CATCH_ORT_SEGFAULT=1` to catch segmentation faults
in onnxruntime by running the inference sessions in a separate process.
2. Set `CREATE_REPRODUCTION_REPORT=1` to create markdown files for reproduction of
errors.
"""
from __future__ import annotations
import os
import unittest
from typing import Callable, Optional, Sequence, Tuple
import numpy as np
import onnx
import onnxruntime as ort
import parameterized
import torch
from torch.testing._internal import common_device_type
from torch.testing._internal.opinfo import core as opinfo_core
from torch.utils import _pytree as pytree
import onnxscript
from onnxscript._internal import version_utils
from tests.function_libs.torch_lib import (
error_reproduction,
ops_test_common,
ops_test_data,
)
# All dtypes will be tested on the generated symbolic functions.
# complex64 will be flattened to float32.
TESTED_DTYPES = (
torch.float16,
torch.float32,
# Uncomment below item when we really need testing it
# torch.bfloat16,
# torch.float64,
torch.bool,
# torch.int8,
# torch.int16,
torch.int32,
torch.int64,
# torch.uint8,
)
# NOTE: torch.complex32 is experimental in torch
COMPLEX_TYPES = (torch.complex64,)
def dtypes_except(*dtypes: torch.dtype) -> Sequence[torch.dtype]:
"""Returns all dtypes except the ones specified."""
return tuple(dtype for dtype in TESTED_DTYPES if dtype not in dtypes)
def _should_skip_xfail_test_sample(
op_name: str, sample, dtype: torch.dtype, device_type: str
) -> Tuple[Optional[str], Optional[str]]:
"""Returns a reason if a test sample should be skipped."""
if op_name not in ops_test_data.OP_WITH_SKIPPED_XFAIL_SUBTESTS:
return None, None
for decorator_meta in ops_test_data.SKIP_XFAIL_SUBTESTS:
# Linear search on ops_test_data.SKIP_XFAIL_SUBTESTS. That's fine because the list is small.
if decorator_meta.op_name == op_name:
assert decorator_meta.matcher is not None, "Matcher must be defined"
if not decorator_meta.enabled_if:
# Do not skip the test if the decorator meta is not enabled
continue
if decorator_meta.dtypes is not None and dtype not in decorator_meta.dtypes:
# Not applicable for this dtype
continue
if (
decorator_meta.device_type is not None
and decorator_meta.device_type != device_type
):
# Not applicable for this device_type
continue
if decorator_meta.matcher(sample):
return decorator_meta.test_behavior, decorator_meta.reason
return None, None
class TestFunctionValidity(unittest.TestCase):
@parameterized.parameterized.expand(
[(info.op.name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS]
)
def test_script_function_passes_checker(
self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo
):
if not isinstance(torchlib_op_info.op, onnxscript.OnnxFunction):
self.skipTest("Traced functions does not have a function proto")
function_proto = torchlib_op_info.op.to_function_proto()
onnx.checker.check_function(function_proto) # type: ignore[attr-defined]
@parameterized.parameterized.expand(
[(info.op.name, info) for info in ops_test_data.TESTED_TORCHLIB_OPS]
)
def test_function_has_op_schema(self, _, torchlib_op_info: ops_test_data.TorchLibOpInfo):
func = torchlib_op_info.op
schema = func.op_schema
self.assertIsNotNone(schema)
self.assertEqual(schema.name, func.name)
def run_test_output_match(
test_suite: unittest.TestCase,
device: str,
dtype: torch.dtype,
op: opinfo_core.OpInfo,
function_executor: Callable,
tested_op_mapping: dict[
str,
ops_test_data.TorchLibOpInfo,
],
):
"""Base test method for testing each opset, used by instantiate_device_type_tests.
Args:
test_suite: The test class instance.
device: The PyTorch device. instantiate_device_type_tests provides this.
dtype: The PyTorch dtype. instantiate_device_type_tests provides this.
op: The OpInfo instance. instantiate_device_type_tests provides this.
function_executor: The function executor. This is a function that takes
a function and its arguments and returns the output of the function.
tested_op_mapping: The mapping of op name to the tested op.
"""
samples = op.sample_inputs(
device,
dtype,
requires_grad=False,
)
torchlib_op_info = tested_op_mapping[op.name]
# Obtain the input_wrangler that manipulates the OpInfo inputs
# to match the aten operator signature
# An example is nn.functional.upsample_nearest2d, which has a different signature
# than the aten operator upsample_nearest2d
onnx_function = torchlib_op_info.op
input_wrangler = torchlib_op_info.input_wrangler
if (
not ops_test_common.dtype_op_schema_compatible(dtype, onnx_function.op_schema)
and dtype not in COMPLEX_TYPES
):
test_suite.skipTest(
f"dtype '{dtype}' is not supported by the op '{op.name}'. "
f"Type constraints: {onnx_function.op_schema.type_constraints}"
)
# Obtain the tolerance for the op
rtol, atol = torchlib_op_info.get_tolerance(dtype)
for i, cpu_sample in enumerate(samples):
inputs = (cpu_sample.input, *cpu_sample.args)
# Provide the repr to subtest because tensors are not serializable in parallel test runs
with test_suite.subTest(
sample_num=i,
inputs=repr(
[
f"Tensor<{inp.shape}, dtype={inp.dtype}>"
if isinstance(inp, torch.Tensor)
else inp
for inp in inputs
]
),
kwargs=repr(cpu_sample.kwargs),
):
try:
device_type = cpu_sample.args[0].device.type
except (AttributeError, IndexError):
device_type = "cpu"
test_behavior, reason = _should_skip_xfail_test_sample(
op.name, cpu_sample, dtype, device_type
)
with ops_test_common.normal_xfail_skip_test_behaviors(test_behavior, reason):
input_onnx = [ops_test_common.convert_tensor_to_numpy(x) for x in inputs]
kwargs_onnx = ops_test_common.convert_kwargs_for_onnx(cpu_sample.kwargs)
if input_wrangler:
input_onnx, kwargs_onnx = input_wrangler(input_onnx, kwargs_onnx)
torch_output = op(*inputs, **cpu_sample.kwargs)
if isinstance(torch_output, torch.Tensor) and torch.is_complex(torch_output):
torch_output = torch.view_as_real(torch_output.resolve_conj())
reference_torch_outputs, _ = pytree.tree_flatten(torch_output)
if (
op.name.startswith("split")
or (op.name.startswith("unbind") and version_utils.torch_older_than("2.7"))
or op.name
in {"atleast_1d_Sequence", "atleast_2d_Sequence", "atleast_3d_Sequence"}
):
# Hack for handling split, chunk and unbind which relies on SplitToSequence op.
# Split returns a Sequence that should be treats as a single
# value. So we wrap it into a tuple.
# TODO(justinchuby): Find a more general solution
reference_torch_outputs = [reference_torch_outputs]
test_name = test_suite.id()
function_output = function_executor(test_name, reference_torch_outputs)(
onnx_function, input_onnx, kwargs_onnx
)
# Finally we re-flatten everything
# TODO: add pytree structure comparison.
flattened_torch_outputs, _ = pytree.tree_flatten(torch_output)
flattened_function_outputs, _ = pytree.tree_flatten(function_output)
assert flattened_torch_outputs
assert len(flattened_torch_outputs) == len(flattened_function_outputs)
for j, (torch_output, function_output) in enumerate(
zip(flattened_torch_outputs, flattened_function_outputs)
):
if not isinstance(function_output, np.ndarray):
# An onnxscript tensor
function_output = function_output.value
actual = torch.tensor(function_output)
expected = (
torch_output
if isinstance(torch_output, torch.Tensor)
else torch.tensor(torch_output)
)
if (
op.name in ops_test_data.NONDETERMINISTIC_OPS
or j in ops_test_data.COMPARE_SHAPE_ONLY_OPS[op.name]
):
# Check shape and dtype only for ops that are known to be
# nondeterministic
test_suite.assertEqual(actual.shape, expected.shape)
test_suite.assertEqual(actual.dtype, expected.dtype)
continue
# Use torch.testing as opposed to np.testing to ensure dtypes and shapes match
try:
torch.testing.assert_close(
actual,
expected,
rtol=rtol,
atol=atol,
equal_nan=True,
check_device=False,
)
except AssertionError as e:
if os.environ.get("CREATE_REPRODUCTION_REPORT") == "1":
error_reproduction.create_mismatch_report(
test_name, i, inputs, cpu_sample.kwargs, actual, expected, e
)
if len(flattened_torch_outputs) > 1:
raise AssertionError(f"Output {j} mismatch") from e
raise
class TestOutputConsistencyFullGraph(unittest.TestCase):
"""Test output consistency between exported ONNX op run as a graph and PyTorch eager mode.
This is a parameterized test suite.
"""
def setUp(self) -> None:
torch.manual_seed(42)
np.random.seed(42)
ort.set_seed(42)
@ops_test_common.add_decorate_info(
ops_test_data.OPS_DB,
"TestOutputConsistencyFullGraph",
"test_output_match_opinfo_",
skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS,
)
@common_device_type.ops( # type: ignore[misc]
[info for info in ops_test_data.OPS_DB if info.name in ops_test_data.TESTED_OPS],
allowed_dtypes=TESTED_DTYPES,
)
def test_output_match_opinfo_(
self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo
):
# Base test method for testing each op by running the full ONNX graph.
run_test_output_match(
self,
device,
dtype,
op,
ops_test_common.graph_executor,
ops_test_data.TORCHLIB_OPINFO_MAPPING,
)
@ops_test_common.add_decorate_info(
ops_test_data.OPS_DB,
"TestOutputConsistencyFullGraph",
"test_complex_output_match_opinfo_",
skip_or_xfails=ops_test_data.EXPECTED_SKIPS_OR_FAILS,
)
@common_device_type.ops( # type: ignore[misc]
[
info
for info in ops_test_data.OPS_DB
if info.name in ops_test_data.COMPLEX_FUNCTION_MAPPING
],
allowed_dtypes=COMPLEX_TYPES,
)
def test_complex_output_match_opinfo_(
self, device: str, dtype: torch.dtype, op: opinfo_core.OpInfo
):
"""Base test method for testing each op by running the full ONNX graph."""
run_test_output_match(
self,
device,
dtype,
op,
ops_test_common.graph_executor,
ops_test_data.COMPLEX_FUNCTION_MAPPING,
)
common_device_type.instantiate_device_type_tests(
TestOutputConsistencyFullGraph, globals(), only_for=["cpu", "cuda"]
)
if __name__ == "__main__":
unittest.main()