Skip to content

Commit 28063a2

Browse files
committed
[mlir][sparse] refactored python setup of sparse compiler
Reviewed By: bixia Differential Revision: https://reviews.llvm.org/D123419
1 parent b40e901 commit 28063a2

File tree

6 files changed

+46
-39
lines changed

6 files changed

+46
-39
lines changed

mlir/test/Integration/Dialect/SparseTensor/python/test_SDDMM.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from mlir import ir
1010
from mlir import runtime as rt
11-
from mlir import execution_engine
1211

1312
from mlir.dialects import sparse_tensor as st
1413
from mlir.dialects import builtin
@@ -69,17 +68,14 @@ def boilerplate(attr: st.EncodingAttr):
6968
"""
7069

7170

72-
def build_compile_and_run_SDDMMM(attr: st.EncodingAttr, opt: str,
73-
support_lib: str, compiler):
71+
def build_compile_and_run_SDDMMM(attr: st.EncodingAttr, compiler):
7472
# Build.
7573
module = build_SDDMM(attr)
7674
func = str(module.operation.regions[0].blocks[0].operations[0].operation)
7775
module = ir.Module.parse(func + boilerplate(attr))
7876

7977
# Compile.
80-
compiler(module)
81-
engine = execution_engine.ExecutionEngine(
82-
module, opt_level=0, shared_libs=[support_lib])
78+
engine = compiler.compile_and_jit(module)
8379

8480
# Set up numpy input and buffer for output.
8581
a = np.array([[1.1, 2.1, 3.1, 4.1, 5.1, 6.1, 7.1, 8.1],
@@ -156,8 +152,9 @@ def main():
156152
opt = (f'parallelization-strategy={par} '
157153
f'vectorization-strategy={vec} '
158154
f'vl={vl} enable-simd-index32={e}')
159-
compiler = sparse_compiler.SparseCompiler(options=opt)
160-
build_compile_and_run_SDDMMM(attr, opt, support_lib, compiler)
155+
compiler = sparse_compiler.SparseCompiler(
156+
options=opt, opt_level=0, shared_libs=[support_lib])
157+
build_compile_and_run_SDDMMM(attr, compiler)
161158
count = count + 1
162159
# CHECK: Passed 16 tests
163160
print('Passed ', count, 'tests')

mlir/test/Integration/Dialect/SparseTensor/python/test_SpMM.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88

99
from mlir import ir
1010
from mlir import runtime as rt
11-
from mlir import execution_engine
1211

1312
from mlir.dialects import sparse_tensor as st
1413
from mlir.dialects import builtin
@@ -69,17 +68,14 @@ def boilerplate(attr: st.EncodingAttr):
6968
"""
7069

7170

72-
def build_compile_and_run_SpMM(attr: st.EncodingAttr, support_lib: str,
73-
compiler):
71+
def build_compile_and_run_SpMM(attr: st.EncodingAttr, compiler):
7472
# Build.
7573
module = build_SpMM(attr)
7674
func = str(module.operation.regions[0].blocks[0].operations[0].operation)
7775
module = ir.Module.parse(func + boilerplate(attr))
7876

7977
# Compile.
80-
compiler(module)
81-
engine = execution_engine.ExecutionEngine(
82-
module, opt_level=0, shared_libs=[support_lib])
78+
engine = compiler.compile_and_jit(module)
8379

8480
# Set up numpy input and buffer for output.
8581
a = np.array(
@@ -140,13 +136,14 @@ def main():
140136
ir.AffineMap.get_permutation([1, 0])
141137
]
142138
bitwidths = [0]
139+
compiler = sparse_compiler.SparseCompiler(
140+
options=opt, opt_level=0, shared_libs=[support_lib])
143141
for level in levels:
144142
for ordering in orderings:
145143
for pwidth in bitwidths:
146144
for iwidth in bitwidths:
147145
attr = st.EncodingAttr.get(level, ordering, pwidth, iwidth)
148-
compiler = sparse_compiler.SparseCompiler(options=opt)
149-
build_compile_and_run_SpMM(attr, support_lib, compiler)
146+
build_compile_and_run_SpMM(attr, compiler)
150147
count = count + 1
151148
# CHECK: Passed 8 tests
152149
print('Passed ', count, 'tests')

mlir/test/Integration/Dialect/SparseTensor/python/test_elementwise_add_sparse_output.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88
from mlir import ir
99
from mlir import runtime as rt
10-
from mlir import execution_engine
1110
from mlir.dialects import sparse_tensor as st
1211
from mlir.dialects import builtin
1312
from mlir.dialects.linalg.opdsl import lang as dsl
@@ -61,10 +60,10 @@
6160

6261
def _run_test(support_lib, kernel):
6362
"""Compiles, runs and checks results."""
63+
compiler = sparse_compiler.SparseCompiler(
64+
options='', opt_level=2, shared_libs=[support_lib])
6465
module = ir.Module.parse(kernel)
65-
sparse_compiler.SparseCompiler(options='')(module)
66-
engine = execution_engine.ExecutionEngine(
67-
module, opt_level=0, shared_libs=[support_lib])
66+
engine = compiler.compile_and_jit(module)
6867

