Skip to content

Commit 7315aef

Browse files
authored
Merge pull request #1194 from Hardcode84/numba-mlir-integr2
Numba-mlir integration
2 parents 35890ba + ac4cb5d commit 7315aef

File tree

7 files changed

+115
-15
lines changed

7 files changed

+115
-15
lines changed

.github/workflows/conda-package.yml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,8 +104,9 @@ jobs:
104104
python: ['3.9', '3.10', '3.11']
105105
os: [ubuntu-20.04, ubuntu-latest, windows-latest]
106106
experimental: [false]
107+
use_mlir: [false]
107108

108-
continue-on-error: ${{ matrix.experimental }}
109+
continue-on-error: ${{ matrix.experimental || matrix.use_mlir }}
109110

110111
steps:
111112
- name: Setup miniconda
@@ -169,6 +170,10 @@ jobs:
169170
- name: Install builded package
170171
run: mamba install ${{ env.PACKAGE_NAME }}=${{ env.PACKAGE_VERSION }} intel::intel-opencl-rt pytest -c ${{ env.CHANNEL_PATH }}
171172

173+
- name: Install numba-mlir
174+
if: matrix.use_mlir
175+
run: mamba install numba-mlir -c dppy/label/dev -c conda-forge -c intel
176+
172177
- name: Setup OpenCL CPU device
173178
if: runner.os == 'Windows'
174179
shell: pwsh
@@ -184,9 +189,13 @@ jobs:
184189
python -c "import dpcpp_llvm_spirv as p; print(p.get_llvm_spirv_path())"
185190
186191
- name: Smoke test
192+
env:
193+
NUMBA_DPEX_USE_MLIR: ${{ matrix.use_mlir && '1' || '0' }}
187194
run: python -c "import dpnp, dpctl, numba_dpex; dpctl.lsplatform()"
188195

189196
- name: Run tests
197+
env:
198+
NUMBA_DPEX_USE_MLIR: ${{ matrix.use_mlir && '1' || '0' }}
190199
run: |
191200
pytest -q -ra --disable-warnings --pyargs ${{ env.MODULE_NAME }} -vv
192201

numba_dpex/config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,3 +96,5 @@ def __getattr__(name):
9696
DPEX_OPT = _readenv("NUMBA_DPEX_OPT", int, 2)
9797

9898
INLINE_THRESHOLD = _readenv("NUMBA_DPEX_INLINE_THRESHOLD", int, None)
99+
100+
USE_MLIR = _readenv("NUMBA_DPEX_USE_MLIR", int, 0)

numba_dpex/core/descriptor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,12 +39,14 @@ class DpexTargetOptions(CPUTargetOptions):
3939
experimental = _option_mapping("experimental")
4040
release_gil = _option_mapping("release_gil")
4141
no_compile = _option_mapping("no_compile")
42+
use_mlir = _option_mapping("use_mlir")
4243

4344
def finalize(self, flags, options):
4445
super().finalize(flags, options)
4546
_inherit_if_not_set(flags, options, "experimental", False)
4647
_inherit_if_not_set(flags, options, "release_gil", False)
4748
_inherit_if_not_set(flags, options, "no_compile", True)
49+
_inherit_if_not_set(flags, options, "use_mlir", False)
4850

4951

5052
class DpexKernelTarget(TargetDescriptor):

