Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/Linalg.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinDialect.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Diagnostics.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/ImplicitLocOpBuilder.h"
#include "mlir/IR/TypeUtilities.h"
Expand All @@ -26,6 +27,9 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Interfaces/ViewLikeInterface.h"

#include "llvm/ADT/STLFunctionalExtras.h"

#include <optional>

namespace mlir {
Expand Down
3 changes: 2 additions & 1 deletion mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ def Linalg_Dialect : Dialect {
kMemoizedIndexingMapsAttrName = "linalg.memoized_indexing_maps";

using RegionBuilderFunType = llvm::function_ref<
void(ImplicitLocOpBuilder &b, Block &, ArrayRef<NamedAttribute>)>;
void(ImplicitLocOpBuilder &b, Block &, ArrayRef<NamedAttribute>,
function_ref<InFlightDiagnostic()>)>;
RegionBuilderFunType getRegionBuilder(StringRef name) {
return namedStructuredOpRegionBuilders.lookup(name);
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -720,7 +720,7 @@ def LinalgStructuredInterface
Returns a null function if this named op does not define a region
builder.
}],
/*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>)>",
/*retTy=*/"std::function<void(ImplicitLocOpBuilder &, Block &, ArrayRef<NamedAttribute>, function_ref<InFlightDiagnostic()>)>",
/*methodName=*/"getRegionBuilder",
(ins),
[{ return ConcreteOp::getRegionBuilder(); }]
Expand Down
50 changes: 33 additions & 17 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,8 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
}

static std::function<void(ImplicitLocOpBuilder &,
Block &, ArrayRef<NamedAttribute>)>
Block &, ArrayRef<NamedAttribute>,
function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return nullptr;
}
Expand Down Expand Up @@ -300,7 +301,8 @@ def MapOp : LinalgStructuredBase_Op<"map", [
}

static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
mlir::ArrayRef<mlir::NamedAttribute>,
function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return nullptr;
}
Expand Down Expand Up @@ -380,7 +382,8 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [

// Implement functions necessary for DestinationStyleOpInterface.
static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
mlir::ArrayRef<mlir::NamedAttribute>,
function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return nullptr;
}
Expand Down Expand Up @@ -449,13 +452,14 @@ def TransposeOp : LinalgStructuredBase_Op<"transpose", [
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }

static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
mlir::ArrayRef<mlir::NamedAttribute>) {
mlir::ArrayRef<mlir::NamedAttribute>, function_ref<InFlightDiagnostic()> emitError) {
OpBuilder::InsertionGuard guard(b);
b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
}

static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
mlir::ArrayRef<mlir::NamedAttribute>,
function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return regionBuilder;
}
Expand Down Expand Up @@ -521,13 +525,15 @@ def BroadcastOp : LinalgStructuredBase_Op<"broadcast", [
MutableOperandRange getDpsInitsMutable() { return getInitMutable(); }

static void regionBuilder(mlir::ImplicitLocOpBuilder &b, mlir::Block &block,
mlir::ArrayRef<mlir::NamedAttribute>) {
mlir::ArrayRef<mlir::NamedAttribute>,
function_ref<InFlightDiagnostic()> emitError) {
OpBuilder::InsertionGuard guard(b);
b.create<linalg::YieldOp>(b.getLoc(), block.getArgument(0));
}

static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
mlir::ArrayRef<mlir::NamedAttribute>)>
mlir::ArrayRef<mlir::NamedAttribute>,
function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return regionBuilder;
}
Expand Down Expand Up @@ -631,10 +637,12 @@ def ElementwiseOp : LinalgStructuredBase_Op<"elementwise", [
/// Implements the block region builder for the elementwiseOp. This is
/// called by the 'fillStructuredOpRegion'.
static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);
Block &block, ArrayRef<NamedAttribute> attrs,
function_ref<InFlightDiagnostic()> emitError);

static std::function<void(ImplicitLocOpBuilder &,
Block &, ArrayRef<NamedAttribute>)>
Block &, ArrayRef<NamedAttribute>,
function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return regionBuilder;
}
Expand Down Expand Up @@ -771,7 +779,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [

/// Implements the block region builder.
static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);
Block &block, ArrayRef<NamedAttribute> attrs,
function_ref<InFlightDiagnostic()> emitError);

/// Returns a list of AffineMap with the default matmul indexing charactristic.
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
Expand All @@ -780,7 +789,8 @@ def MatmulOp : LinalgStructuredBase_Op<"matmul", [
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap);

static std::function<void(ImplicitLocOpBuilder &,
Block &, ArrayRef<NamedAttribute>)>
Block &, ArrayRef<NamedAttribute>,
function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return regionBuilder;
}
Expand Down Expand Up @@ -916,10 +926,12 @@ def ContractOp : LinalgStructuredBase_Op<"contract", [
static unsigned getNumRegionArgs();

static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);
Block &block, ArrayRef<NamedAttribute> attrs,
function_ref<InFlightDiagnostic()> emitError);

static std::function<void(ImplicitLocOpBuilder &,
Block &, ArrayRef<NamedAttribute>)>
Block &, ArrayRef<NamedAttribute>,
function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return regionBuilder;
}
Expand Down Expand Up @@ -1033,9 +1045,11 @@ def BatchMatmulOp : LinalgStructuredBase_Op<"batch_matmul", !listconcat([AttrSiz

SmallVector<utils::IteratorType> getIteratorTypesArray();
static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);
Block &block, ArrayRef<NamedAttribute> attrs,
function_ref<InFlightDiagnostic()> emitError);
static std::function<void(ImplicitLocOpBuilder &,
Block &, ArrayRef<NamedAttribute>)>
Block &, ArrayRef<NamedAttribute>,
function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return regionBuilder;
}
Expand Down Expand Up @@ -1161,7 +1175,8 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [

/// Implements the block region builder.
static void regionBuilder(ImplicitLocOpBuilder &b,
Block &block, ArrayRef<NamedAttribute> attrs);
Block &block, ArrayRef<NamedAttribute> attrs,
function_ref<InFlightDiagnostic()> emitError);

/// Returns a list of AffineMap with the default batch_reduce_matmul indexing charactristic.
static SmallVector<AffineMap> getDefaultIndexingMaps(MLIRContext *context);
Expand All @@ -1170,7 +1185,8 @@ def BatchReduceMatmulOp : LinalgStructuredBase_Op<"batch_reduce_matmul", [
bool isValidLhsRhsBroadcastMap(AffineMap bcastMap, bool isLHS = true);

static std::function<void(ImplicitLocOpBuilder &,
Block &, ArrayRef<NamedAttribute>)>
Block &, ArrayRef<NamedAttribute>,
function_ref<InFlightDiagnostic()>)>
getRegionBuilder() {
return regionBuilder;
}
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/CAPI/Dialect/Linalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void mlirLinalgFillBuiltinNamedOpRegion(MlirOperation mlirOp) {
Region &region = op->getRegion(0);
Block *body = b.createBlock(&region, /*insertPt=*/{}, argTypes, argLocs);
b.setInsertionPointToStart(body);
fun(b, *body, op->getAttrs());
fun(b, *body, op->getAttrs(), /*emitError=*/{});
}

MLIR_CAPI_EXPORTED bool mlirLinalgIsAContractionOp(MlirOperation op) {
Expand Down
Loading