Skip to content

Commit 675056f

Browse files
committed
[eudsl-python-extras] add arm sme tiled matmul
1 parent c06a2ee commit 675056f

File tree

1 file changed

+210
-0
lines changed

1 file changed

+210
-0
lines changed
Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
import mlir.extras.types as T
2+
import numpy as np
3+
from mlir.dialects import builtin
4+
from mlir.dialects.transform import any_op_t
5+
from mlir.dialects.transform.extras import named_sequence, apply_patterns
6+
from mlir.dialects.transform.structured import MatchInterfaceEnum, VectorizeOp
7+
from mlir.dialects.transform.vector import (
8+
VectorContractLowering,
9+
)
10+
from mlir.ir import StringAttr, UnitAttr, Attribute
11+
12+
# you need this to register the memref value caster
13+
# noinspection PyUnresolvedReferences
14+
import mlir.extras.dialects.memref
15+
from mlir.extras.context import RAIIMLIRContext, ExplicitlyManagedModule
16+
from mlir.extras.dialects import linalg
17+
from mlir.extras.dialects import transform, llvm
18+
from mlir.extras.dialects.func import func
19+
from mlir.extras.dialects.transform import (
20+
match,
21+
get_parent_op,
22+
)
23+
from mlir.extras.runtime.passes import Pipeline, run_pipeline
24+
from mlir.extras.runtime.refbackend import LLVMJITBackend
25+
from mlir.extras.util import find_ops
26+
27+
ctx = RAIIMLIRContext()
28+
backend = LLVMJITBackend()
29+
module = ExplicitlyManagedModule()
30+
31+
M, K, N = 7, 13, 7
32+
33+
34+
@func
35+
def matmul_armsme(
36+
A: T.tensor(M, K, T.f32()),
37+
B: T.tensor(K, N, T.f32()),
38+
C: T.tensor(M, N, T.f32()),
39+
):
40+
return linalg.matmul(A, B, C)
41+
42+
43+
@builtin.module(attrs={"transform.target_tag": StringAttr.get("payload")})
44+
def payload():
45+
matmul_armsme.emit(force=True)
46+
47+
48+
# based on https://github.com/llvm/llvm-project/blob/ad656d3a1954dd6157ba689b3003b6fbb97a0833/mlir/test/Integration/Dialect/Linalg/CPU/ArmSME/matmul.mlir
49+
@builtin.module(attrs={"transform.with_named_sequence": UnitAttr.get()})
50+
def mod_transform():
51+
@named_sequence("main", [any_op_t()], [])
52+
def main(module_op: any_op_t()):
53+
# Step 1: Match the linalg.matmul operation
54+
matmul_op = match(module_op, ops=["linalg.matmul"])
55+
56+
# Step 2: Tile for size [4] x [4], which corresponds to SVLs x SVLs
57+
tiled_linalg_op, loops = transform.tile_to_scf_for(
58+
matmul_op, sizes=[[4], [4], 1]
59+
)
60+
61+
# Step 3: Vectorize
62+
VectorizeOp(tiled_linalg_op, vector_sizes=[[4], [4], 1])
63+
64+
# Step 4: Bufferize ahead of TransferReadDropUnitDimsPattern
65+
bufferize = transform.bufferization.one_shot_bufferize(
66+
module_op, bufferize_function_boundaries=True
67+
)
68+
69+
# Step 5: Match func.func operations
70+
func_op = match(bufferize, ops=["func.func"])
71+
72+
# Step 6: Lower vector.multi_reduction to vector.contract (+ some helpful patterns)
73+
@apply_patterns(func_op)
74+
def patterns1():
75+
transform.apply_patterns.vector.lower_masked_transfers()
76+
transform.apply_patterns.vector.transfer_permutation_patterns()
77+
transform.apply_patterns.vector.reduction_to_contract()
78+
transform.apply_patterns.vector.sink_ops()
79+
80+
# Step 7: Lower vector.contract to vector.outerproduct
81+
@apply_patterns(func_op)
82+
def patterns2():
83+
transform.apply_patterns.vector.lower_contraction(
84+
lowering_strategy=VectorContractLowering.OuterProduct
85+
)
86+
transform.apply_patterns.vector.lower_masks()
87+
transform.apply_patterns.vector.rank_reducing_subview_patterns()
88+
transform.apply_patterns.canonicalization()
89+
90+
# # Step 8 (optional optimization): Hoist accumulator load/store
91+
func_h = transform.structured.hoist_redundant_vector_transfers(
92+
any_op_t(), func_op
93+
)
94+
95+
all_loops = match(bufferize, interface=MatchInterfaceEnum.LoopLikeInterface)
96+
97+
transform.apply_licm(all_loops)
98+
transform.loop.hoist_loop_invariant_subsets(all_loops)
99+
100+
101+
module = module.finish()
102+
103+
vectorized_module = run_pipeline(
104+
module,
105+
pipeline=Pipeline()
106+
.transform_interpreter(entry_point="main", debug_payload_root_tag="payload")
107+
.canonicalize()
108+
.cse(),
109+
)
110+
111+
# print(vectorized_module)
112+
113+
kernel_funcs = find_ops(
114+
vectorized_module.operation, lambda o: isinstance(o.opview, llvm.LLVMFuncOp)
115+
)
116+
for k in kernel_funcs:
117+
k.attributes["target_features"] = Attribute.parse(
118+
'#llvm.target_features<["+sme", "+sve"]>'
119+
)
120+
121+
122+
lower_to_llvm = (
123+
Pipeline()
124+
# https://github.com/llvm/llvm-project/blob/9146ef5df0543f08a86686cfeb3bd1ea7338f4c6/mlir/test/lib/Dialect/ArmSME/TestLowerToArmSME.cpp#L45
125+
# Legalize vector operations so they can be converted to ArmSME.
126+
.arm_sme_vector_legalization()
127+
# Sprinkle some cleanups.
128+
.canonicalize()
129+
.cse()
130+
# Passes that convert operations on vectors to ArmSME operations.
131+
# Convert Arith to ArmSME.
132+
.convert_arith_to_arm_sme()
133+
# Convert Vector to ArmSME.
134+
.convert_vector_to_arm_sme()
135+
# Convert operations on high-level vectors to loops.
136+
# Convert ArmSME to SCF.
137+
.convert_arm_sme_to_scf()
138+
# Convert Vector to SCF (with full unroll enabled).
139+
.convert_vector_to_scf(full_unroll=True)
140+
# Enable streaming-mode and ZA.
141+
.Func(
142+
Pipeline().enable_arm_streaming(
143+
streaming_mode="streaming-locally",
144+
za_mode="new-za",
145+
if_required_by_ops=True,
146+
)
147+
)
148+
# Convert SCF to CF (required for ArmSME tile allocation).
149+
.convert_scf_to_cf()
150+
# Convert ArmSME to LLVM.
151+
.Func(Pipeline().convert_arm_sme_to_llvm())
152+
# Sprinkle some cleanups.
153+
.canonicalize()
154+
.cse()
155+
# https://github.com/makslevental/llvm-project/blob/f6643263631bcb0d191ef923963ac1a5ca9ac5fd/mlir/test/lib/Dialect/LLVM/TestLowerToLLVM.cpp#L44
156+
.Func(
157+
Pipeline()
158+
# Blanket-convert any remaining high-level vector ops to loops if any remain.
159+
.convert_vector_to_scf()
160+
# Blanket-convert any remaining linalg ops to loops if any remain.
161+
.convert_linalg_to_loops()
162+
)
163+
# Blanket-convert any remaining affine ops if any remain.
164+
.lower_affine()
165+
# Convert SCF to CF (always needed).
166+
.convert_scf_to_cf()
167+
# Sprinkle some cleanups.
168+
.canonicalize()
169+
.cse()
170+
# Convert vector to LLVM (always needed).
171+
.convert_vector_to_llvm()
172+
# Convert Math to LLVM (always needed).
173+
.Func(Pipeline().convert_math_to_llvm())
174+
# Expand complicated MemRef operations before lowering them.
175+
.expand_strided_metadata()
176+
# The expansion may create affine expressions. Get rid of them.
177+
.lower_affine()
178+
# Convert MemRef to LLVM (always needed).
179+
.finalize_memref_to_llvm()
180+
# Convert Func to LLVM (always needed).
181+
.convert_func_to_llvm()
182+
.convert_arith_to_llvm()
183+
.convert_cf_to_llvm()
184+
# Convert Index to LLVM (always needed).
185+
.convert_index_to_llvm()
186+
# Convert UB to LLVM (always needed).
187+
.convert_ub_to_llvm()
188+
# Convert remaining unrealized_casts (always needed).
189+
.reconcile_unrealized_casts()
190+
)
191+
192+
compiled_module = backend.compile(
193+
find_ops(
194+
vectorized_module.operation,
195+
lambda x: "transform.target_tag" in x.attributes
196+
and x.attributes["transform.target_tag"].value == "payload",
197+
single=True,
198+
),
199+
kernel_name=matmul_armsme.__name__,
200+
pipeline=lower_to_llvm,
201+
)
202+
203+
print(compiled_module)
204+
205+
A = np.random.randint(0, 10, (M, K)).astype(np.float32)
206+
B = np.random.randint(0, 10, (K, N)).astype(np.float32)
207+
C = np.zeros((M, N), dtype=np.float32)
208+
209+
backend.load(compiled_module).matmul_armsme_capi_wrapper(A, B, C)
210+
assert np.allclose(A @ B, C)

0 commit comments

Comments
 (0)