numba_dpex/core/pipelines/dpjit_compiler.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ class _DpjitPassBuilder(object):
3737
execution.
3838
"""
3939

40+
_use_mlir = False
41+
4042
@staticmethod
4143
def define_typed_pipeline(state, name="dpex_dpjit_typed"):
4244
"""Returns the typed part of the nopython pipeline"""
@@ -55,19 +57,31 @@ def define_typed_pipeline(state, name="dpex_dpjit_typed"):
5557
pm.add_pass(NopythonRewrites, "nopython rewrites")
5658
pm.add_pass(ParforPass, "convert to parfors")
5759
pm.add_pass(
58-
ParforLegalizeCFDPass, "Legalize parfors for compute follows data"
60+
ParforLegalizeCFDPass,
61+
"Legalize parfors for compute follows data",
5962
)
6063
pm.add_pass(ParforFusionPass, "fuse parfors")
6164
pm.add_pass(ParforPreLoweringPass, "parfor prelowering")
6265

6366
pm.finalize()
6467
return pm
6568

66-
@staticmethod
67-
def define_nopython_lowering_pipeline(state, name="dpex_dpjit_lowering"):
69+
@classmethod
70+
def define_nopython_lowering_pipeline(
71+
cls, state, name="dpex_dpjit_lowering"
72+
):
6873
"""Returns an nopython mode pipeline based PassManager"""
6974
pm = PassManager(name)
7075

76+
flags = state.flags
77+
if cls._use_mlir or hasattr(flags, "use_mlir") and flags.use_mlir:
78+
from numba_mlir.mlir.passes import MlirReplaceParfors
79+
80+
pm.add_pass(
81+
MlirReplaceParfors,
82+
"Lower parfor using MLIR pipeline",
83+
)
84+
7185
# legalize
7286
pm.add_pass(
7387
NoPythonSupportedFeatureValidation,
@@ -85,11 +99,11 @@ def define_nopython_lowering_pipeline(state, name="dpex_dpjit_lowering"):
8599
pm.finalize()
86100
return pm
87101

88-
@staticmethod
89-
def define_nopython_pipeline(state, name="dpex_dpjit_nopython"):
102+
@classmethod
103+
def define_nopython_pipeline(cls, state, name="dpex_dpjit_nopython"):
90104
"""Returns an nopython mode pipeline based PassManager"""
91105
# compose pipeline from untyped, typed and lowering parts
92-
dpb = _DpjitPassBuilder
106+
dpb = cls
93107
pm = PassManager(name)
94108
untyped_passes = DefaultPassBuilder.define_untyped_pipeline(state)
95109
pm.passes.extend(untyped_passes.passes)
@@ -104,17 +118,31 @@ def define_nopython_pipeline(state, name="dpex_dpjit_nopython"):
104118
return pm
105119

106120

121+
class _DpjitPassBuilderMlir(_DpjitPassBuilder):
122+
_use_mlir = True
123+
124+
107125
class DpjitCompiler(CompilerBase):
108126
"""Dpex's compiler pipeline to offload parfor nodes into SYCL kernels."""
109127

128+
_pass_builder = _DpjitPassBuilder
129+
110130
def define_pipelines(self):
111131
pms = []
112132
self.state.parfor_diagnostics = ExtendedParforDiagnostics()
113133
self.state.metadata[
114134
"parfor_diagnostics"
115135
] = self.state.parfor_diagnostics
116136
if not self.state.flags.force_pyobject:
117-
pms.append(_DpjitPassBuilder.define_nopython_pipeline(self.state))
137+
pms.append(self._pass_builder.define_nopython_pipeline(self.state))
118138
if self.state.status.can_fallback or self.state.flags.force_pyobject:
119139
raise UnsupportedCompilationModeError()
120140
return pms
141+
142+
143+
class DpjitCompilerMlir(DpjitCompiler):
144+
_pass_builder = _DpjitPassBuilderMlir
145+
146+
147+
def get_compiler(use_mlir):
148+
return DpjitCompilerMlir if use_mlir else DpjitCompiler

numba_dpex/decorators.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313
compile_func,
1414
compile_func_template,
1515
)
16-
from numba_dpex.core.pipelines.dpjit_compiler import DpjitCompiler
16+
from numba_dpex.core.pipelines.dpjit_compiler import get_compiler
17+
18+
from .config import USE_MLIR
1719

1820

