Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 20 additions & 1 deletion include/TPP/Dialect/Tune/TuneTransformOps.td
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#ifndef TUNE_TRANSFORM_OPS
#define TUNE_TRANSFORM_OPS

include "mlir/IR/CommonAttrConstraints.td"
include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
Expand All @@ -17,7 +18,25 @@ def TuneSelectOp : Op<Transform_Dialect, "tune.select", [

let arguments = (ins SymbolRefAttr:$name,
ArrayAttr:$options);
let results = (outs TransformParamTypeInterface:$selected);
let results = (outs TransformParamTypeInterface:$result);
let assemblyFormat =
"$name `from` $options attr-dict `:` type(results)";
}


def TunePickOp : Op<Transform_Dialect, "tune.pick", [
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>
]> {
let summary = "Non-deterministically select a value from a set of values";
let description = [{
TODO
}];

let arguments = (ins SymbolRefAttr:$name,
ArrayAttr:$options,
OptionalAttr<AnyAttr>:$selected);
let results = (outs TransformParamTypeInterface:$result);
let assemblyFormat =
"$name `from` $options attr-dict `:` type(results)";
}
Expand Down
35 changes: 34 additions & 1 deletion lib/TPP/Dialect/Tune/TransformOps/TuneTransformOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,41 @@ DiagnosedSilenceableFailure
transform::TuneSelectOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
if (getOptions().size() == 1) {
results.setParams(getOperation()->getOpResults()[0], getOptions()[0]);
return DiagnosedSilenceableFailure::success();
}

return emitDefiniteFailure()
<< "this op does not have interpreted semantics!";
<< "this op does not resolve non-deterministic choice!";
}

//===----------------------------------------------------------------------===//
// TunePickOp
//===----------------------------------------------------------------------===//

void transform::TunePickOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
producesHandle(getOperation()->getOpResults(), effects);
onlyReadsPayload(effects);
}

DiagnosedSilenceableFailure
transform::TunePickOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
if (getSelected()) {
results.setParams(getOperation()->getOpResults()[0], *getSelected());
return DiagnosedSilenceableFailure::success();
}

if (getOptions().size() == 1) {
results.setParams(getOperation()->getOpResults()[0], getOptions()[0]);
return DiagnosedSilenceableFailure::success();
}

return emitDefiniteFailure() << "non-deterministic choice is only resolved "
"through providing a `selected` attr!";
}

//===----------------------------------------------------------------------===//
Expand Down
86 changes: 80 additions & 6 deletions python/mlir/dialects/transform/tune.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,101 @@

register_dialect_extension(get_dialect_registry())

from ...ir import ArrayAttr, SymbolRefAttr, Attribute, Type
from .._tune_transform_ops_gen import TuneSelectOp
from ...ir import (
ArrayAttr,
SymbolRefAttr,
Attribute,
Type,
StringAttr,
IntegerAttr,
IntegerType,
BoolAttr,
)
from .._tune_transform_ops_gen import *

from collections.abc import Sequence
from typing import Union
from typing import Union, Optional


def select(
selected: Type, # transform.any_param or transform.param<...>
result: Type, # transform.any_param or transform.param<...>
name: Union[str, Attribute],
options: Union[ArrayAttr, Sequence[Attribute]],
options: Union[ArrayAttr, Sequence[Union[Attribute, str, int, bool]]],
loc=None,
ip=None,
) -> TuneSelectOp:
if isinstance(name, str):
name = SymbolRefAttr.get([name])

if not isinstance(options, ArrayAttr):
option_attrs = []
for option in options:
if isinstance(option, str):
option_attrs.append(StringAttr.get(option))
elif isinstance(option, int):
int_type = IntegerType.get_signless(64)
option_attrs.append(IntegerAttr.get(int_type, option))
elif isinstance(option, bool):
option_attrs.append(BoolAttr.get(option))
elif isinstance(option, Attribute):
option_attrs.append(option)
options = ArrayAttr.get(option_attrs)

