Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.h
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,16 @@ bool isaConvolutionOpInterface(LinalgOp linalgOp,
/// Checks whether `linalgOp` is semantically equivalent to a `linalg.copyOp`.
bool isaCopyOpInterface(LinalgOp linalgOp);

/// Checks whether `genericOp` is semantically equivalent to a
/// `linalg.broadcast`. Returns broadcast dimensions if true.
std::optional<SmallVector<int64_t>>
isaBroadcastOpInterface(GenericOp genericOp);

/// Checks whether `genericOp` is semantically equivalent to a
/// `linalg.transpose`. Returns permuted dimensions if true.
std::optional<SmallVector<int64_t>>
isaTransposeOpInterface(GenericOp genericOp);

/// Checks whether a given `genericOp` is semantically equivalent to a single
/// linalgelementwise unary op. e.g. linalg.exp.
/// A linalg.generic body could be a series of unary elementwise ops e.g.
Expand Down
115 changes: 100 additions & 15 deletions mlir/lib/Dialect/Linalg/IR/LinalgInterfaces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include "llvm/ADT/SmallBitVector.h"
#include "llvm/ADT/SmallVector.h"
#include <algorithm>
#include <numeric>

using namespace mlir;
using namespace mlir::linalg;
Expand Down Expand Up @@ -49,18 +50,41 @@ bool linalg::detail::canOpOperandsBeDroppedImpl(
return inversePermutation(concatAffineMaps(indexingMaps)) != AffineMap();
}

// Returns true if all loops of the linalgOp are parallel
static bool isAllParallel(LinalgOp op) {
return op.getNumParallelLoops() == op.getNumLoops();
}

// Returns true if and only if linalgOp takes one input and one init.
static bool isSingleInputOutput(LinalgOp op) {
return op.getNumDpsInputs() == 1 && op.getNumDpsInits() == 1;
}
// Returns true if genericOp body is just a yieldOp that yields
// input operand as result.
static bool isSingleYieldOp(GenericOp op) {
if (op.getNumDpsInputs() != 1 || op.getNumDpsInits() != 1)
return false;

Block *body = op.getBody();
if (body->getOperations().size() != 1)
return false;

auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
yieldOp->getOperand(0) != body->getArgument(0))
return false;
return true;
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do these belong here? IIUC, the comment above ("Interface utility functions") refers to ODS/TableGen "interfaces" (i.e. none of these is a InterfaceMethod).

Having said that, why not add them to the interface?


//===----------------------------------------------------------------------===//
// CopyOpInterface implementation
//===----------------------------------------------------------------------===//

bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
// Structural.
if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops())
// Structural and operands
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] I don't understand this comment :) appreciate that you are effectively inheriting this, but let's clarify. Does "Structural and operands" mean "Check the structure (no parallel dims?) and the operands (single input/output?)".

Also trying to make sure I understand 😅

if (!isAllParallel(linalgOp) || !isSingleInputOutput(linalgOp))
return false;

// Operands and maps.
if (linalgOp.getNumDpsInputs() != 1 || linalgOp.getNumDpsInits() != 1)
return false;
auto mapRange = linalgOp.getIndexingMapsArray();
if (mapRange.size() != 2 || !mapRange.front().isIdentity() ||
!mapRange.back().isIdentity()) {
Expand All @@ -75,8 +99,8 @@ bool linalg::isaCopyOpInterface(LinalgOp linalgOp) {
//===----------------------------------------------------------------------===//
std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
// Structural.
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
genericOp.getNumDpsInputs() != 1 || genericOp.getNumDpsInits() != 1)
if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
!isSingleYieldOp(genericOp))
return std::nullopt;

// Input should be referenced and init should not.
Expand All @@ -87,16 +111,78 @@ std::optional<Value> linalg::isaFillOpInterface(GenericOp genericOp) {
OpOperand *value = genericOp.getDpsInputOperand(0);
if (!genericOp.isScalar(value))
return std::nullopt;
return value->get();
}

Block *body = genericOp.getBody();
if (body->getOperations().size() != 1)
//===----------------------------------------------------------------------===//
// BroadcastOpInterface implementation
//===----------------------------------------------------------------------===//
std::optional<SmallVector<int64_t>>
linalg::isaBroadcastOpInterface(GenericOp genericOp) {
// Structural.
if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
!isSingleYieldOp(genericOp))
return std::nullopt;

auto yieldOp = dyn_cast<linalg::YieldOp>(body->back());
if (!yieldOp || yieldOp.getNumOperands() != 1 ||
yieldOp->getOperand(0) != body->getArgument(0))
auto t0 = genericOp.getDpsInputOperand(0)->get().getType();
auto t1 = genericOp.getDpsInitOperand(0)->get().getType();
if (!isa<MemRefType, RankedTensorType>(t0) ||
!isa<MemRefType, RankedTensorType>(t1))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What else could it be? Perhaps checking for ShapedType would be enough?

