Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Transform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ add_subdirectory(IRDLExtension)
add_subdirectory(LoopExtension)
add_subdirectory(PDLExtension)
add_subdirectory(Transforms)
add_subdirectory(TuneExtension)
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
set(LLVM_TARGET_DEFINITIONS TuneExtensionOps.td)
mlir_tablegen(TuneExtensionOps.h.inc -gen-op-decls)
mlir_tablegen(TuneExtensionOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRTransformDialectTuneExtensionOpsIncGen)

add_mlir_doc(TuneExtensionOps TuneExtensionOps Dialects/ -gen-op-doc)
21 changes: 21 additions & 0 deletions mlir/include/mlir/Dialect/Transform/TuneExtension/TuneExtension.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
//===- TuneExtension.h - Tune extension for Transform dialect ---*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H
#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H

namespace mlir {
class DialectRegistry;

namespace transform {
/// Registers the tune extension of the Transform dialect in the given registry.
void registerTuneExtension(DialectRegistry &dialectRegistry);
} // namespace transform
} // namespace mlir

#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSION_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
//===- TuneExtensionOps.h - Tune ext. for Transform dialect -----*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H

#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/OpDefinition.h"

#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h.inc"

#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS_H
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
//===- TuneExtensionOps.td - Transform dialect operations --*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS
#define MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS

include "mlir/Dialect/Transform/IR/TransformDialect.td"
include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/BuiltinAttributes.td"
include "mlir/IR/CommonAttrConstraints.td"

def KnobOp : Op<Transform_Dialect, "tune.knob", [
DeclareOpInterfaceMethods<TransformOpInterface>,
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
]> {
let summary = "Represents a tunable parameter with a set of options";

let description = [{
Provides a representation for "tunables" within schedules.

Each op represents a single tunable, which has a `name` and a set
of valid `options` described by an attribute. Without a specified
`selected` option, this op represents a non-deterministic choice
that has yet to be resolved -- as such, the interpreter runtime
semantics is to raise a failure.

The non-deterministic choice is resolved through providing a
`selected` attribute. When provided, the interpreter runtime
semantics are to return the `selected` attribute as a param through
the op's result.

-----

In case the `options` attribute is an `ArrayAttr`, the verifier checks that the provided `selected` attribute occurs in `options`.
}];
let cppNamespace = [{ mlir::transform::tune }];
let hasVerifier = 1;

let arguments = (ins Builtin_StringAttr:$name,
AnyAttr:$options,
OptionalAttr<AnyAttr>:$selected);
let results = (outs TransformParamTypeInterface:$result);

let assemblyFormat =
"`<` $name `>` (`=` $selected^ `from`)? `options` `=` $options attr-dict `->` type(results)";
}

#endif // MLIR_DIALECT_TRANSFORM_TUNEEXTENSION_TUNEEXTENSIONOPS
2 changes: 2 additions & 0 deletions mlir/include/mlir/InitAllExtensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
#include "mlir/Dialect/Transform/IRDLExtension/IRDLExtension.h"
#include "mlir/Dialect/Transform/LoopExtension/LoopExtension.h"
#include "mlir/Dialect/Transform/PDLExtension/PDLExtension.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"
#include "mlir/Dialect/Vector/TransformOps/VectorTransformOps.h"
#include "mlir/Target/LLVMIR/Dialect/Builtin/BuiltinToLLVMIRTranslation.h"
#include "mlir/Target/LLVMIR/Dialect/GPU/GPUToLLVMIRTranslation.h"
Expand Down Expand Up @@ -107,6 +108,7 @@ inline void registerAllExtensions(DialectRegistry &registry) {
transform::registerIRDLExtension(registry);
transform::registerLoopExtension(registry);
transform::registerPDLExtension(registry);
transform::registerTuneExtension(registry);
vector::registerTransformDialectExtension(registry);
arm_neon::registerTransformDialectExtension(registry);
arm_sve::registerTransformDialectExtension(registry);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Transform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@ add_subdirectory(IRDLExtension)
add_subdirectory(LoopExtension)
add_subdirectory(PDLExtension)
add_subdirectory(Transforms)
add_subdirectory(TuneExtension)
add_subdirectory(Utils)
12 changes: 12 additions & 0 deletions mlir/lib/Dialect/Transform/TuneExtension/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
add_mlir_dialect_library(MLIRTransformTuneExtension
TuneExtension.cpp
TuneExtensionOps.cpp

DEPENDS
MLIRTransformDialectTuneExtensionOpsIncGen

LINK_LIBS PUBLIC
MLIRIR
MLIRTransformDialect
MLIRTransforms
)
32 changes: 32 additions & 0 deletions mlir/lib/Dialect/Transform/TuneExtension/TuneExtension.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//===- TuneExtension.cpp - Tune extension for the Transform dialect -------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/TuneExtension/TuneExtension.h"

#include "mlir/Dialect/Transform/IR/TransformDialect.h"
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"
#include "mlir/IR/DialectRegistry.h"

using namespace mlir;

class TuneExtension
: public transform::TransformDialectExtension<TuneExtension> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TuneExtension)

void init() {
registerTransformOps<
#define GET_OP_LIST
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"
>();
}
};