6968
# Set up numpy inputs and buffer for output.
7069
a = np.array(

mlir/test/Integration/Dialect/SparseTensor/python/test_output.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import sys
77
import tempfile
88

9-
from mlir import execution_engine
109
from mlir import ir
1110
from mlir import runtime as rt
1211

@@ -49,13 +48,10 @@ def expected():
4948
"""
5049

5150

52-
def build_compile_and_run_output(attr: st.EncodingAttr, support_lib: str,
53-
compiler):
51+
def build_compile_and_run_output(attr: st.EncodingAttr, compiler):
5452
# Build and Compile.
5553
module = ir.Module.parse(boilerplate(attr))
56-
compiler(module)
57-
engine = execution_engine.ExecutionEngine(
58-
module, opt_level=0, shared_libs=[support_lib])
54+
engine = compiler.compile_and_jit(module)
5955

6056
# Invoke the kernel and compare output.
6157
with tempfile.TemporaryDirectory() as test_dir:
@@ -88,12 +84,13 @@ def main():
8884
ir.AffineMap.get_permutation([1, 0])
8985
]
9086
bitwidths = [8, 16, 32, 64]
87+
compiler = sparse_compiler.SparseCompiler(
88+
options='', opt_level=2, shared_libs=[support_lib])
9189
for level in levels:
9290
for ordering in orderings:
9391
for bwidth in bitwidths:
9492
attr = st.EncodingAttr.get(level, ordering, bwidth, bwidth)
95-
compiler = sparse_compiler.SparseCompiler(options='')
96-
build_compile_and_run_output(attr, support_lib, compiler)
93+
build_compile_and_run_output(attr, compiler)
9794
count = count + 1
9895

9996
# CHECK: Passed 16 tests

mlir/test/Integration/Dialect/SparseTensor/python/test_stress.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313

1414
from mlir import ir
1515
from mlir import runtime as rt
16-
from mlir.execution_engine import ExecutionEngine
1716

1817
from mlir.dialects import builtin
1918
from mlir.dialects import func
@@ -139,15 +138,13 @@ def writeTo(self, filename):
139138
f.write(str(self._module))
140139
return self
141140

142-
def compile(self, compiler, support_lib: str):
141+
def compile(self, compiler):
143142
"""Compile the ir.Module."""
144143
assert self._module is not None, \
145144
'StressTest: must call build() before compile()'
146145
assert self._engine is None, \
147146
'StressTest: must not call compile() repeatedly'
148-
compiler(self._module)
149-
self._engine = ExecutionEngine(
150-
self._module, opt_level=0, shared_libs=[support_lib])
147+
self._engine = compiler.compile_and_jit(self._module)
151148
return self
152149

153150
def run(self, np_arg0: np.ndarray) -> np.ndarray:
@@ -194,7 +191,8 @@ def main():
194191
f'vectorization-strategy={vec} '
195192
f'vl={vl} '
196193
f'enable-simd-index32={e}')
197-
compiler = sparse_compiler.SparseCompiler(options=sparsification_options)
194+
compiler = sparse_compiler.SparseCompiler(
195+
options=sparsification_options, opt_level=0, shared_libs=[support_lib])
198196
f64 = ir.F64Type.get()
199197
# Be careful about increasing this because
200198
# len(types) = 1 + 2^rank * rank! * len(bitwidths)^2
@@ -230,9 +228,8 @@ def main():
230228
np_arg0 = np.arange(size, dtype=tyconv.irtype_to_dtype(f64)).reshape(*shape)
231229
np_out = (
232230
StressTest(tyconv).build(types).writeTo(
233-
sys.argv[1] if len(sys.argv) > 1 else None).compile(
234-
compiler, support_lib).writeTo(
235-
sys.argv[2] if len(sys.argv) > 2 else None).run(np_arg0))
231+
sys.argv[1] if len(sys.argv) > 1 else None).compile(compiler)
232+
.writeTo(sys.argv[2] if len(sys.argv) > 2 else None).run(np_arg0))
236233
# CHECK: Passed
237234
if np.allclose(np_out, np_arg0):
238235
print('Passed')

mlir/test/Integration/Dialect/SparseTensor/python/tools/sparse_compiler.py

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,35 @@
55
# This file contains the sparse compiler class.
66

77
from mlir import all_passes_registration
8+
from mlir import execution_engine
89
from mlir import ir
910
from mlir import passmanager
11+
from typing import Sequence
1012

1113
class SparseCompiler:
12-
"""Sparse compiler definition."""
14+
"""Sparse compiler class for compiling and building MLIR modules."""
1315

14-
def __init__(self, options: str):
16+
def __init__(self, options: str, opt_level: int, shared_libs: Sequence[str]):
1517
pipeline = f'sparse-compiler{{{options} reassociate-fp-reductions=1 enable-index-optimizations=1}}'
1618
self.pipeline = pipeline
19+
self.opt_level = opt_level
20+
self.shared_libs = shared_libs
1721

1822
def __call__(self, module: ir.Module):
23+
"""Convenience application method."""
24+
self.compile(module)
25+
26+
def compile(self, module: ir.Module):
27+
"""Compiles the module by invoking the sparse copmiler pipeline."""
1928
passmanager.PassManager.parse(self.pipeline).run(module)
29+
30+
def jit(self, module: ir.Module) -> execution_engine.ExecutionEngine:
31+
"""Wraps the module in a JIT execution engine."""
32+
return execution_engine.ExecutionEngine(
33+
module, opt_level=self.opt_level, shared_libs=self.shared_libs)
34+
35+
def compile_and_jit(self,
36+
module: ir.Module) -> execution_engine.ExecutionEngine:
37+
"""Compiles and jits the module."""
38+
self.compile(module)
39+
return self.jit(module)

0 commit comments

Comments
 (0)