Skip to content

Commit 1903cb5

Browse files
authored
Run parfor tests from upstream Numba (#177)
1 parent 240a3e4 commit 1903cb5

File tree

2 files changed

+233
-2
lines changed

2 files changed

+233
-2
lines changed

azure-pipelines.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -183,7 +183,7 @@ jobs:
183183
- script: |
184184
source /usr/local/miniconda/bin/activate
185185
cd numba_dpcomp
186-
conda create -y -n test_env python=3.9 numba=0.54 scikit-learn ninja scipy pybind11 tbb=2021.1 pytest lit -c conda-forge
186+
conda create -y -n test_env python=3.9 numba=0.54 scikit-learn pytest-xdist ninja scipy pybind11 tbb=2021.1 pytest lit -c conda-forge
187187
conda activate test_env
188188
cmake --version
189189
chmod -R 777 $(System.DefaultWorkingDirectory)/llvm_cache
@@ -196,7 +196,7 @@ jobs:
196196
source /usr/local/miniconda/bin/activate
197197
cd numba_dpcomp
198198
conda activate test_env
199-
pytest -vv --capture=tee-sys
199+
pytest -n1 -vv --capture=tee-sys
200200
displayName: 'Tests'
201201
202202
- script: |
Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
# Copyright 2021 Intel Corporation
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import sys
16+
import copy
17+
import numbers
18+
import pytest
19+
import numpy as np
20+
import types as pytypes
21+
22+
from numba_dpcomp import njit, jit, vectorize
23+
24+
from numba_dpcomp.mlir.passes import (
25+
print_pass_ir,
26+
get_print_buffer,
27+
is_print_buffer_empty,
28+
)
29+
30+
import numba.tests.test_parfors
31+
32+
33+
def _gen_tests():
34+
testcases = [
35+
numba.tests.test_parfors.TestPrangeBasic,
36+
numba.tests.test_parfors.TestPrangeSpecific,
37+
numba.tests.test_parfors.TestParforsVectorizer,
38+
]
39+
40+
xfail_tests = {
41+
"test_prange03mul",
42+
"test_prange09",
43+
"test_prange03sub",
44+
"test_prange10",
45+
"test_prange03",
46+
"test_prange03div",
47+
"test_prange07",
48+
"test_prange06",
49+
"test_prange16",
50+
"test_prange12",
51+
"test_prange04",
52+
"test_prange13",
53+
"test_prange25",
54+
"test_prange21",
55+
"test_prange14",
56+
"test_prange18",
57+
"test_prange_nested_reduction1",
58+
"test_list_setitem_hoisting",
59+
"test_prange23",
60+
"test_prange24",
61+
"test_list_comprehension_prange",
62+
"test_prange22",
63+
"test_prange_raises_invalid_step_size",
64+
"test_issue7501",
65+
"test_parfor_race_1",
66+
"test_check_alias_analysis",
67+
"test_nested_parfor_push_call_vars",
68+
"test_record_array_setitem_yield_array",
69+
"test_record_array_setitem",
70+
"test_multiple_call_getattr_object",
71+
"test_prange_two_instances_same_reduction_var",
72+
"test_prange_conflicting_reduction_ops",
73+
"test_ssa_false_reduction",
74+
"test_prange26",
75+
"test_prange_two_conditional_reductions",
76+
"test_argument_alias_recarray_field",
77+
"test_mutable_list_param",
78+
"test_signed_vs_unsigned_vec_asm",
79+
"test_unsigned_refusal_to_vectorize",
80+
"test_vectorizer_fastmath_asm",
81+
}
82+
83+
skip_tests = {
84+
"test_kde_example",
85+
"test_prange27",
86+
"test_copy_global_for_parfor",
87+
}
88+
89+
def _wrap_test_class(test_base):
90+
class _Wrapper(test_base):
91+
def _gen_normal(self, func):
92+
return njit()(func)
93+
94+
def _gen_parallel(self, func):
95+
def wrapper(*args, **kwargs):
96+
with print_pass_ir([], ["ParallelToTbbPass"]):
97+
res = njit(parallel=True)(func)(*args, **kwargs)
98+
ir = get_print_buffer()
99+
# Check some parallel loops were actually generated
100+
assert ir.count("plier_util.parallel") > 0, ir
101+
return res
102+
103+
return wrapper
104+
105+
def _gen_parallel_fastmath(self, func):
106+
def wrapper(*args, **kwargs):
107+
with print_pass_ir([], ["PostLLVMLowering"]):
108+
res = njit(parallel=True, fastmath=True)(func)(*args, **kwargs)
109+
ir = get_print_buffer()
110+
# Check some fastmath llvm flags were generated
111+
count = 0
112+
for line in ir.splitlines():
113+
for op in ("fadd", "fsub", "fmul", "fdiv", "frem", "fcmp"):
114+
if line.count("llvm." + op) and line.count(
115+
"llvm.fastmath<fast>"
116+
):
117+
count += 1
118+
assert count > 0, ir
119+
return res
120+
121+
return wrapper
122+
123+
def get_gufunc_asm(self, func, schedule_type, *args, **kwargs):
124+
assert False
125+
126+
def prange_tester(self, pyfunc, *args, **kwargs):
127+
patch_instance = kwargs.pop("patch_instance", None)
128+
scheduler_type = kwargs.pop("scheduler_type", None)
129+
check_fastmath = kwargs.pop("check_fastmath", False)
130+
check_fastmath_result = kwargs.pop("check_fastmath_result", False)
131+
check_scheduling = kwargs.pop("check_scheduling", True)
132+
check_args_for_equality = kwargs.pop("check_arg_equality", None)
133+
assert not kwargs, "Unhandled kwargs: " + str(kwargs)
134+
135+
pyfunc = self.generate_prange_func(pyfunc, patch_instance)
136+
137+
cfunc = self._gen_normal(pyfunc)
138+
cpfunc = self._gen_parallel(pyfunc)
139+
140+
if check_fastmath or check_fastmath_result:
141+
fastmath_pcres = self._gen_parallel_fastmath(pyfunc)
142+
143+
def copy_args(*args):
144+
if not args:
145+
return tuple()
146+
new_args = []
147+
for x in args:
148+
if isinstance(x, np.ndarray):
149+
new_args.append(x.copy("k"))
150+
elif isinstance(x, np.number):
151+
new_args.append(x.copy())
152+
elif isinstance(x, numbers.Number):
153+
new_args.append(x)
154+
elif isinstance(x, tuple):
155+
new_args.append(copy.deepcopy(x))
156+
elif isinstance(x, list):
157+
new_args.append(x[:])
158+
else:
159+
raise ValueError("Unsupported argument type encountered")
160+
return tuple(new_args)
161+
162+
# python result
163+
py_args = copy_args(*args)
164+
py_expected = pyfunc(*py_args)
165+
166+
# njit result
167+
njit_args = copy_args(*args)
168+
njit_output = cfunc(*njit_args)
169+
170+
# parfor result
171+
parfor_args = copy_args(*args)
172+
parfor_output = cpfunc(*parfor_args)
173+
174+
if check_args_for_equality is None:
175+
np.testing.assert_almost_equal(njit_output, py_expected, **kwargs)
176+
np.testing.assert_almost_equal(parfor_output, py_expected, **kwargs)
177+
self.assertEqual(type(njit_output), type(parfor_output))
178+
else:
179+
assert len(py_args) == len(check_args_for_equality)
180+
for pyarg, njitarg, parforarg, argcomp in zip(
181+
py_args, njit_args, parfor_args, check_args_for_equality
182+
):
183+
argcomp(njitarg, pyarg, **kwargs)
184+
argcomp(parforarg, pyarg, **kwargs)
185+
186+
# Ignore check_scheduling
187+
# if check_scheduling:
188+
# self.check_scheduling(cpfunc, scheduler_type)
189+
190+
# if requested check fastmath variant
191+
if check_fastmath or check_fastmath_result:
192+
parfor_fastmath_output = fastmath_pcres(*copy_args(*args))
193+
if check_fastmath_result:
194+
np.testing.assert_almost_equal(
195+
parfor_fastmath_output, py_expected, **kwargs
196+
)
197+
198+
return _Wrapper
199+
200+
def _replace_global(func, name, newval):
201+
if name in func.__globals__:
202+
func.__globals__[name] = newval
203+
204+
def _gen_test_func(func):
205+
_replace_global(func, "jit", jit)
206+
_replace_global(func, "njit", njit)
207+
_replace_global(func, "vectorize", vectorize)
208+
209+
def wrapper():
210+
return func()
211+
212+
return wrapper
213+
214+
this_module = sys.modules[__name__]
215+
for tc in testcases:
216+
inst = _wrap_test_class(tc)()
217+
for func_name in dir(tc):
218+
if func_name.startswith("test"):
219+
func = getattr(inst, func_name)
220+
if callable(func):
221+
func = _gen_test_func(func)
222+
if func_name in xfail_tests:
223+
func = pytest.mark.xfail(func)
224+
elif func_name in skip_tests:
225+
func = pytest.mark.skip(func)
226+
227+
setattr(this_module, func_name, func)
228+
229+
230+
_gen_tests()
231+
del _gen_tests

0 commit comments

Comments
 (0)