Skip to content

Commit a4306d4

Browse files
committed
Add --pack-block-factors cmd arg to control -pack-matmul options
1 parent df10928 commit a4306d4

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

python/mlir/tpp/sched/bundles.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from typing import Optional, Sequence
1+
from typing import Optional, Sequence, Union
22

3+
from mlir import ir
34
from mlir.dialects import transform
45
from .common import apply_registered_pass, match, select
56
from .utils import GpuBackend, PipelineInterrupt
@@ -21,7 +22,14 @@ def cleanup(op, **_config):
2122

2223

2324
# TODO: make bundle into a NamedSequence to call with IncludeOp
24-
def tpp_mapping(mod, lower_pack_unpack_without_transpose: bool = False, **_config):
25+
def tpp_mapping(
26+
mod,
27+
lower_pack_unpack_without_transpose: bool = False,
28+
pack_block_factors: Optional[
29+
Sequence[Union[Sequence[Union[int, ir.IntegerAttr]], int, ir.IntegerAttr]]
30+
] = None,
31+
**_config,
32+
):
2533
"High-level transforms that map operations to TPP-compatible forms."
2634

2735
# Preprocess convolutions.
@@ -33,12 +41,14 @@ def tpp_mapping(mod, lower_pack_unpack_without_transpose: bool = False, **_confi
3341
func = apply_registered_pass(func, "pack-conv2DNchwFchw")
3442
func = apply_registered_pass(func, "pack-conv2DNhwcHwcf")
3543
func = apply_registered_pass(func, "rewrite-conv-to-matmul-or-brgemm")
36-
m = select("m", [2, 4, 8])
37-
n = select("n", [4, 8, 16])
38-
k = select("k", [2, 4, 8, 16])
39-
func = apply_registered_pass(
40-
func, "pack-matmul", options={"block-factors": [m, n, k]}
41-
)
44+
options = None
45+
if pack_block_factors:
46+
m_vals, n_vals, k_vals = pack_block_factors
47+
m = select("m", m_vals if isinstance(m_vals, Sequence) else [m_vals])
48+
n = select("n", n_vals if isinstance(n_vals, Sequence) else [n_vals])
49+
k = select("k", k_vals if isinstance(k_vals, Sequence) else [k_vals])
50+
options = {"block-factors": [m, n, k]}
51+
func = apply_registered_pass(func, "pack-matmul", options=options)
4252
apply_registered_pass(func, "pack-vnni")
4353
if lower_pack_unpack_without_transpose:
4454
mod = apply_registered_pass(mod, "lower-packs-unpacks-without-transpose")

python/mlir/tpp/sched/main.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,15 @@ def comma_separated_ints(arg: str):
3232
"--payload", type=str, help="payload file to print with schedule"
3333
)
3434

35+
def block_factors(arg: str):
36+
m, n, k = arg.split(",")
37+
38+
convert = lambda dim: list(map(int, dim))
39+
40+
return convert(m.split(";")), convert(n.split(";")), convert(k.split(";"))
41+
42+
parser.add_argument("--pack-block-factors", type=block_factors, default=None)
43+
3544
parser.add_argument("--split-input-file", action="store_true")
3645

3746
def comma_separated_bundles(arg: str):

0 commit comments

Comments
 (0)