diff --git a/include/TPP/Dialect/Tune/TuneTransformOps.td b/include/TPP/Dialect/Tune/TuneTransformOps.td index b6e75aee4..197a43393 100644 --- a/include/TPP/Dialect/Tune/TuneTransformOps.td +++ b/include/TPP/Dialect/Tune/TuneTransformOps.td @@ -1,6 +1,7 @@ #ifndef TUNE_TRANSFORM_OPS #define TUNE_TRANSFORM_OPS +include "mlir/IR/CommonAttrConstraints.td" include "mlir/Dialect/Transform/IR/TransformDialect.td" include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" @@ -17,7 +18,25 @@ def TuneSelectOp : Op, + DeclareOpInterfaceMethods +]> { + let summary = "Non-deterministically select a value from a set of values"; + let description = [{ + TODO + }]; + + let arguments = (ins SymbolRefAttr:$name, + ArrayAttr:$options, + OptionalAttr:$selected); + let results = (outs TransformParamTypeInterface:$result); let assemblyFormat = "$name `from` $options attr-dict `:` type(results)"; } diff --git a/lib/TPP/Dialect/Tune/TransformOps/TuneTransformOps.cpp b/lib/TPP/Dialect/Tune/TransformOps/TuneTransformOps.cpp index c28294835..c4c94b9a8 100644 --- a/lib/TPP/Dialect/Tune/TransformOps/TuneTransformOps.cpp +++ b/lib/TPP/Dialect/Tune/TransformOps/TuneTransformOps.cpp @@ -20,8 +20,41 @@ DiagnosedSilenceableFailure transform::TuneSelectOp::apply(transform::TransformRewriter &rewriter, transform::TransformResults &results, transform::TransformState &state) { + if (getOptions().size() == 1) { + results.setParams(getOperation()->getOpResults()[0], getOptions()[0]); + return DiagnosedSilenceableFailure::success(); + } + return emitDefiniteFailure() - << "this op does not have interpreted semantics!"; + << "this op does not resolve non-deterministic choice!"; +} + +//===----------------------------------------------------------------------===// +// TunePickOp +//===----------------------------------------------------------------------===// + +void transform::TunePickOp::getEffects( + SmallVectorImpl &effects) { + producesHandle(getOperation()->getOpResults(), effects); + onlyReadsPayload(effects); +} + +DiagnosedSilenceableFailure +transform::TunePickOp::apply(transform::TransformRewriter &rewriter, + transform::TransformResults &results, + transform::TransformState &state) { + if (getSelected()) { + results.setParams(getOperation()->getOpResults()[0], *getSelected()); + return DiagnosedSilenceableFailure::success(); + } + + if (getOptions().size() == 1) { + results.setParams(getOperation()->getOpResults()[0], getOptions()[0]); + return DiagnosedSilenceableFailure::success(); + } + + return emitDefiniteFailure() << "non-deterministic choice is only resolved " + "through providing a `selected` attr!"; } //===----------------------------------------------------------------------===// diff --git a/python/mlir/dialects/transform/tune.py b/python/mlir/dialects/transform/tune.py index d8ff218ac..fa39620f5 100644 --- a/python/mlir/dialects/transform/tune.py +++ b/python/mlir/dialects/transform/tune.py @@ -3,27 +3,101 @@ register_dialect_extension(get_dialect_registry()) -from ...ir import ArrayAttr, SymbolRefAttr, Attribute, Type -from .._tune_transform_ops_gen import TuneSelectOp +from ...ir import ( + ArrayAttr, + SymbolRefAttr, + Attribute, + Type, + StringAttr, + IntegerAttr, + IntegerType, + BoolAttr, +) +from .._tune_transform_ops_gen import * from collections.abc import Sequence -from typing import Union +from typing import Union, Optional def select( - selected: Type, # transform.any_param or transform.param<...> + result: Type, # transform.any_param or transform.param<...> name: Union[str, Attribute], - options: Union[ArrayAttr, Sequence[Attribute]], + options: Union[ArrayAttr, Sequence[Union[Attribute, str, int, bool]]], loc=None, ip=None, ) -> TuneSelectOp: if isinstance(name, str): name = SymbolRefAttr.get([name]) + if not isinstance(options, ArrayAttr): + option_attrs = [] + for option in options: + if isinstance(option, str): + option_attrs.append(StringAttr.get(option)) + elif isinstance(option, int): + int_type = IntegerType.get_signless(64) + option_attrs.append(IntegerAttr.get(int_type, option)) + elif isinstance(option, bool): + option_attrs.append(BoolAttr.get(option)) + elif isinstance(option, Attribute): + option_attrs.append(option) + options = ArrayAttr.get(option_attrs) + return TuneSelectOp( - selected=selected, + result=result, name=name, options=options, loc=loc, ip=ip, ) + + +def pick( + result: Type, # transform.any_param or transform.param<...> + name: Union[str, Attribute], + options: Union[ArrayAttr, Sequence[Union[Attribute, str, int, bool]]], + *, + selected: Optional[Union[Attribute, str, int, bool]] = None, + loc=None, + ip=None, +) -> TunePickOp: + if isinstance(name, str): + name = SymbolRefAttr.get([name]) + + if not isinstance(options, ArrayAttr): + option_attrs = [] + for option in options: + if isinstance(option, str): + option_attrs.append(StringAttr.get(option)) + elif isinstance(option, int): + int_type = IntegerType.get_signless(64) + option_attrs.append(IntegerAttr.get(int_type, option)) + elif isinstance(option, bool): + option_attrs.append(BoolAttr.get(option)) + elif isinstance(option, Attribute): + option_attrs.append(option) + else: + assert False + options = ArrayAttr.get(option_attrs) + + + if selected is None: + pass + elif isinstance(selected, str): + selected = StringAttr.get(selected) + elif isinstance(selected, int): + int_type = IntegerType.get_signless(64) + selected = IntegerAttr.get(int_type, selected) + elif isinstance(selected, bool): + selected = BoolAttr.get(selected) + elif not isinstance(selected, Attribute): + assert False + + return TunePickOp( + result=result, + name=name, + options=options, + selected=selected, + loc=loc, + ip=ip, + ) diff --git a/python/mlir/dialects/tune.py b/python/mlir/dialects/tune.py index 82ba6ca43..2e1c42e92 100644 --- a/python/mlir/dialects/tune.py +++ b/python/mlir/dialects/tune.py @@ -1,4 +1,4 @@ from .._mlir_libs import get_dialect_registry -from .._mlir_libs._tppDialects.tune import register_dialect +from .._mlir_libs._tppDialects.tune import * register_dialect(get_dialect_registry()) diff --git a/python/mlir/tpp/sched/bundles.py b/python/mlir/tpp/sched/bundles.py index 5edd1d713..7de196f39 100755 --- a/python/mlir/tpp/sched/bundles.py +++ b/python/mlir/tpp/sched/bundles.py @@ -1,8 +1,8 @@ -from typing import Optional, Sequence +from typing import Optional, Sequence, Union from mlir import ir from mlir.dialects import transform -from .common import apply_registered_pass, match +from .common import apply_registered_pass, match, select, pick from .utils import GpuBackend, PipelineInterrupt from ..xsmm import utils as xsmm_utils @@ -22,7 +22,14 @@ def cleanup(op, **_config): # TODO: make bundle into a NamedSequence to call with IncludeOp -def tpp_mapping(mod, lower_pack_unpack_without_transpose: bool = False, **_config): +def tpp_mapping( + mod, + lower_pack_unpack_without_transpose: bool = False, + pack_block_factors: Optional[ + Sequence[Union[Sequence[Union[int, ir.IntegerAttr]], int, ir.IntegerAttr]] + ] = None, + **_config, +): "High-level transforms that map operations to TPP-compatible forms." # Preprocess convolutions. @@ -34,7 +41,14 @@ def tpp_mapping(mod, lower_pack_unpack_without_transpose: bool = False, **_confi func = apply_registered_pass(func, "pack-conv2DNchwFchw") func = apply_registered_pass(func, "pack-conv2DNhwcHwcf") func = apply_registered_pass(func, "rewrite-conv-to-matmul-or-brgemm") - func = apply_registered_pass(func, "pack-matmul") + options = None + if pack_block_factors: + m_vals, n_vals, k_vals = pack_block_factors + m = select("m", m_vals if isinstance(m_vals, Sequence) else [m_vals]) + n = select("n", n_vals if isinstance(n_vals, Sequence) else [n_vals]) + k = pick("k", k_vals if isinstance(k_vals, Sequence) else [k_vals]) + options = {"block-factors": [m, n, k]} + func = apply_registered_pass(func, "pack-matmul", options=options) apply_registered_pass(func, "pack-vnni") if lower_pack_unpack_without_transpose: mod = apply_registered_pass(mod, "lower-packs-unpacks-without-transpose") diff --git a/python/mlir/tpp/sched/common.py b/python/mlir/tpp/sched/common.py index e51ebb3c5..965226c8c 100644 --- a/python/mlir/tpp/sched/common.py +++ b/python/mlir/tpp/sched/common.py @@ -1,5 +1,15 @@ from mlir.dialects import transform -from mlir.dialects.transform import structured +from mlir.dialects.transform import structured, tune + + +# Wrapper to addresss verbosity. +def select(*args, **kwargs): + return tune.select(transform.AnyParamType.get(), *args, **kwargs) + + +# Wrapper to addresss verbosity. +def pick(*args, **kwargs): + return tune.pick(transform.AnyParamType.get(), *args, **kwargs) # Wrapper to addresss verbosity. diff --git a/python/mlir/tpp/sched/main.py b/python/mlir/tpp/sched/main.py index be31fb2e5..072f22ad9 100644 --- a/python/mlir/tpp/sched/main.py +++ b/python/mlir/tpp/sched/main.py @@ -32,6 +32,15 @@ def comma_separated_ints(arg: str): "--payload", type=str, help="payload file to print with schedule" ) + def block_factors(arg: str): + m, n, k = arg.split(",") + + convert = lambda dim: list(map(int, dim)) + + return convert(m.split(";")), convert(n.split(";")), convert(k.split(";")) + + parser.add_argument("--pack-block-factors", type=block_factors, default=None) + parser.add_argument("--split-input-file", action="store_true") def comma_separated_bundles(arg: str): diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index f76b6a820..56f2a7365 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -2,5 +2,6 @@ add_subdirectory(mlir-gen) add_subdirectory(tpp-opt) add_subdirectory(tpp-run) add_subdirectory(tpp-sched) +add_subdirectory(tpp-tune) add_subdirectory(fpcmp) add_subdirectory(bench-ref) diff --git a/tools/tpp-tune/CMakeLists.txt b/tools/tpp-tune/CMakeLists.txt new file mode 100644 index 000000000..86aba32af --- /dev/null +++ b/tools/tpp-tune/CMakeLists.txt @@ -0,0 +1,9 @@ +file(MAKE_DIRECTORY + ${CMAKE_BINARY_DIR}/bin) +file(CREATE_LINK + ${CMAKE_CURRENT_SOURCE_DIR}/tpp-tune.py + ${CMAKE_BINARY_DIR}/bin/tpp-tune + SYMBOLIC) + + +add_custom_target(tpp-tune DEPENDS ${CMAKE_BINARY_DIR}/bin/tpp-tune TPPPythonModules) diff --git a/tools/tpp-tune/tpp-tune.py b/tools/tpp-tune/tpp-tune.py new file mode 100755 index 000000000..06daa84c1 --- /dev/null +++ b/tools/tpp-tune/tpp-tune.py @@ -0,0 +1,75 @@ +#!/usr/bin/env python3 + +import sys +from pathlib import Path +from typing import Union, Sequence, Dict +import random + +# Enable automagically finding TPP-MLIR's python modules (which include +# and extend MLIR's Python bindings). +python_packages_path = Path(__file__).parent.parent / "python_packages" +if python_packages_path.exists(): + sys.path = [str(python_packages_path)] + sys.path + + +from mlir import ir +from mlir.dialects import transform +from mlir.dialects.transform import tune as transform_tune + + +def walker(f): + def wrapper(op: Union[ir.OpView, ir.Operation]): + f(op) + for region in op.regions: + for block in region.blocks: + for child_op in block: + wrapper(child_op) + + return wrapper + + +def autotune(choices: Dict[str, Sequence[ir.Attribute]]) -> Dict[str, ir.Attribute]: + # Aint tuning easy!! + return {key: random.choice(values) for key, values in choices.items()} + + +file = sys.stdin +if len(sys.argv) > 1 and sys.argv[1] != "-": + file = open(sys.argv[1]) + + +with ir.Context(), ir.Location.unknown(): + schedule = ir.Module.parse(file.read()) + + choices = {} + + @walker + def choices_finder(op): + if isinstance(op, transform_tune.TuneSelectOp): + if op.name in choices: + raise RuntimeError(f"options name collision: {op.name} used twice") + choices[op.name] = tuple(op.options) + elif isinstance(op, transform_tune.TunePickOp): + if op.name in choices: + raise RuntimeError(f"options name collision: {op.name} used twice") + choices[op.name] = tuple(op.options) + + choices_finder(schedule.operation) + + selected = autotune(choices) + + @walker + def selected_rewriter(op: Union[ir.OpView, ir.Operation]): + if isinstance(op, transform_tune.TuneSelectOp): + with ir.InsertionPoint(op): + param = transform.param_constant( + transform.AnyParamType.get(), selected[op.name] + ) + for use in op.result.uses: + use.owner.operands[use.operand_number] = param + elif isinstance(op, transform_tune.TunePickOp): + op.attributes["selected"] = selected[op.name] + + selected_rewriter(schedule.operation) + + print(schedule)