Skip to content

Commit 36413a6

Browse files
committed
[eudsl-python-extras] add arm sme tiled matmul
1 parent c06a2ee commit 36413a6

File tree

4 files changed

+226
-0
lines changed

4 files changed

+226
-0
lines changed

.github/workflows/build_test_release_eudsl_python_extras.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,10 @@ jobs:
186186
python projects/eudsl-python-extras/examples/mwe.py
187187
python projects/eudsl-python-extras/examples/rdna_matmul_opt.py
188188
189+
if [[ "${{ matrix.name }}" == "ubuntu_aarch64" ]] || [[ "${{ matrix.name }}" == "macos_arm64" ]]; then
190+
python projects/eudsl-python-extras/examples/tiled_arm_matmul.py
191+
fi
192+
189193
if [[ $(python -c "print(__import__('sys').version_info >= (3, 13))") == "True" ]]; then
190194
python projects/eudsl-python-extras/examples/cuda_matmul_opt.py
191195
fi
Lines changed: 212 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,212 @@
1+
# NB: this only works on aarch64/arm64 which supports SME
2+
3+
import mlir.extras.types as T
4+
import numpy as np
5+
from mlir.dialects import builtin
6+
from mlir.dialects.transform import any_op_t
7+
from mlir.dialects.transform.extras import named_sequence, apply_patterns
8+
from mlir.dialects.transform.structured import MatchInterfaceEnum, VectorizeOp
9+
from mlir.dialects.transform.vector import (
10+
VectorContractLowering,
11+
)
12+
from mlir.ir import StringAttr, UnitAttr, Attribute
13+
14+
# you need this to register the memref value caster
15+
# noinspection PyUnresolvedReferences
16+
import mlir.extras.dialects.memref
17+
from mlir.extras.context import RAIIMLIRContext, ExplicitlyManagedModule
18+
from mlir.extras.dialects import linalg
19+
from mlir.extras.dialects import transform, llvm
20+
from mlir.extras.dialects.func import func
21+
from mlir.extras.dialects.transform import (
22+
match,
23+
get_parent_op,
24+
)
25+
from mlir.extras.runtime.passes import Pipeline, run_pipeline
26+
from mlir.extras.runtime.refbackend import LLVMJITBackend
27+
from mlir.extras.util import find_ops
28+
29+
ctx = RAIIMLIRContext()
30+
backend = LLVMJITBackend()
31+
module = ExplicitlyManagedModule()
32+
33+
M, K, N = 7, 13, 7
34+
35+
36+
@func
37+
def matmul_armsme(
38+
A: T.tensor(M, K, T.f32()),
39+
B: T.tensor(K, N, T.f32()),
40+
C: T.tensor(M, N, T.f32()),
41+
):
42+
return linalg.matmul(A, B, C)
43+
44+
45+
@builtin.module(attrs={"transform.target_tag": StringAttr.get("payload")})
46+
def payload():
47+
matmul_armsme.emit(force=True)
48+
49+
50+
# based on https://github.com/llvm/llvm-project/blob/ad656d3a1954dd6157ba689b3003b6fbb97a0833/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
51+
@builtin.module(attrs={"transform.with_named_sequence": UnitAttr.get()})
52+
def mod_transform():
53+
@named_sequence("main", [any_op_t()], [])
54+
def main(module_op: any_op_t()):
55+
# Step 1: Match the linalg.matmul operation
56+
matmul_op = match(module_op, ops=["linalg.matmul"])
57+
58+
# Step 2: Tile for size [4] x [4], which corresponds to SVLs x SVLs
59+
tiled_linalg_op, loops = transform.tile_to_scf_for(
60+
matmul_op, sizes=[[4], [4], 1]
61+
)
62+
63+
# Step 3: Vectorize
64+
VectorizeOp(tiled_linalg_op, vector_sizes=[[4], [4], 1])
65+
66+
# Step 4: Bufferize ahead of TransferReadDropUnitDimsPattern
67+
bufferize = transform.bufferization.one_shot_bufferize(
68+
module_op, bufferize_function_boundaries=True
69+
)
70+
71+
# Step 5: Match func.func operations
72+
func_op = match(bufferize, ops=["func.func"])
73+
74+
# Step 6: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
75+
@apply_patterns(func_op)
76+
def patterns1():
77+
transform.apply_patterns.vector.lower_masked_transfers()
78+
transform.apply_patterns.vector.transfer_permutation_patterns()
79+
transform.apply_patterns.vector.reduction_to_contract()
80+
transform.apply_patterns.vector.sink_ops()
81+
82+
# Step 7: Lower vector.contract to vector.outerproduct
83+
@apply_patterns(func_op)
84+
def patterns2():
85+
transform.apply_patterns.vector.lower_contraction(
86+
lowering_strategy=VectorContractLowering.OuterProduct
87+
)
88+
transform.apply_patterns.vector.lower_masks()
89+
transform.apply_patterns.vector.rank_reducing_subview_patterns()
90+
transform.apply_patterns.canonicalization()
91+
92+
# # Step 8 (optional optimization): Hoist accumulator load/store
93+
func_h = transform.structured.hoist_redundant_vector_transfers(
94+
any_op_t(), func_op
95+
)
96+
97+
all_loops = match(bufferize, interface=MatchInterfaceEnum.LoopLikeInterface)
98+
99+
transform.apply_licm(all_loops)
100+
transform.loop.hoist_loop_invariant_subsets(all_loops)
101+
102+
103+
module = module.finish()
104+
105+
vectorized_module = run_pipeline(
106+
module,
107+
pipeline=Pipeline()
108+
.transform_interpreter(entry_point="main", debug_payload_root_tag="payload")
109+
.canonicalize()
110+
.cse(),
111+
)
112+
113+
# print(vectorized_module)
114+
115+
kernel_funcs = find_ops(
116+
vectorized_module.operation, lambda o: isinstance(o.opview, llvm.LLVMFuncOp)
117+
)
118+
for k in kernel_funcs:
119+
k.attributes["target_features"] = Attribute.parse(
120+
'#llvm.target_features<["+sme", "+sve"]>'
121+
)
122+
123+
124+
lower_to_llvm = (
125+
Pipeline()
126+
# https://github.com/llvm/llvm-project/blob/9146ef5df0543f08a86686cfeb3bd1ea7338f4c6/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp#L45
127+
# Legalize vector operations so they can be converted to ArmSME.
128+
.arm_sme_vector_legalization()
129+
# Sprinkle some cleanups.
130+
.canonicalize()
131+
.cse()
132+
# Passes that convert operations on vectors to ArmSME operations.
133+
# Convert Arith to ArmSME.
134+
.convert_arith_to_arm_sme()
135+
# Convert Vector to ArmSME.
136+
.convert_vector_to_arm_sme()
137+
# Convert operations on high-level vectors to loops.
138+
# Convert ArmSME to SCF.
139+
.convert_arm_sme_to_scf()
140+
# Convert Vector to SCF (with full unroll enabled).
141+
.convert_vector_to_scf(full_unroll=True)
142+
# Enable streaming-mode and ZA.
143+
.Func(
144+
Pipeline().enable_arm_streaming(
145+
streaming_mode="streaming-locally",
146+
za_mode="new-za",
147+
if_required_by_ops=True,
148+
)
149+
)
150+
# Convert SCF to CF (required for ArmSME tile allocation).
151+
.convert_scf_to_cf()
152+
# Convert ArmSME to LLVM.
153+
.Func(Pipeline().convert_arm_sme_to_llvm())
154+
# Sprinkle some cleanups.
155+
.canonicalize()
156+
.cse()
157+
# https://github.com/makslevental/llvm-project/blob/f6643263631bcb0d191ef923963ac1a5ca9ac5fd/mlir/test/lib/Dialect/LLVM/TestLowerToLLVM.cpp#L44
158+
.Func(
159+
Pipeline()
160+
# Blanket-convert any remaining high-level vector ops to loops if any remain.
161+
.convert_vector_to_scf()
162+
# Blanket-convert any remaining linalg ops to loops if any remain.
163+
.convert_linalg_to_loops()
164+
)
165+
# Blanket-convert any remaining affine ops if any remain.
166+
.lower_affine()
167+
# Convert SCF to CF (always needed).
168+
.convert_scf_to_cf()
169+
# Sprinkle some cleanups.
170+
.canonicalize()
171+
.cse()
172+
# Convert vector to LLVM (always needed).
173+
.convert_vector_to_llvm()
174+
# Convert Math to LLVM (always needed).
175+
.Func(Pipeline().convert_math_to_llvm())
176+
# Expand complicated MemRef operations before lowering them.
177+
.expand_strided_metadata()
178+
# The expansion may create affine expressions. Get rid of them.
179+
.lower_affine()
180+
# Convert MemRef to LLVM (always needed).
181+
.finalize_memref_to_llvm()
182+
# Convert Func to LLVM (always needed).
183+
.convert_func_to_llvm()
184+
.convert_arith_to_llvm()
185+
.convert_cf_to_llvm()
186+
# Convert Index to LLVM (always needed).
187+
.convert_index_to_llvm()
188+
# Convert UB to LLVM (always needed).
189+
.convert_ub_to_llvm()
190+
# Convert remaining unrealized_casts (always needed).
191+
.reconcile_unrealized_casts()
192+
)
193+
194+
compiled_module = backend.compile(
195+
find_ops(
196+
vectorized_module.operation,
197+
lambda x: "transform.target_tag" in x.attributes
198+
and x.attributes["transform.target_tag"].value == "payload",
199+
single=True,
200+
),
201+
kernel_name=matmul_armsme.__name__,
202+
pipeline=lower_to_llvm,
203+
)
204+
205+
# print(compiled_module)
206+
207+
A = np.random.randint(0, 10, (M, K)).astype(np.float32)
208+
B = np.random.randint(0, 10, (K, N)).astype(np.float32)
209+
C = np.zeros((M, N), dtype=np.float32)
210+
211+
backend.load(compiled_module).matmul_armsme_capi_wrapper(A, B, C)
212+
assert np.allclose(A @ B, C)

projects/eudsl-python-extras/mlir/extras/runtime/refbackend.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def _try_find_runtime_libraries(local_vars: dict):
5858
"c_runner_utils",
5959
"runner_utils",
6060
"cuda_runtime",
61+
"arm_sme_abi_stubs",
62+
"arm_runner_utils"
6163
}
6264
# TODO(max): for some reason adding cuda runtime lib to execengine
6365
# causes a segfault (or something)

projects/mlir-python-bindings/CMakeLists.txt

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,14 @@ if (TARGET mlir_cuda_runtime)
142142
list(APPEND _runtimes mlir_cuda_runtime)
143143
endif()
144144

145+
if (TARGET mlir_arm_sme_abi_stubs)
146+
list(APPEND _runtimes mlir_arm_sme_abi_stubs)
147+
endif()
148+
149+
if (TARGET mlir_arm_runner_utils)
150+
list(APPEND _runtimes mlir_arm_runner_utils)
151+
endif()
152+
145153
if (TARGET omp)
146154
list(APPEND _runtimes omp)
147155
endif()

0 commit comments

Comments
 (0)