return TuneSelectOp(
selected=selected,
result=result,
name=name,
options=options,
loc=loc,
ip=ip,
)


def pick(
result: Type, # transform.any_param or transform.param<...>
name: Union[str, Attribute],
options: Union[ArrayAttr, Sequence[Union[Attribute, str, int, bool]]],
*,
selected: Optional[Union[Attribute, str, int, bool]] = None,
loc=None,
ip=None,
) -> TunePickOp:
if isinstance(name, str):
name = SymbolRefAttr.get([name])

if not isinstance(options, ArrayAttr):
option_attrs = []
for option in options:
if isinstance(option, str):
option_attrs.append(StringAttr.get(option))
elif isinstance(option, int):
int_type = IntegerType.get_signless(64)
option_attrs.append(IntegerAttr.get(int_type, option))
elif isinstance(option, bool):
option_attrs.append(BoolAttr.get(option))
elif isinstance(option, Attribute):
option_attrs.append(option)
else:
assert False
options = ArrayAttr.get(option_attrs)


if selected is None:
pass
elif isinstance(selected, str):
selected = StringAttr.get(selected)
elif isinstance(selected, int):
int_type = IntegerType.get_signless(64)
selected = IntegerAttr.get(int_type, selected)
elif isinstance(selected, bool):
selected = BoolAttr.get(selected)
elif not isinstance(selected, Attribute):
assert False

return TunePickOp(
result=result,
name=name,
options=options,
selected=selected,
loc=loc,
ip=ip,
)
2 changes: 1 addition & 1 deletion python/mlir/dialects/tune.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .._mlir_libs import get_dialect_registry
from .._mlir_libs._tppDialects.tune import register_dialect
from .._mlir_libs._tppDialects.tune import *

register_dialect(get_dialect_registry())
22 changes: 18 additions & 4 deletions python/mlir/tpp/sched/bundles.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from typing import Optional, Sequence
from typing import Optional, Sequence, Union

from mlir import ir
from mlir.dialects import transform
from .common import apply_registered_pass, match
from .common import apply_registered_pass, match, select, pick
from .utils import GpuBackend, PipelineInterrupt

from ..xsmm import utils as xsmm_utils
Expand All @@ -22,7 +22,14 @@ def cleanup(op, **_config):


# TODO: make bundle into a NamedSequence to call with IncludeOp
def tpp_mapping(mod, lower_pack_unpack_without_transpose: bool = False, **_config):
def tpp_mapping(
mod,
lower_pack_unpack_without_transpose: bool = False,
pack_block_factors: Optional[
Sequence[Union[Sequence[Union[int, ir.IntegerAttr]], int, ir.IntegerAttr]]
] = None,
**_config,
):
"High-level transforms that map operations to TPP-compatible forms."

# Preprocess convolutions.
Expand All @@ -34,7 +41,14 @@ def tpp_mapping(mod, lower_pack_unpack_without_transpose: bool = False, **_confi
func = apply_registered_pass(func, "pack-conv2DNchwFchw")
func = apply_registered_pass(func, "pack-conv2DNhwcHwcf")
func = apply_registered_pass(func, "rewrite-conv-to-matmul-or-brgemm")
func = apply_registered_pass(func, "pack-matmul")
options = None
if pack_block_factors:
m_vals, n_vals, k_vals = pack_block_factors
m = select("m", m_vals if isinstance(m_vals, Sequence) else [m_vals])
n = select("n", n_vals if isinstance(n_vals, Sequence) else [n_vals])
k = pick("k", k_vals if isinstance(k_vals, Sequence) else [k_vals])
options = {"block-factors": [m, n, k]}
func = apply_registered_pass(func, "pack-matmul", options=options)
apply_registered_pass(func, "pack-vnni")
if lower_pack_unpack_without_transpose:
mod = apply_registered_pass(mod, "lower-packs-unpacks-without-transpose")
Expand Down
12 changes: 11 additions & 1 deletion python/mlir/tpp/sched/common.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,15 @@
from mlir.dialects import transform
from mlir.dialects.transform import structured
from mlir.dialects.transform import structured, tune


# Wrapper to addresss verbosity.
def select(*args, **kwargs):
return tune.select(transform.AnyParamType.get(), *args, **kwargs)


# Wrapper to addresss verbosity.
def pick(*args, **kwargs):
return tune.pick(transform.AnyParamType.get(), *args, **kwargs)


# Wrapper to addresss verbosity.
Expand Down
9 changes: 9 additions & 0 deletions python/mlir/tpp/sched/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ def comma_separated_ints(arg: str):
"--payload", type=str, help="payload file to print with schedule"
)

