|
| 1 | +# Copyright 2021 Intel Corporation |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +import sys |
| 16 | +import copy |
| 17 | +import numbers |
| 18 | +import pytest |
| 19 | +import numpy as np |
| 20 | +import types as pytypes |
| 21 | + |
| 22 | +from numba_dpcomp import njit, jit, vectorize |
| 23 | + |
| 24 | +from numba_dpcomp.mlir.passes import ( |
| 25 | + print_pass_ir, |
| 26 | + get_print_buffer, |
| 27 | + is_print_buffer_empty, |
| 28 | +) |
| 29 | + |
| 30 | +import numba.tests.test_parfors |
| 31 | + |
| 32 | + |
| 33 | +def _gen_tests(): |
| 34 | + testcases = [ |
| 35 | + numba.tests.test_parfors.TestPrangeBasic, |
| 36 | + numba.tests.test_parfors.TestPrangeSpecific, |
| 37 | + numba.tests.test_parfors.TestParforsVectorizer, |
| 38 | + ] |
| 39 | + |
| 40 | + xfail_tests = { |
| 41 | + "test_prange03mul", |
| 42 | + "test_prange09", |
| 43 | + "test_prange03sub", |
| 44 | + "test_prange10", |
| 45 | + "test_prange03", |
| 46 | + "test_prange03div", |
| 47 | + "test_prange07", |
| 48 | + "test_prange06", |
| 49 | + "test_prange16", |
| 50 | + "test_prange12", |
| 51 | + "test_prange04", |
| 52 | + "test_prange13", |
| 53 | + "test_prange25", |
| 54 | + "test_prange21", |
| 55 | + "test_prange14", |
| 56 | + "test_prange18", |
| 57 | + "test_prange_nested_reduction1", |
| 58 | + "test_list_setitem_hoisting", |
| 59 | + "test_prange23", |
| 60 | + "test_prange24", |
| 61 | + "test_list_comprehension_prange", |
| 62 | + "test_prange22", |
| 63 | + "test_prange_raises_invalid_step_size", |
| 64 | + "test_issue7501", |
| 65 | + "test_parfor_race_1", |
| 66 | + "test_check_alias_analysis", |
| 67 | + "test_nested_parfor_push_call_vars", |
| 68 | + "test_record_array_setitem_yield_array", |
| 69 | + "test_record_array_setitem", |
| 70 | + "test_multiple_call_getattr_object", |
| 71 | + "test_prange_two_instances_same_reduction_var", |
| 72 | + "test_prange_conflicting_reduction_ops", |
| 73 | + "test_ssa_false_reduction", |
| 74 | + "test_prange26", |
| 75 | + "test_prange_two_conditional_reductions", |
| 76 | + "test_argument_alias_recarray_field", |
| 77 | + "test_mutable_list_param", |
| 78 | + "test_signed_vs_unsigned_vec_asm", |
| 79 | + "test_unsigned_refusal_to_vectorize", |
| 80 | + "test_vectorizer_fastmath_asm", |
| 81 | + } |
| 82 | + |
| 83 | + skip_tests = { |
| 84 | + "test_kde_example", |
| 85 | + "test_prange27", |
| 86 | + "test_copy_global_for_parfor", |
| 87 | + } |
| 88 | + |
| 89 | + def _wrap_test_class(test_base): |
| 90 | + class _Wrapper(test_base): |
| 91 | + def _gen_normal(self, func): |
| 92 | + return njit()(func) |
| 93 | + |
| 94 | + def _gen_parallel(self, func): |
| 95 | + def wrapper(*args, **kwargs): |
| 96 | + with print_pass_ir([], ["ParallelToTbbPass"]): |
| 97 | + res = njit(parallel=True)(func)(*args, **kwargs) |
| 98 | + ir = get_print_buffer() |
| 99 | + # Check some parallel loops were actually generated |
| 100 | + assert ir.count("plier_util.parallel") > 0, ir |
| 101 | + return res |
| 102 | + |
| 103 | + return wrapper |
| 104 | + |
| 105 | + def _gen_parallel_fastmath(self, func): |
| 106 | + def wrapper(*args, **kwargs): |
| 107 | + with print_pass_ir([], ["PostLLVMLowering"]): |
| 108 | + res = njit(parallel=True, fastmath=True)(func)(*args, **kwargs) |
| 109 | + ir = get_print_buffer() |
| 110 | + # Check some fastmath llvm flags were generated |
| 111 | + count = 0 |
| 112 | + for line in ir.splitlines(): |
| 113 | + for op in ("fadd", "fsub", "fmul", "fdiv", "frem", "fcmp"): |
| 114 | + if line.count("llvm." + op) and line.count( |
| 115 | + "llvm.fastmath<fast>" |
| 116 | + ): |
| 117 | + count += 1 |
| 118 | + assert count > 0, ir |
| 119 | + return res |
| 120 | + |
| 121 | + return wrapper |
| 122 | + |
| 123 | + def get_gufunc_asm(self, func, schedule_type, *args, **kwargs): |
| 124 | + assert False |
| 125 | + |
| 126 | + def prange_tester(self, pyfunc, *args, **kwargs): |
| 127 | + patch_instance = kwargs.pop("patch_instance", None) |
| 128 | + scheduler_type = kwargs.pop("scheduler_type", None) |
| 129 | + check_fastmath = kwargs.pop("check_fastmath", False) |
| 130 | + check_fastmath_result = kwargs.pop("check_fastmath_result", False) |
| 131 | + check_scheduling = kwargs.pop("check_scheduling", True) |
| 132 | + check_args_for_equality = kwargs.pop("check_arg_equality", None) |
| 133 | + assert not kwargs, "Unhandled kwargs: " + str(kwargs) |
| 134 | + |
| 135 | + pyfunc = self.generate_prange_func(pyfunc, patch_instance) |
| 136 | + |
| 137 | + cfunc = self._gen_normal(pyfunc) |
| 138 | + cpfunc = self._gen_parallel(pyfunc) |
| 139 | + |
| 140 | + if check_fastmath or check_fastmath_result: |
| 141 | + fastmath_pcres = self._gen_parallel_fastmath(pyfunc) |
| 142 | + |
| 143 | + def copy_args(*args): |
| 144 | + if not args: |
| 145 | + return tuple() |
| 146 | + new_args = [] |
| 147 | + for x in args: |
| 148 | + if isinstance(x, np.ndarray): |
| 149 | + new_args.append(x.copy("k")) |
| 150 | + elif isinstance(x, np.number): |
| 151 | + new_args.append(x.copy()) |
| 152 | + elif isinstance(x, numbers.Number): |
| 153 | + new_args.append(x) |
| 154 | + elif isinstance(x, tuple): |
| 155 | + new_args.append(copy.deepcopy(x)) |
| 156 | + elif isinstance(x, list): |
| 157 | + new_args.append(x[:]) |
| 158 | + else: |
| 159 | + raise ValueError("Unsupported argument type encountered") |
| 160 | + return tuple(new_args) |
| 161 | + |
| 162 | + # python result |
| 163 | + py_args = copy_args(*args) |
| 164 | + py_expected = pyfunc(*py_args) |
| 165 | + |
| 166 | + # njit result |
| 167 | + njit_args = copy_args(*args) |
| 168 | + njit_output = cfunc(*njit_args) |
| 169 | + |
| 170 | + # parfor result |
| 171 | + parfor_args = copy_args(*args) |
| 172 | + parfor_output = cpfunc(*parfor_args) |
| 173 | + |
| 174 | + if check_args_for_equality is None: |
| 175 | + np.testing.assert_almost_equal(njit_output, py_expected, **kwargs) |
| 176 | + np.testing.assert_almost_equal(parfor_output, py_expected, **kwargs) |
| 177 | + self.assertEqual(type(njit_output), type(parfor_output)) |
| 178 | + else: |
| 179 | + assert len(py_args) == len(check_args_for_equality) |
| 180 | + for pyarg, njitarg, parforarg, argcomp in zip( |
| 181 | + py_args, njit_args, parfor_args, check_args_for_equality |
| 182 | + ): |
| 183 | + argcomp(njitarg, pyarg, **kwargs) |
| 184 | + argcomp(parforarg, pyarg, **kwargs) |
| 185 | + |
| 186 | + # Ignore check_scheduling |
| 187 | + # if check_scheduling: |
| 188 | + # self.check_scheduling(cpfunc, scheduler_type) |
| 189 | + |
| 190 | + # if requested check fastmath variant |
| 191 | + if check_fastmath or check_fastmath_result: |
| 192 | + parfor_fastmath_output = fastmath_pcres(*copy_args(*args)) |
| 193 | + if check_fastmath_result: |
| 194 | + np.testing.assert_almost_equal( |
| 195 | + parfor_fastmath_output, py_expected, **kwargs |
| 196 | + ) |
| 197 | + |
| 198 | + return _Wrapper |
| 199 | + |
| 200 | + def _replace_global(func, name, newval): |
| 201 | + if name in func.__globals__: |
| 202 | + func.__globals__[name] = newval |
| 203 | + |
| 204 | + def _gen_test_func(func): |
| 205 | + _replace_global(func, "jit", jit) |
| 206 | + _replace_global(func, "njit", njit) |
| 207 | + _replace_global(func, "vectorize", vectorize) |
| 208 | + |
| 209 | + def wrapper(): |
| 210 | + return func() |
| 211 | + |
| 212 | + return wrapper |
| 213 | + |
| 214 | + this_module = sys.modules[__name__] |
| 215 | + for tc in testcases: |
| 216 | + inst = _wrap_test_class(tc)() |
| 217 | + for func_name in dir(tc): |
| 218 | + if func_name.startswith("test"): |
| 219 | + func = getattr(inst, func_name) |
| 220 | + if callable(func): |
| 221 | + func = _gen_test_func(func) |
| 222 | + if func_name in xfail_tests: |
| 223 | + func = pytest.mark.xfail(func) |
| 224 | + elif func_name in skip_tests: |
| 225 | + func = pytest.mark.skip(func) |
| 226 | + |
| 227 | + setattr(this_module, func_name, func) |
| 228 | + |
| 229 | + |
| 230 | +_gen_tests() |
| 231 | +del _gen_tests |
0 commit comments