return std::nullopt;
return value->get();

// Check output is identity map. Injective function could also be
// a permutation of indices and expressible in linalg.generic but
// is not expressible for named broadcast op.
auto dstMap = genericOp.getIndexingMapsArray()[1];
if (!dstMap.isIdentity())
return std::nullopt;

SmallVector<int64_t> position;
auto srcMap = genericOp.getIndexingMapsArray()[0];

// Check input map is monotonically increasing DimIds.
for (unsigned i = 0; i < srcMap.getNumResults(); ++i) {
auto expr = llvm::dyn_cast<AffineDimExpr>(srcMap.getResults()[i]);
if (!expr)
return std::nullopt;
int64_t pos = expr.getPosition();
if (i > 0 && pos <= position[i - 1])
return std::nullopt;
position.push_back(expr.getPosition());
}

SmallVector<int64_t> broadcastedDims;
auto numDims = srcMap.getNumDims();
for (auto dim : llvm::seq<int64_t>(0, numDims)) {
if (!llvm::is_contained(position, dim))
broadcastedDims.push_back(dim);
}
return broadcastedDims;
}

//===----------------------------------------------------------------------===//
// TranposeOpInterface implementation
//===----------------------------------------------------------------------===//
std::optional<SmallVector<int64_t>>
linalg::isaTransposeOpInterface(GenericOp genericOp) {
// Structural.
if (!isAllParallel(genericOp) || !isSingleInputOutput(genericOp) ||
!isSingleYieldOp(genericOp))
return std::nullopt;

// mapping checks.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit]

Suggested change
// mapping checks.
// Check the maps.

auto mapRange = genericOp.getIndexingMapsArray();
if (mapRange.size() != 2 || !mapRange.back().isIdentity() ||
!mapRange.front().isPermutation())
return std::nullopt;

SmallVector<int64_t> permutation;
auto map = mapRange.front();
for (unsigned i = 0; i < map.getNumResults(); ++i) {
auto expr = llvm::cast<AffineDimExpr>(map.getResults()[i]);
permutation.push_back(expr.getPosition());
}
return permutation;
}

//===----------------------------------------------------------------------===//
Expand All @@ -106,8 +192,7 @@ static bool
isaElemwiseSingleUnaryOrBinaryOpInterface(linalg::GenericOp genericOp,
unsigned arity) {
// Check all loops are parallel.
if (genericOp.getNumParallelLoops() != genericOp.getNumLoops() ||
genericOp.getNumLoops() < 1)
if (!isAllParallel(genericOp) || genericOp.getNumLoops() < 1)
return false;

// Check there are arity-inputs, 1-output and all are identity-maps.
Expand Down
27 changes: 27 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -259,18 +259,43 @@ static FailureOr<LinalgOp> specializeLinalgContractions(RewriterBase &rewriter,
//===----------------------------------------------------------------------===//
FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
GenericOp genericOp) {
// Copy
if (isaCopyOpInterface(genericOp)) {
LinalgOp namedOp = rewriter.replaceOpWithNewOp<CopyOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
return namedOp;
}

// Fill
if (isaFillOpInterface(genericOp)) {
LinalgOp namedOp = rewriter.replaceOpWithNewOp<FillOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0]);
return namedOp;
}

// Broadcast
std::optional<SmallVector<int64_t>> equivalentToBroadcast =
isaBroadcastOpInterface(genericOp);
if (equivalentToBroadcast) {
auto dims = *equivalentToBroadcast;
LinalgOp namedOp = rewriter.replaceOpWithNewOp<BroadcastOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
dims);
return namedOp;
}

// Transpose
std::optional<SmallVector<int64_t>> equivalentToTranspose =
isaTransposeOpInterface(genericOp);
if (equivalentToTranspose) {
auto permutation = *equivalentToTranspose;
LinalgOp namedOp = rewriter.replaceOpWithNewOp<TransposeOp>(
genericOp, genericOp.getDpsInputs()[0], genericOp.getDpsInits()[0],
permutation);
return namedOp;
}

// Elementwise Unary
if (isaElemwiseSingleUnaryOpInterface(genericOp)) {
Operation *op = &genericOp.getBody()->front();
if (isa<math::ExpOp>(op)) {
Expand All @@ -279,6 +304,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
}
}

// Elementwise Binary
if (isaElemwiseSingleBinaryOpInterface(genericOp)) {
bool swap = areBinOpsSwapped(genericOp);
Operation *op = &genericOp.getBody()->front();
Expand All @@ -300,6 +326,7 @@ FailureOr<LinalgOp> mlir::linalg::specializeGenericOp(RewriterBase &rewriter,
}
}