1921
def kernel(
@@ -152,9 +154,12 @@ def dpjit(*args, **kws):
152154
"pipeline class is set for dpjit and is ignored", RuntimeWarning
153155
)
154156
del kws["forceobj"]
157+
158+
use_mlir = kws.pop("use_mlir", bool(USE_MLIR))
159+
155160
kws.update({"nopython": True})
156161
kws.update({"parallel": True})
157-
kws.update({"pipeline_class": DpjitCompiler})
162+
kws.update({"pipeline_class": get_compiler(use_mlir)})
158163

159164
# FIXME: When trying to use dpex's target context, overloads do not work
160165
# properly. We will turn on dpex target once the issue is fixed.

numba_dpex/tests/_helper.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,23 @@
66

77
import contextlib
88
import shutil
9+
from functools import cache
910

1011
import dpctl
1112
import dpnp
1213
import pytest
1314

14-
from numba_dpex import config, numba_sem_version
15+
from numba_dpex import config, dpjit, numba_sem_version
16+
17+
18+
@cache
19+
def has_numba_mlir():
20+
try:
21+
import numba_mlir
22+
except ImportError:
23+
return False
24+
25+
return True
1526

1627

1728
def has_opencl_gpu():
@@ -89,6 +100,10 @@ def is_windows():
89100
not has_level_zero(),
90101
reason="No level-zero GPU platforms available",
91102
)
103+
skip_no_numba_mlir = pytest.mark.skipif(
104+
not has_numba_mlir(),
105+
reason="numba-mlir package is not availabe",
106+
)
92107

93108
filter_strings = [
94109
pytest.param("level_zero:gpu:0", marks=skip_no_level_zero_gpu),
@@ -123,6 +138,14 @@ def is_windows():
123138
)
124139

125140

141+
decorators = [
142+
pytest.param(dpjit, id="dpjit"),
143+
pytest.param(
144+
dpjit(use_mlir=True), id="dpjit_mlir", marks=skip_no_numba_mlir
145+
),
146+
]
147+
148+
126149
@contextlib.contextmanager
127150
def override_config(name, value, config=config):
128151
"""

numba_dpex/tests/test_prange.py

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,12 @@
1212

1313
from numba_dpex import dpjit, prange
1414

15+
from ._helper import decorators
1516

16-
def test_one_prange_mul():
17-
@dpjit
17+
18+
@pytest.mark.parametrize("jit", decorators)
19+
def test_one_prange_mul(jit):
20+
@jit
1821
def f(a, b):
1922
for i in prange(4):
2023
b[i, 0] = a[i, 0] * 10
@@ -35,6 +38,33 @@ def f(a, b):
3538
assert nb[i, 0] == na[i, 0] * 10
3639

3740

41+
@pytest.mark.parametrize("jit", decorators)
42+
def test_one_prange_mul_nested(jit):
43+
@jit
44+
def f_inner(a, b):
45+
for i in prange(4):
46+
b[i, 0] = a[i, 0] * 10
47+
return
48+
49+
@jit
50+
def f(a, b):
51+
return f_inner(a, b)
52+
53+
device = dpctl.select_default_device()
54+
55+
m = 8
56+
n = 8
57+
a = dpnp.ones((m, n), device=device)
58+
b = dpnp.ones((m, n), device=device)
59+
60+
f(a, b)
61+
na = dpnp.asnumpy(a)
62+
nb = dpnp.asnumpy(b)
63+
64+
for i in range(4):
65+
assert nb[i, 0] == na[i, 0] * 10
66+
67+
3868
@pytest.mark.skip(reason="dpnp.add() doesn't support variable + scalar.")
3969
def test_one_prange_add_scalar():
4070
@dpjit
@@ -155,8 +185,9 @@ def f(a, b):
155185
assert np.all(b.asnumpy() == 12)
156186

157187

158-
def test_two_consecutive_prange():
159-
@dpjit
188+
@pytest.mark.parametrize("jit", decorators)
189+
def test_two_consecutive_prange(jit):
190+
@jit
160191
def prange_example(a, b, c, d):
161192
for i in prange(n):
162193
c[i] = a[i] + b[i]

0 commit comments

Comments
 (0)