Skip to content

Commit 1193f50

Browse files
authored
E2E tests for math ops. (#20169)
This is thought to be needed before going ahead with a batch of math ops codegen changes (#19970 (review)) and this also discovered a few bugs: #20163 #20164 #20165. Signed-off-by: Benoit Jacob <[email protected]>
1 parent 6eadf3d commit 1193f50

File tree

12 files changed

+1288
-10
lines changed

12 files changed

+1288
-10
lines changed

CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -568,6 +568,7 @@ include(iree_cc_binary_benchmark)
568568
include(iree_hal_cts_test_suite)
569569
include(iree_static_linker_test)
570570
include(iree_plugin_register)
571+
include(iree_genrule)
571572

572573
# Default any sub-tree which doesn't provide its own package namespacing
573574
# to derive it relative to this directory and prefixed with iree/.

build_tools/bazel/iree_check_test.bzl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def iree_check_test(
2525
runner_args = [],
2626
tags = [],
2727
timeout = None,
28+
deps = [],
2829
**kwargs):
2930
"""Creates an iree-check-module test for the specified source file.
3031
@@ -53,11 +54,13 @@ def iree_check_test(
5354
"--iree-hal-target-backends=%s" % target_backend,
5455
] + compiler_flags + input_type_flags
5556
bytecode_module_name = name + "_bytecode_module"
57+
5658
iree_bytecode_module(
5759
name = bytecode_module_name,
5860
src = src,
5961
flags = flags,
6062
tags = ["target=%s" % target_backend],
63+
deps = deps,
6164
visibility = ["//visibility:private"],
6265
)
6366

@@ -86,6 +89,7 @@ def iree_check_single_backend_test_suite(
8689
input_type = None,
8790
runner_args = [],
8891
tags = [],
92+
deps = [],
8993
timeout = None,
9094
**kwargs):
9195
"""Creates a test suite of iree-check-module tests for a single backend/driver pair.
@@ -124,7 +128,7 @@ def iree_check_single_backend_test_suite(
124128

125129
tests = []
126130
for src in srcs:
127-
test_name = "_".join([name, src])
131+
test_name = "_".join([name, src]).replace("/", "_").replace(":", "_")
128132
iree_check_test(
129133
name = test_name,
130134
src = src,
@@ -135,6 +139,7 @@ def iree_check_single_backend_test_suite(
135139
runner_args = runner_args,
136140
tags = tags,
137141
timeout = timeout,
142+
deps = deps,
138143
**kwargs
139144
)
140145
tests.append(test_name)
@@ -161,6 +166,7 @@ def iree_check_test_suite(
161166
runner_args = [],
162167
tags = [],
163168
target_cpu_features_variants = [],
169+
deps = [],
164170
**kwargs):
165171
"""Creates a test suite of iree-check-module tests.
166172
@@ -205,6 +211,7 @@ def iree_check_test_suite(
205211
input_type = input_type,
206212
runner_args = runner_args,
207213
tags = tags,
214+
deps = deps,
208215
**kwargs
209216
)
210217
tests.append(suite_name)

build_tools/bazel/iree_genrule.bzl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2025 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
"""A minimal subset of Bazel genrule for common use cases"""
8+
9+
def iree_genrule(
10+
name,
11+
srcs,
12+
outs,
13+
cmd,
14+
**kwargs):
15+
"""A minimal subset of Bazel genrule for common use cases.
16+
17+
Args:
18+
name: Name of the target.
19+
srcs: Source files, including any script run in the command.
20+
Unlike Bazel's genrule, we do not try to distinguish between the
21+
two. The distinction is needed when tools need to be compiled for
22+
host, but that doesn't concern us if we only need to run python
23+
scripts.
24+
outs: Files generated by the command.
25+
cmd: The command to be executed. The only supported special Bazel
26+
genrule syntax is:
27+
* "$(rootpath x)", which expands to the path to a file in the
28+
source tree.
29+
* "$(execpath x)", which expands to the path to a file in the
30+
directory where Bazel runs the build action.
31+
**kwargs: any additional attributes to pass to the underlying rules.
32+
"""
33+
34+
native.genrule(
35+
name = name,
36+
srcs = srcs,
37+
outs = outs,
38+
cmd = cmd,
39+
**kwargs
40+
)

build_tools/bazel_to_cmake/bazel_to_cmake_converter.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,7 @@ def iree_check_single_backend_test_suite(
789789
runner_args=None,
790790
tags=None,
791791
timeout=None,
792+
deps=None,
792793
**kwargs,
793794
):
794795
if self._should_skip_target(tags=tags, **kwargs):
@@ -806,6 +807,7 @@ def iree_check_single_backend_test_suite(
806807
runner_args_block = self._convert_string_list_block("RUNNER_ARGS", runner_args)
807808
labels_block = self._convert_string_list_block("LABELS", tags)
808809
timeout_block = self._convert_timeout_arg_block("TIMEOUT", timeout)
810+
deps_block = self._convert_string_list_block("DEPS", deps)
809811

810812
self._converter.body += (
811813
f"iree_check_single_backend_test_suite(\n"
@@ -818,6 +820,7 @@ def iree_check_single_backend_test_suite(
818820
f"{runner_args_block}"
819821
f"{labels_block}"
820822
f"{timeout_block}"
823+
f"{deps_block}"
821824
f")\n\n"
822825
)
823826

@@ -832,6 +835,7 @@ def iree_check_test_suite(
832835
tags=None,
833836
target_cpu_features_variants=None,
834837
timeout=None,
838+
deps=None,
835839
**kwargs,
836840
):
837841
if self._should_skip_target(tags=tags, **kwargs):
@@ -858,6 +862,7 @@ def iree_check_test_suite(
858862
"TARGET_CPU_FEATURES_VARIANTS", target_cpu_features_variants
859863
)
860864
timeout_block = self._convert_timeout_arg_block("TIMEOUT", timeout)
865+
deps_block = self._convert_string_list_block("DEPS", deps)
861866

862867
self._converter.body += (
863868
f"iree_check_test_suite(\n"
@@ -871,6 +876,7 @@ def iree_check_test_suite(
871876
f"{labels_block}"
872877
f"{target_cpu_features_variants_block}"
873878
f"{timeout_block}"
879+
f"{deps_block}"
874880
f")\n\n"
875881
)
876882

@@ -1007,6 +1013,21 @@ def iree_cmake_extra_content(self, content, inline=False):
10071013
else:
10081014
self._converter.header += f"\n{content}\n"
10091015

1016+
def iree_genrule(self, name, srcs, outs, cmd):
1017+
name_block = self._convert_string_arg_block("NAME", name, quote=False)
1018+
srcs_block = self._convert_srcs_block(srcs)
1019+
outs_block = self._convert_target_list_block("OUTS", outs)
1020+
cmd_block = self._convert_string_arg_block("CMD", cmd, quote=True)
1021+
1022+
self._converter.body += (
1023+
f"iree_genrule(\n"
1024+
f"{name_block}"
1025+
f"{srcs_block}"
1026+
f"{outs_block}"
1027+
f"{cmd_block}"
1028+
f")\n\n"
1029+
)
1030+
10101031

10111032
class Converter(object):
10121033
"""Conversion state tracking and full file template substitution."""
Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# Copyright 2025 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
include(CMakeParseArguments)
8+
9+
# A wrapper around add_custom_command and a minimal subset of Bazel genrule.
10+
#
11+
# Parameters:
12+
# NAME: Name of the target.
13+
# SRCS: Source files, including any script run in the command.
14+
# Unlike Bazel's genrule, we do not try to distinguish between the
15+
# two. The distinction is needed when tools need to be compiled for
16+
# host, but that doesn't concern us if we only need to run python
17+
# scripts.
18+
# OUTS: Files generated by the command.
19+
# CMD: The command to be executed. The only supported special Bazel genrule
20+
# syntax is:
21+
# * "$(rootpath x)", only supported for source files. In conversion
22+
# to CMake, this expands to the path of x relatively to the current
23+
# source dir.
24+
# * "$(execpath x)", only supported for generated files. In
25+
# conversion to CMake, this expands to just x, as the binary dir is
26+
# the default working dir for custom commands anyway.
27+
function(iree_genrule)
28+
cmake_parse_arguments(
29+
_RULE
30+
""
31+
"NAME"
32+
"SRCS;OUTS;CMD"
33+
${ARGN}
34+
)
35+
36+
set(_CMD "${_RULE_CMD}")
37+
38+
# Replace Bazel syntax $(rootpath x) with the path into the source dir.
39+
string(REGEX REPLACE "\\$\\(rootpath ([^)]+)\\)" "${CMAKE_CURRENT_SOURCE_DIR}/\\1" _CMD "${_CMD}")
40+
41+
# Simply drop Bazel syntax $(execpath x) as Bazel custom commands are executed
42+
# by default in the build directory.
43+
string(REGEX REPLACE "\\$\\(execpath ([^)]+)\\)" "\\1" _CMD "${_CMD}")
44+
45+
# Convert CMake/Unix-style paths with forward slashes to Windows-style with
46+
# backslashes. It is a bit incorrect to do it as a single cmake_path command
47+
# on the whole command string, which isn't technically a path, but this should
48+
# not matter if all what this does is this character substitution.
49+
# It is not worth implementing a cumbersome fix here, when CMake 4.0 brings
50+
# the $<PATH:NATIVE_PATH,...> generator expression which is a simpler, better
51+
# fix here. TODO(bjacob): use that generator expression in the above string
52+
# replace command directly, whenever we can rely on CMake 4.0.
53+
cmake_path(NATIVE_PATH _CMD _CMD)
54+
55+
# CMake add_custom_command expects a list as the command, so we replace spaces
56+
# by semicolon here. Careful to avoid replacing backslash-escaped spaces.
57+
string(REGEX REPLACE "([^\\]) " "\\1;" _CMD "${_CMD}")
58+
59+
add_custom_command(
60+
OUTPUT
61+
"${_RULE_OUTS}"
62+
COMMAND
63+
${_CMD}
64+
DEPENDS
65+
"${_RULE_SRCS}"
66+
COMMENT
67+
"Generating ${_RULE_OUTS}"
68+
VERBATIM
69+
)
70+
71+
add_custom_target("${_RULE_NAME}"
72+
DEPENDS "${_RULE_OUTS}"
73+
)
74+
endfunction()

compiler/src/iree/compiler/Modules/Check/IR/CheckOps.td

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,14 @@ def CHECK_ExpectAlmostEqOp :
138138
let summary = [{Checks that the operands are almost equal}];
139139
let description = [{
140140
Verifies that the buffer view or tensor operands with float elements satisfy
141-
the Numpy-style fuzzy-comparision condition with pararameters `atol`,
142-
`rtol`, which is the following element-wise on array elements `lhs`, `rhs`:
143-
```
144-
abs(lhs - rhs) <= atol + rtol * abs(rhs).
145-
```
141+
the Numpy-style fuzzy-comparision condition with parameters `atol`,
142+
`rtol`, defined exactly as in NumPy isclose():
143+
https://github.com/numpy/numpy/blob/7297f3117d84745bfade1e2f9aec3531e5917500/numpy/_core/numeric.py#L2447-L2449
144+
145+
The condition being verified on each lhs and rhs value is:
146+
lhs == rhs || (isfinite(rhs) && abs(lhs - rhs) <= atol + rtol * abs(rhs)).
147+
Note that the `lhs == rhs` part is needed for the case (lhs=+inf, rhs+inf)
148+
to return true. Indeed, in that case, lhs-rhs is NaN.
146149

147150
Issues a non-fatal failure if the verification fails.
148151

runtime/src/iree/modules/check/module.cc

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,17 @@ bool EqByteSpan(iree_byte_span_t lhs_bytes, iree_byte_span_t rhs_bytes) {
6969
// Numpy-compatible fuzzy comparison of floating-point values lhs, rhs with
7070
// respect to tolerance parameters atol, rtol.
7171
//
72-
// The meaning of the tolerance parameters atol and rtol is exactly as in:
73-
// https://numpy.org/doc/stable/reference/generated/numpy.testing.assert_allclose.html
72+
// The meaning of the tolerance parameters atol and rtol is exactly as in NumPy
73+
// isclose():
74+
// https://github.com/numpy/numpy/blob/7297f3117d84745bfade1e2f9aec3531e5917500/numpy/_core/numeric.py#L2447-L2449
7475
// The condition being verified on each lhs and rhs value is:
75-
// abs(lhs - rhs) <= atol + rtol * abs(rhs).
76+
// lhs == rhs || (isfinite(rhs) && abs(lhs - rhs) <= atol + rtol * abs(rhs)).
77+
// Note that the `lhs == rhs` part is needed for the case (lhs=+inf, rhs+inf)
78+
// to return true. Indeed, in that case, lhs-rhs is NaN.
7679
template <typename T>
7780
bool NumpyFuzzyCompare(T lhs, T rhs, float atol, float rtol) {
78-
return std::abs(lhs - rhs) <= atol + rtol * std::abs(rhs);
81+
return lhs == rhs || (std::isfinite(rhs) &&
82+
std::abs(lhs - rhs) <= atol + rtol * std::abs(rhs));
7983
}
8084

8185
// Records information about some LHS/RHS scalars that failed a fuzzy comparison

tests/e2e/math/BUILD.bazel

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
# Copyright 2025 The IREE Authors
2+
#
3+
# Licensed under the Apache License v2.0 with LLVM Exceptions.
4+
# See https://llvm.org/LICENSE.txt for license information.
5+
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
7+
load("//build_tools/bazel:iree_check_test.bzl", "iree_check_single_backend_test_suite")
8+
load("//build_tools/bazel:iree_genrule.bzl", "iree_genrule")
9+
10+
package(
11+
features = ["layering_check"],
12+
licenses = ["notice"], # Apache 2.0
13+
)
14+
15+
testcases = [
16+
(
17+
# Input JSON file describing testcases
18+
"math_ops_%s.json" % backend,
19+
# Output generated MLIR test file.
20+
"math_ops_%s.mlir" % backend,
21+
)
22+
for backend in [
23+
"llvm-cpu",
24+
"rocm",
25+
]
26+
]
27+
28+
[iree_genrule(
29+
name = "gen_%s" % generated_src,
30+
srcs = [
31+
"generate.py",
32+
testcases_json,
33+
],
34+
outs = [generated_src],
35+
cmd = " ".join([
36+
"python3",
37+
"$(rootpath generate.py)",
38+
"--testcases=$(rootpath %s)" % testcases_json,
39+
"> $(execpath %s)" % generated_src,
40+
]),
41+
) for testcases_json, generated_src in testcases]
42+
43+
[iree_check_single_backend_test_suite(
44+
name = "math_ops_%s" % backend,
45+
srcs = ["//tests/e2e/math:math_ops_%s.mlir" % backend],
46+
compiler_flags = ["--iree-llvmcpu-target-cpu=generic"] if backend == "llvm-cpu" else [],
47+
driver = driver,
48+
target_backend = backend,
49+
deps = [
50+
"gen_math_ops_%s.mlir" % backend,
51+
],
52+
) for backend, driver in [
53+
("llvm-cpu", "local-task"),
54+
("rocm", "hip"),
55+
]]

0 commit comments

Comments
 (0)