// Contraction - e.g. matmul
if (isaContractionOpInterface(genericOp)) {
return specializeLinalgContractions(rewriter, genericOp);
}
Expand Down
32 changes: 32 additions & 0 deletions mlir/test/Dialect/Linalg/roundtrip-broadcast.mlir
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why a new file instead of re-using roundtrip.mlir? Note that this file is called "roundtrip-broadcast.mlir", but it test both broadcasts and transposes.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubled checked transpose and broadcast are in separate file e.g. the linalg.transpose are in roundtrip-transpose.mlir. It may be that in the browser it is appearing mashed up across comments.

Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s

// CHECK-LABEL: broadcast_first_dimension
// CHECK-SAME: %[[A:.+]]: tensor<?x?xf32>, %[[Out:.+]]: tensor<?x?x?xf32>)
// CHECK-NOT: linalg.generic
// CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<?x?xf32>) outs(%[[Out]] : tensor<?x?x?xf32>) dimensions = [0]
//
func.func @broadcast_first_dimension(%A: tensor<?x?xf32>, %Out: tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%res = linalg.broadcast ins(%A: tensor<?x?xf32>) outs(%Out: tensor<?x?x?xf32>) dimensions = [0]
return %res : tensor<?x?x?xf32>
}

// CHECK-LABEL: broadcast_mid_dimension
// CHECK-SAME: %[[A:.+]]: tensor<3x5xf32>, %[[Out:.+]]: tensor<3x4x5xf32>)
// CHECK-NOT: linalg.generic
// CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<3x5xf32>) outs(%[[Out]] : tensor<3x4x5xf32>) dimensions = [1]
//
func.func @broadcast_mid_dimension(%A: tensor<3x5xf32>, %Out: tensor<3x4x5xf32>) -> tensor<3x4x5xf32> {
%res = linalg.broadcast ins(%A: tensor<3x5xf32>) outs(%Out: tensor<3x4x5xf32>) dimensions = [1]
return %res : tensor<3x4x5xf32>
}


// CHECK-LABEL: broadcast_multiple_dimensions
// CHECK-SAME: %[[A:.+]]: tensor<4x5x7xf32>, %[[Out:.+]]: tensor<3x4x5x6x7x8x9xf32>)
// CHECK-NOT: linalg.generic
// CHECK: %broadcasted = linalg.broadcast ins(%[[A]] : tensor<4x5x7xf32>) outs(%[[Out]] : tensor<3x4x5x6x7x8x9xf32>) dimensions = [0, 3, 5, 6]
//
func.func @broadcast_multiple_dimensions(%A: tensor<4x5x7xf32>, %Out: tensor<3x4x5x6x7x8x9xf32>) -> tensor<3x4x5x6x7x8x9xf32> {
%res = linalg.broadcast ins(%A: tensor<4x5x7xf32>) outs(%Out: tensor<3x4x5x6x7x8x9xf32>) dimensions = [0,3,5,6]
return %res : tensor<3x4x5x6x7x8x9xf32>
}
11 changes: 11 additions & 0 deletions mlir/test/Dialect/Linalg/roundtrip-transpose.mlir
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps add a 3d, 1d cases? And identity?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aded 3D test. Thanks for the suggestion.
1D test on transpose will just get dce-d out.

Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: mlir-opt %s -linalg-generalize-named-ops | mlir-opt --linalg-specialize-generic-ops | FileCheck %s

// CHECK-LABEL: linalg_transpose
// CHECK-SAME: %[[A:.+]]: tensor<16x64xf32>, %[[Out:.+]]: tensor<64x16xf32>
// CHECK-NOT: linalg.generic
// CHECK: %transposed = linalg.transpose ins(%[[A]] : tensor<16x64xf32>) outs(%[[Out]] : tensor<64x16xf32>) permutation = [1, 0]
//
func.func @linalg_transpose(%A: tensor<16x64xf32>, %Out: tensor<64x16xf32>) -> tensor<64x16xf32> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nit] Skip linalg in func name (repeating info already available)

%res = linalg.transpose ins(%A: tensor<16x64xf32>) outs(%Out: tensor<64x16xf32>) permutation = [1,0]
return %res : tensor<64x16xf32>
}
12 changes: 0 additions & 12 deletions mlir/test/Dialect/Linalg/transform-op-specialize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,6 @@
#map1 = affine_map<(d0, d1) -> (d0)>
#map2 = affine_map<(d0, d1) -> (d1, d0)>

func.func @broadcast_copy_expect_no_match(%arg0: memref<?xf32>, %arg1: memref<?x?xf32>) {
// expected-note @below {{when applied to this op}}
linalg.generic {
indexing_maps = [#map1, #map],
iterator_types = ["parallel", "parallel"]}
ins(%arg0 : memref<?xf32>) outs(%arg1 : memref<?x?xf32>) {
^bb0(%in: f32, %out: f32):
linalg.yield %in : f32
}
return
}

func.func @not_a_copy_expect_no_match(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>) {
// expected-note @below {{when applied to this op}}
linalg.generic {
Expand Down
Loading