void mlir::transform::registerTuneExtension(DialectRegistry &dialectRegistry) {
dialectRegistry.addExtensions<TuneExtension>();
}
62 changes: 62 additions & 0 deletions mlir/lib/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
//===- TuneExtensionOps.cpp - Tune extension for the Transform dialect ----===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Transform/IR/TransformOps.h"
#include "mlir/Dialect/Transform/Interfaces/TransformInterfaces.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/Support/Debug.h"

#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.h"

using namespace mlir;

#define GET_OP_CLASSES
#include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.cpp.inc"

#define DEBUG_TYPE "transform-tune"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "] ")

//===----------------------------------------------------------------------===//
// KnobOp
//===----------------------------------------------------------------------===//

void transform::tune::KnobOp::getEffects(
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
producesHandle(getOperation()->getOpResults(), effects);
onlyReadsPayload(effects);
}

DiagnosedSilenceableFailure
transform::tune::KnobOp::apply(transform::TransformRewriter &rewriter,
transform::TransformResults &results,
transform::TransformState &state) {
if (getSelected()) {
results.setParams(getOperation()->getOpResults()[0], *getSelected());
return DiagnosedSilenceableFailure::success();
}

return emitDefiniteFailure()
<< "non-deterministic choice " << getName()
<< " is only resolved through providing a `selected` attr";
}

LogicalResult transform::tune::KnobOp::verify() {
if (auto selected = getSelected()) {
if (auto optionsArray = dyn_cast<ArrayAttr>(getOptions())) {
if (!llvm::is_contained(optionsArray, selected))
return emitOpError("provided `selected` attribute is not an element of "
"`options` array of attributes");
} else
LLVM_DEBUG(DBGS() << "cannot verify `selected` attribute " << selected
<< " is an element of `options` attribute "
<< getOptions());
}

return success();
}
9 changes: 9 additions & 0 deletions mlir/python/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,15 @@ ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
DIALECT_NAME transform
EXTENSION_NAME transform_debug_extension)

declare_mlir_dialect_extension_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
TD_FILE dialects/TransformTuneExtensionOps.td
SOURCES
dialects/transform/tune.py
DIALECT_NAME transform
EXTENSION_NAME transform_tune_extension)

declare_mlir_dialect_python_bindings(
ADD_TO_PARENT MLIRPythonSources.Dialects
ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/mlir"
Expand Down
18 changes: 18 additions & 0 deletions mlir/python/mlir/dialects/TransformTuneExtensionOps.td
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
//===-- TransformTuneExtensionOps.td - Binding entry point -*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Entry point of the generated Python bindings for the Tune extension of the
// Transform dialect.
//===----------------------------------------------------------------------===//

#ifndef PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
#define PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS

include "mlir/Dialect/Transform/TuneExtension/TuneExtensionOps.td"

#endif // PYTHON_BINDINGS_TRANSFORM_DEBUG_EXTENSION_OPS
82 changes: 82 additions & 0 deletions mlir/python/mlir/dialects/transform/tune.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from typing import Optional, Sequence

from ...ir import (
Type,
Attribute,
ArrayAttr,
StringAttr,
F64Type,
IntegerType,
IntegerAttr,
FloatAttr,
BoolAttr,
)
from .._transform_tune_extension_ops_gen import *
from .._transform_tune_extension_ops_gen import _Dialect

try:
from .._ods_common import _cext as _ods_cext
except ImportError as e:
raise RuntimeError("Error loading imports from extension module") from e

from typing import Union


@_ods_cext.register_operation(_Dialect, replace=True)
class KnobOp(KnobOp):
def __init__(
self,
result: Type, # !transform.any_param or !transform.param<Type>
name: Union[StringAttr, str],
options: Union[
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
],
*,
selected: Optional[Attribute] = None,
loc=None,
ip=None,
):
if isinstance(name, str):
name = StringAttr.get(name)

def map_to_attr(value):
if isinstance(value, bool):
return BoolAttr.get(value)
if isinstance(value, int):
return IntegerAttr.get(IntegerType.get_signless(64), value)
if isinstance(value, float):
return FloatAttr.get(F64Type.get(), value)
if isinstance(value, str):
return StringAttr.get(value)
assert isinstance(value, Attribute)
return value

if isinstance(options, Sequence) and not isinstance(options, ArrayAttr):
options = ArrayAttr.get([map_to_attr(opt) for opt in options])

super().__init__(
result,
name,
options,
selected=selected and map_to_attr(selected),
loc=loc,
ip=ip,
)


def knob(
result: Type, # !transform.any_param or !transform.param<Type>
name: Union[StringAttr, str],
options: Union[
ArrayAttr, Sequence[Union[Attribute, bool, int, float, str]], Attribute
],
*,
selected: Optional[Attribute] = None,
loc=None,
ip=None,
):
return KnobOp(result, name, options, selected=selected, loc=loc, ip=ip)
21 changes: 21 additions & 0 deletions mlir/test/Dialect/Transform/test-tune-extension-invalid.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// RUN: mlir-opt %s --transform-interpreter --split-input-file --verify-diagnostics

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
// expected-error@below {{provided `selected` attribute is not an element of `options` array of attributes}}
%heads_or_tails = transform.tune.knob<"coin"> = 1 from options = [true, false] -> !transform.any_param
transform.yield
}
}

// -----

func.func private @f()

module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
// expected-error@below {{non-deterministic choice "coin" is only resolved through providing a `selected` attr}}
%heads_or_tails = transform.tune.knob<"coin"> options = [true, false] -> !transform.any_param
transform.yield
}
}
Loading