Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions python/mlir/dialects/transform/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,16 @@

register_dialect_extension(get_dialect_registry())

from ...ir import ArrayAttr, SymbolRefAttr, Attribute, Type
from ...ir import (
ArrayAttr,
SymbolRefAttr,
Attribute,
Type,
StringAttr,
IntegerAttr,
IntegerType,
BoolAttr,
)
from .._tune_transform_ops_gen import TuneSelectOp

from collections.abc import Sequence
Expand All @@ -13,13 +22,27 @@
def select(
selected: Type, # transform.any_param or transform.param<...>
name: Union[str, Attribute],
options: Union[ArrayAttr, Sequence[Attribute]],
options: Union[ArrayAttr, Sequence[Union[Attribute, str, int, bool]]],
loc=None,
ip=None,
) -> TuneSelectOp:
if isinstance(name, str):
name = SymbolRefAttr.get([name])

if not isinstance(options, ArrayAttr):
option_attrs = []
for option in options:
if isinstance(option, str):
option_attrs.append(StringAttr.get(option))
elif isinstance(option, int):
int_type = IntegerType.get_signless(64)
option_attrs.append(IntegerAttr.get(int_type, option))
elif isinstance(option, bool):
option_attrs.append(BoolAttr.get(option))
elif isinstance(option, Attribute):
option_attrs.append(option)
options = ArrayAttr.get(option_attrs)

return TuneSelectOp(
selected=selected,
name=name,
Expand Down
2 changes: 1 addition & 1 deletion python/mlir/dialects/tune.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .._mlir_libs import get_dialect_registry
from .._mlir_libs._tppDialects.tune import register_dialect
from .._mlir_libs._tppDialects.tune import *

register_dialect(get_dialect_registry())
10 changes: 7 additions & 3 deletions python/mlir/tpp/sched/bundles.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import Optional, Sequence

from mlir import ir
from mlir.dialects import transform
from .common import apply_registered_pass, match
from .common import apply_registered_pass, match, select
from .utils import GpuBackend, PipelineInterrupt

from ..xsmm import utils as xsmm_utils
Expand Down Expand Up @@ -34,7 +33,12 @@ def tpp_mapping(mod, lower_pack_unpack_without_transpose: bool = False, **_confi
func = apply_registered_pass(func, "pack-conv2DNchwFchw")
func = apply_registered_pass(func, "pack-conv2DNhwcHwcf")
func = apply_registered_pass(func, "rewrite-conv-to-matmul-or-brgemm")
func = apply_registered_pass(func, "pack-matmul")
m = select("m", [2, 4, 8])
n = select("n", [4, 8, 16])
k = select("k", [2, 4, 8, 16])
func = apply_registered_pass(
func, "pack-matmul", options={"block-factors": [m, n, k]}
)
apply_registered_pass(func, "pack-vnni")
if lower_pack_unpack_without_transpose:
mod = apply_registered_pass(mod, "lower-packs-unpacks-without-transpose")
Expand Down
7 changes: 6 additions & 1 deletion python/mlir/tpp/sched/common.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from mlir.dialects import transform
from mlir.dialects.transform import structured
from mlir.dialects.transform import structured, tune


# Wrapper to addresss verbosity.
def apply_registered_pass(*args, **kwargs):
return transform.apply_registered_pass(transform.AnyOpType.get(), *args, **kwargs)


# Wrapper to addresss verbosity.
def select(*args, **kwargs):
return tune.select(transform.AnyParamType.get(), *args, **kwargs)


# Wrapper to addresss verbosity.
def match(*args, **kwargs):
return structured.MatchOp(transform.AnyOpType.get(), *args, **kwargs)
1 change: 1 addition & 0 deletions tools/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ add_subdirectory(mlir-gen)
add_subdirectory(tpp-opt)
add_subdirectory(tpp-run)
add_subdirectory(tpp-sched)
add_subdirectory(tpp-tune)
add_subdirectory(fpcmp)
add_subdirectory(bench-ref)
9 changes: 9 additions & 0 deletions tools/tpp-tune/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
file(MAKE_DIRECTORY
${CMAKE_BINARY_DIR}/bin)
file(CREATE_LINK
${CMAKE_CURRENT_SOURCE_DIR}/tpp-tune.py
${CMAKE_BINARY_DIR}/bin/tpp-tune
SYMBOLIC)


add_custom_target(tpp-tune DEPENDS ${CMAKE_BINARY_DIR}/bin/tpp-tune TPPPythonModules)
70 changes: 70 additions & 0 deletions tools/tpp-tune/tpp-tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
#!/usr/bin/env python3

import sys
from pathlib import Path
from typing import Union, Sequence, Dict
from pprint import pprint


# Enable automagically finding TPP-MLIR's python modules (which include
# and extend MLIR's Python bindings).
python_packages_path = Path(__file__).parent.parent / "python_packages"
if python_packages_path.exists():
sys.path = [str(python_packages_path)] + sys.path


from mlir import ir
from mlir.dialects import transform
from mlir.dialects.transform import tune as transform_tune


def walker(f):
def wrapper(op: Union[ir.OpView, ir.Operation]):
f(op)
for region in op.regions:
for block in region.blocks:
for child_op in block:
wrapper(child_op)

return wrapper


def autotune(choices: Dict[str, Sequence[ir.Attribute]]) -> Dict[str, ir.Attribute]:
# Aint tuning easy!!
return {key: values[0] for key, values in choices.items()}
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just need to replace this function's implementation to make it useful.



file = sys.stdin
if len(sys.argv) > 1 and sys.argv[1] != "-":
file = open(sys.argv[1])


with ir.Context(), ir.Location.unknown():
schedule = ir.Module.parse(file.read())

choices = {}

@walker
def choices_finder(op):
if isinstance(op, transform_tune.TuneSelectOp):
choices[op.name] = tuple(op.options)

choices_finder(schedule.operation)

pprint(choices, stream=sys.stderr)

selected = autotune(choices)

@walker
def selected_rewriter(op: Union[ir.OpView, ir.Operation]):
if isinstance(op, transform_tune.TuneSelectOp):
with ir.InsertionPoint(op):
param = transform.param_constant(
transform.AnyParamType.get(), selected[op.name]
)
for use in op.result.uses:
use.owner.operands[use.operand_number] = param

selected_rewriter(schedule.operation)

print(schedule)
Loading