def block_factors(arg: str):
m, n, k = arg.split(",")

convert = lambda dim: list(map(int, dim))

return convert(m.split(";")), convert(n.split(";")), convert(k.split(";"))

parser.add_argument("--pack-block-factors", type=block_factors, default=None)

parser.add_argument("--split-input-file", action="store_true")

def comma_separated_bundles(arg: str):
Expand Down
1 change: 1 addition & 0 deletions tools/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ add_subdirectory(mlir-gen)
add_subdirectory(tpp-opt)
add_subdirectory(tpp-run)
add_subdirectory(tpp-sched)
add_subdirectory(tpp-tune)
add_subdirectory(fpcmp)
add_subdirectory(bench-ref)
9 changes: 9 additions & 0 deletions tools/tpp-tune/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
file(MAKE_DIRECTORY
${CMAKE_BINARY_DIR}/bin)
file(CREATE_LINK
${CMAKE_CURRENT_SOURCE_DIR}/tpp-tune.py
${CMAKE_BINARY_DIR}/bin/tpp-tune
SYMBOLIC)


add_custom_target(tpp-tune DEPENDS ${CMAKE_BINARY_DIR}/bin/tpp-tune TPPPythonModules)
75 changes: 75 additions & 0 deletions tools/tpp-tune/tpp-tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#!/usr/bin/env python3

import sys
from pathlib import Path
from typing import Union, Sequence, Dict
import random

# Enable automagically finding TPP-MLIR's python modules (which include
# and extend MLIR's Python bindings).
python_packages_path = Path(__file__).parent.parent / "python_packages"
if python_packages_path.exists():
sys.path = [str(python_packages_path)] + sys.path


from mlir import ir
from mlir.dialects import transform
from mlir.dialects.transform import tune as transform_tune


def walker(f):
def wrapper(op: Union[ir.OpView, ir.Operation]):
f(op)
for region in op.regions:
for block in region.blocks:
for child_op in block:
wrapper(child_op)

return wrapper


def autotune(choices: Dict[str, Sequence[ir.Attribute]]) -> Dict[str, ir.Attribute]:
# Aint tuning easy!!
return {key: random.choice(values) for key, values in choices.items()}


file = sys.stdin
if len(sys.argv) > 1 and sys.argv[1] != "-":
file = open(sys.argv[1])


with ir.Context(), ir.Location.unknown():
schedule = ir.Module.parse(file.read())

choices = {}

@walker
def choices_finder(op):
if isinstance(op, transform_tune.TuneSelectOp):
if op.name in choices:
raise RuntimeError(f"options name collision: {op.name} used twice")
choices[op.name] = tuple(op.options)
elif isinstance(op, transform_tune.TunePickOp):
if op.name in choices:
raise RuntimeError(f"options name collision: {op.name} used twice")
choices[op.name] = tuple(op.options)

choices_finder(schedule.operation)

selected = autotune(choices)

@walker
def selected_rewriter(op: Union[ir.OpView, ir.Operation]):
if isinstance(op, transform_tune.TuneSelectOp):
with ir.InsertionPoint(op):
param = transform.param_constant(
transform.AnyParamType.get(), selected[op.name]
)
for use in op.result.uses:
use.owner.operands[use.operand_number] = param
elif isinstance(op, transform_tune.TunePickOp):
op.attributes["selected"] = selected[op.name]

selected_rewriter(schedule.operation)

print(schedule)