Skip to content

Commit 1f35e17

Browse files
authored
[python] Improve tests (#198)
1 parent aba6492 commit 1f35e17

File tree

3 files changed

+16
-11
lines changed

3 files changed

+16
-11
lines changed

numba_dpcomp/numba_dpcomp/mlir/kernel_sim.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ def barrier_proxy(flags):
8585
global _greenlet_found
8686
assert _greenlet_found, "greenlet package not installed"
8787
state = get_exec_state()
88+
assert len(state.tasks) > 0
8889
wg_size = state.wg_size[0]
8990
assert wg_size > 0
9091
if wg_size > 1:
@@ -175,6 +176,9 @@ def wrapper():
175176
return wrapper
176177

177178

179+
_barrier_ops = ["barrier"]
180+
181+
178182
def _execute_kernel(global_size, local_size, func, *args):
179183
if len(local_size) == 0:
180184
local_size = (1,) * len(global_size)
@@ -184,6 +188,7 @@ def _execute_kernel(global_size, local_size, func, *args):
184188
state = _setup_execution_state(global_size, local_size)
185189
try:
186190
groups = tuple((g + l - 1) // l for g, l in zip(global_size, local_size))
191+
need_barrier = any(n in func.__globals__ for n in _barrier_ops)
187192
for gid in product(*(range(g) for g in groups)):
188193
offset = tuple(g * l for g, l in zip(gid, local_size))
189194
size = tuple(
@@ -195,8 +200,9 @@ def _execute_kernel(global_size, local_size, func, *args):
195200

196201
indices_range = (range(o, o + s) for o, s in zip(offset, size))
197202

198-
global _greenlet_found
199-
if _greenlet_found:
203+
if need_barrier:
204+
global _greenlet_found
205+
assert _greenlet_found
200206
tasks = state.tasks
201207
assert len(tasks) == 0
202208
for indices in product(*indices_range):

numba_dpcomp/numba_dpcomp/mlir/tests/test_gpu.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -688,12 +688,11 @@ def func(c):
688688

689689

690690
@require_gpu
691-
def test_barrier1():
691+
@pytest.mark.parametrize("global_size", [1, 2, 27])
692+
@pytest.mark.parametrize("local_size", [1, 2, 7])
693+
def test_barrier1(global_size, local_size):
692694
atomic_add = atomic.add
693695

694-
global_size = 27
695-
local_size = 7
696-
697696
def func(a, b):
698697
i = get_global_id(0)
699698
off = i // local_size

numba_dpcomp/numba_dpcomp/mlir/tests/test_numba_parfor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -147,19 +147,19 @@ def _gen_tests():
147147
skip_tests = {}
148148

149149
def countParfors(test_func, args, **kws):
150-
pytest.xfail()
150+
pytest.skip()
151151

152152
def countArrays(test_func, args, **kws):
153-
pytest.xfail()
153+
pytest.skip()
154154

155155
def countArrayAllocs(test_func, args, **kws):
156-
pytest.xfail()
156+
pytest.skip()
157157

158158
def countNonParforArrayAccesses(test_func, args, **kws):
159-
pytest.xfail()
159+
pytest.skip()
160160

161161
def get_optimized_numba_ir(test_func, args, **kws):
162-
pytest.xfail()
162+
pytest.skip()
163163

164164
def _wrap_test_class(test_base):
165165
class _Wrapper(test_base):

0 commit comments

Comments
 (0)