Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def LinalgStructuredInterface
/*args=*/(ins),
/*methodBody=*/"",
/*defaultImplementation=*/[{
return getNumParallelLoops() == getNumParallelLoops();
return getNumParallelLoops() == getNumLoops();
}]
>,
InterfaceMethod<
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/Dialect/Linalg/Transforms/Transforms.h
Original file line number Diff line number Diff line change
Expand Up @@ -1786,6 +1786,10 @@ void populateBubbleUpExtractSliceOpPatterns(RewritePatternSet &patterns);
/// linalg.fill(%cst, tensor.extract_slice(%init)).
void populateSwapExtractSliceWithFillPatterns(RewritePatternSet &patterns);

/// Add patterns to make explicit broadcasts and transforms in the
/// input operands of a genericOp.
void populateUnfoldProjectedPermutationPatterns(RewritePatternSet &patterns);

/// Patterns to apply `splitReduction` below.
void populateSplitReductionPattern(
RewritePatternSet &patterns,
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ add_mlir_dialect_library(MLIRLinalgTransforms
TilingInterfaceImpl.cpp
Transforms.cpp
TransposeConv2D.cpp
UnfoldProjectedPermutation.cpp
Vectorization.cpp
WinogradConv2D.cpp

Expand Down
1 change: 1 addition & 0 deletions mlir/lib/Dialect/Linalg/Transforms/Specialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,7 @@ struct LinalgSpecializeGenericOpsPass
void LinalgSpecializeGenericOpsPass::runOnOperation() {
RewritePatternSet patterns(&getContext());
populateLinalgGenericOpsSpecializationPatterns(patterns);
populateUnfoldProjectedPermutationPatterns(patterns);

if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns))))
signalPassFailure();
Expand Down
243 changes: 243 additions & 0 deletions mlir/lib/Dialect/Linalg/Transforms/UnfoldProjectedPermutation.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,243 @@
//===- UnfoldProjectedPermutation.cpp - extract projected projections ---===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
#include <utility>

#include "mlir/Dialect/Affine/IR/AffineOps.h"
#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include <map>
#include <optional>
#include <vector>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is vector needed?

You should also move <utility> to this section. I think if you remove the blank line (11) , clang-format would sort the includes in the right order for you :)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done


using namespace mlir;
using namespace mlir::linalg;

namespace {

/// This file implements pattern to decompose the input operand(s) of a
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// This file implements pattern to decompose the input operand(s) of a
/// This pattern decomposes the input operand(s) of a

? Isn't this comment attached to the pattern itself?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just followed the style of [1] which i really liked as right at the beginning it is explains the algorithm the file implements.
[1] https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Linalg/Transforms/DecomposeLinalgOps.cpp

/// linalg.generic that has a `transpose`, `broadcast` or a mixture of two,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// linalg.generic that has a `transpose`, `broadcast` or a mixture of two,
/// linalg.generic that has a `transpose`, `broadcast`, or a mixture of the two,

/// into explicit transpose and broadcast. Having them folded into the
/// linalg.generic is a good optimization but sometimes we may want to unwrap
/// i.e. `unfold` them as explicit transpose and broadcast. This rewrite
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
/// linalg.generic is a good optimization but sometimes we may want to unwrap
/// i.e. `unfold` them as explicit transpose and broadcast. This rewrite
/// linalg.generic is a good optimization but sometimes we may want to unwrap,
/// i.e., `unfold` them as explicit transpose and broadcast. This rewrite

/// pattern helps do it for each input operand. This is useful for instance
/// when trying to recognize named ops.
///
/// The transpose, broadcast, or mixture of both, are expressed in the affine
/// map of the operand. Technically it is essentially `projected permutation`.
///
/// Example
///
/// ```mlir
///
/// #projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>
/// #identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>
/// ...
/// %res = linalg.generic
/// { indexing_maps = [#projection, #identity, #identity],
/// iterator_types = ["parallel", "parallel", "parallel",
/// "parallel", "parallel"]}
/// ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>)
/// outs(%z : tensor<5x9x7x8x10xf32>) {
/// ^bb0(%in: f32, %in_1: f32, %out: f32):
/// %div = arith.divf %in, %in_1 : f32
/// linalg.yield %div : f32
/// } -> tensor<5x9x7x8x10xf32>
/// ```
///
/// In the above IR operand `%x` map is a projected-permutation. This can be
/// unfolded as:
///
/// ```mlir
/// ...
/// %x_trans = linalg.transpose
/// ins(%x : tensor<7x8x9xf32>)
/// outs(%e1 : tensor<9x7x8xf32>) permutation = [2, 0, 1]
/// ...
/// %x_trans_bc = linalg.broadcast
/// ins(%x_trans : tensor<9x7x8xf32>)
/// outs(%e2 : tensor<5x9x7x8x10xf32>) dimensions = [0, 4]
/// %2 = linalg.div
/// ins(%x_trans_bc, %y :
/// tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>)
/// outs(%arg2 : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32>
///
/// Note that linalg.generic has been 'specialized' to linalg.div.
///
/// To unfold it is more effective to transpose first and then do the broadcast.
/// However, if transpose is done first, the permutation map needs to be
/// expressed in terms of reduced dimension (as broadcast hasn't happened yet).
/// Also, the broadcast dimensions in a linalg.generic come from other operands
/// (those not broadcasted along that particular dimension). We work this out
/// by computing the convex-polyhedron shape of the linalg.gneric iteration
/// space from shapes of all the operands (inputs and outputs).
///
struct UnfoldProjectedPermutation : public OpRewritePattern<GenericOp> {
using OpRewritePattern<GenericOp>::OpRewritePattern;

LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override;
};

/// For the given `map` determine what dimensions are transposed
/// and what dimensions are broadcasted.
/// Returns :
/// `isTransposed, isBroadcast,
/// transpose-permutation, broadcast-dimensions`
///
std::tuple<bool, bool, SmallVector<int64_t>, SmallVector<int64_t>>
computeTransposeBroadcast(AffineMap &map) {
assert(map.isProjectedPermutation(false) && "not a projection");
int64_t minorSize = map.getNumResults();

SmallVector<int64_t> minorResult;
for (int64_t i = 0; i < minorSize; ++i) {
auto expr = cast<AffineDimExpr>(map.getResults()[i]);
minorResult.push_back(expr.getPosition());
}

// If dims are not monotonically increasing then transpose is present.
SmallVector<int64_t> sortedResMap(minorResult);
std::sort(sortedResMap.begin(), sortedResMap.end());
bool hasTranspose = !std::equal(minorResult.begin(), minorResult.end(),
sortedResMap.begin(), sortedResMap.end());

// Walk the sorted map result to determine which dimensions are broadcasted.
SmallVector<int64_t> broadcast;
for (int64_t i = 0, j = 0; i < map.getNumInputs(); ++i) {
if (j < minorSize && sortedResMap[j] == i) {
j++;
continue;
}
broadcast.push_back(i);
}
bool hasBroadcast = !broadcast.empty();

/// Consider an operand `x : tensor<7x8x9>` of a genericOp that has
/// affine map `affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>`
/// `x`s access is both transposed and brodcast. But when specifying
/// the `linalg.transpose(x : tensor<7x8x9>)` the dimensions need to be
/// specified as `affine_map<(d0,d1,d2) -> (d1, d2, d0)` instead of
/// refering to d3, d4. Therefore, re-base the transpose dimensions so
/// that they start from d0.
std::map<int64_t, int64_t> minorMap;
for (int64_t i = 0; i < minorSize; ++i)
minorMap.insert({sortedResMap[i], i});

// Re-map the dimensions.
SmallVector<int64_t> remappedResult(minorSize);
for (int64_t i = 0; i < minorSize; ++i)
remappedResult[i] = minorMap[minorResult[i]];

/// Calculate the permutation for the transpose.
SmallVector<int64_t> permutation(minorSize);
for (unsigned i = 0; i < minorSize; ++i) {
permutation[remappedResult[i]] = i;
}

return {hasTranspose, hasBroadcast, permutation, broadcast};
}

LogicalResult
UnfoldProjectedPermutation::matchAndRewrite(GenericOp op,
PatternRewriter &rewriter) const {
if (!op.hasPureTensorSemantics() || op.isSingleInputOutput() ||
op.isSingleYieldOp() || !op.isAllParallelLoops())
return failure();

// All maps need to be projected permutations.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Intuitively this makes sense, but ... why? 😅 Which part would break?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping

for (auto &opOperand : op->getOpOperands()) {
auto map = op.getMatchingIndexingMap(&opOperand);
if (!map.isProjectedPermutation(false))
return failure();
}

// Currently we handle only static shapes.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this work at all for dynamic shapes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For a start this will assert when trying to create tensor.empty with dynamic shape. https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp#L874

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, rather than documenting what the code does, could you add a comment saying "why"? Or what's missing? From what you are saying, we'd need to add logic to compute dynamic sizes of the input tensors for ops like EmptyOp? And probably sth else as well?

for (auto &operand : op->getOpOperands()) {
auto opType = cast<RankedTensorType>(operand.get().getType());
for (auto size : opType.getShape())
if (size == ShapedType::kDynamic)
return failure();
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
for (auto &operand : op->getOpOperands()) {
auto opType = cast<RankedTensorType>(operand.get().getType());
for (auto size : opType.getShape())
if (size == ShapedType::kDynamic)
return failure();
}
if (!llvm::all_of(packOp->getOpOperands(), [](OpOperand &oper) {
auto opType = cast<RankedTensorType>(oper.get().getType());
return !ShapedType::isDynamicShape(opType.getShape());
}) return failure();

This way it's easier to emphasis the key logic rather than all the loops and if stmts. Please double check syntax 😅

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ping

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if (!llvm::all_of(packOp->getOpOperands(), [](OpOperand &oper) {
auto opType = cast(oper.get().getType());
return !ShapedType::isDynamicShape(opType.getShape());
}) return failure();

Thanks for this more elegant code. I guess you mean 'any_of'.


auto outputShape = op.getStaticLoopRanges();

auto loc = op.getLoc();
bool isChanged = false;
SmallVector<Value> newInitValues = op.getDpsInputs();
SmallVector<AffineMap> newMap = op.getIndexingMapsArray();

// Walk over each input operand and unfold if it is transposed, broadcast
// or mix of two via operand's affine-map.
for (int64_t i = 0; i < op.getNumDpsInputs(); ++i) {
auto &map = newMap[i];
auto inputRTType = cast<RankedTensorType>(newInitValues[i].getType());
auto elType = inputRTType.getElementType();

/// Nothing to do if map is already an identity.
if (map.isIdentity())
continue;

auto [hasTranspose, hasBroadcast, permutation, broadcastedDims] =
computeTransposeBroadcast(map);

if (hasTranspose) {
/// linalg.transpose permutes the dimensions of input using
/// rule: dim(result, i) = dim(input, permutation[i])
SmallVector<int64_t> transposedShape(map.getNumResults());
for (int64_t i = 0; i < map.getNumResults(); ++i)
transposedShape[i] = inputRTType.getShape()[permutation[i]];

Value emptyTensor =
rewriter.create<tensor::EmptyOp>(loc, transposedShape, elType);

auto transposeOp = rewriter.create<TransposeOp>(loc, newInitValues[i],
emptyTensor, permutation);
newInitValues[i] = transposeOp->getResult(0);
isChanged = true;
}

// Does it require broadcast
if (hasBroadcast) {
assert(broadcastedDims.size() && "should have non size broadcast");
Value emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, outputShape, inputRTType.getElementType());

auto broadcastOp = rewriter.create<linalg::BroadcastOp>(
loc, newInitValues[i], emptyTensor, broadcastedDims);

newInitValues[i] = broadcastOp->getResult(0);
isChanged = true;
}
newMap[i] = rewriter.getMultiDimIdentityMap(map.getNumDims());
}

if (isChanged) {
SmallVector<Value> operands = op->getOperands();
ValueRange operandsRef(operands);

auto newOp = rewriter.create<linalg::GenericOp>(
/*location=*/op.getLoc(),
/*resultTensorTypes=*/op->getResultTypes(),
/*inputs=*/newInitValues,
/*outputs=*/operandsRef.drop_front(op.getNumDpsInputs()),
/*indexingMaps=*/newMap,
/*iteratorTypes=*/op.getIteratorTypesArray());

newOp.getRegion().takeBody(op->getRegion(0));
rewriter.replaceOp(op, newOp->getResults());
}
return success();
}

} // namespace

void mlir::linalg::populateUnfoldProjectedPermutationPatterns(
RewritePatternSet &patterns) {
patterns.insert<UnfoldProjectedPermutation>(patterns.getContext());
}
71 changes: 71 additions & 0 deletions mlir/test/Dialect/Linalg/unfold-projected-permutation.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// RUN: mlir-opt %s -split-input-file --linalg-specialize-generic-ops | FileCheck %s

#projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>
#identity = affine_map<(d0, d1, d2, d3, d4) -> (d0, d1, d2, d3, d4)>

func.func @transpose_and_broadcast(%x : tensor<7x8x9xf32>, %y: tensor<5x9x7x8x10xf32>, %z : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> {
%res = linalg.generic
{ indexing_maps = [#projection, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]}
ins(%x, %y : tensor<7x8x9xf32>, tensor<5x9x7x8x10xf32>) outs(%z : tensor<5x9x7x8x10xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%div = arith.divf %in, %in_1 : f32
linalg.yield %div : f32
} -> tensor<5x9x7x8x10xf32>
return %res : tensor<5x9x7x8x10xf32>
}

// CHECK-LABEL: transpose_and_broadcast
// CHECK-SAME: %[[X:.+]]: tensor<7x8x9xf32>, %[[Y:.+]]: tensor<5x9x7x8x10xf32>, %[[Z:.+]]: tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32> {
// CHECK: %[[E0:.+]] = tensor.empty() : tensor<9x7x8xf32>
// CHECK: %[[X_trans:.+]] = linalg.transpose ins(%[[X]] : tensor<7x8x9xf32>) outs(%[[E0]] : tensor<9x7x8xf32>) permutation = [2, 0, 1]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would have expected this to be [1, 2, 0] but I am also assuming transpose keeps the input as the basis and describes how to permute the inputs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hm, (7, 8, 9) -> (9, 7, 8), which corresponds to (d2, d3, d1) -> (d1, d2, d3), which gives permutation = [2, 0, 1].

Now, I managed convince myself that this is correct, but please double check for yourself 😅

@MaheshRavishankar , you might be skewed by:

#projection = affine_map<(d0, d1, d2, d3, d4) -> (d2, d3, d1)>

I think this is the trick (IIUC, this is the actual mapping here):

  • 7 -> d2,
  • 8 -> d3,
  • 9 -> d1.

Whereas you assume that:

  • 7 -> d1,
  • 8 -> d2,
  • 9 -> d3.

Does it make sense?

Copy link
Contributor Author

@javedabsar1 javedabsar1 Nov 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes of course you are right! Two parts to this.

  1. Broadcast semantics is -
    dim(result, i) == dim(input, permutation[i])

Therefore, for input 7x8x9 and permutation [2,0,1] the output works therefore out to 9x7x8.

  1. working out the permutation from affine-map e.g. (d2, d3, d1) -> (d1, d2, d3) and vice-versa.

// CHECK: %[[E1:.+]] = tensor.empty() : tensor<5x9x7x8x10xf32>
// CHECK: %[[X_trans_bc:.+]] = linalg.broadcast ins(%[[X_trans]] : tensor<9x7x8xf32>) outs(%[[E1]] : tensor<5x9x7x8x10xf32>) dimensions = [0, 4]
// CHECK: {{.*}} = linalg.div ins(%[[X_trans_bc]], %[[Y]] : tensor<5x9x7x8x10xf32>, tensor<5x9x7x8x10xf32>) outs(%[[Z]] : tensor<5x9x7x8x10xf32>) -> tensor<5x9x7x8x10xf32>
// CHECK-NOT: linalg.generic

// -----

#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#transposed = affine_map<(d0, d1, d2) -> (d2, d0, d1)>

func.func @transpose_only(%x : tensor<32x2x16xf32>, %y: tensor<2x16x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
%res = linalg.generic
{ indexing_maps = [#transposed, #identity, #identity], iterator_types = ["parallel", "parallel", "parallel"]}
ins(%x, %y : tensor<32x2x16xf32>, tensor<2x16x32xf32>)
outs(%z : tensor<2x16x32xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%div = arith.divf %in, %in_1 : f32
linalg.yield %div : f32
} -> tensor<2x16x32xf32>
return %res : tensor<2x16x32xf32>
}

// CHECK-LABEL: transpose_only
// CHECK-SAME: %[[X:.+]]: tensor<32x2x16xf32>, %[[Y:.+]]: tensor<2x16x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32>
// CHECK: %[[X_trans:.+]] = linalg.transpose ins(%[[X]] : tensor<32x2x16xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) permutation = [1, 2, 0]
// CHECK: {{.*}} = linalg.div ins(%[[X_trans]], %[[Y]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%[[Z]] : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
// CHECK-NOT: linalg.generic

// -----

#identity = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#broadcast = affine_map<(d0, d1, d2) -> (d0, d2)>
func.func @broadcast_only(%x : tensor<2x16x32xf32>, %y: tensor<2x32xf32>, %z : tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
%res = linalg.generic
{ indexing_maps = [#identity, #broadcast, #identity], iterator_types = ["parallel", "parallel", "parallel"]}
ins(%x, %y : tensor<2x16x32xf32>, tensor<2x32xf32>)
outs(%z : tensor<2x16x32xf32>) {
^bb0(%in: f32, %in_1: f32, %out: f32):
%div = arith.divf %in, %in_1 : f32
linalg.yield %div : f32
} -> tensor<2x16x32xf32>
return %res : tensor<2x16x32xf32>
}

// CHECK-LABEL: broadcast_only
// CHECK-SAME: %[[X:.+]]: tensor<2x16x32xf32>, %[[Y:.+]]: tensor<2x32xf32>, %[[Z:.+]]: tensor<2x16x32xf32>) -> tensor<2x16x32xf32> {
// CHECK: %[[E0:.+]] = tensor.empty() : tensor<2x16x32xf32>
// CHECK: %[[X_bc:.+]] = linalg.broadcast ins(%[[Y]] : tensor<2x32xf32>) outs(%[[E0]] : tensor<2x16x32xf32>) dimensions = [1]
// CHECK: {{.*}} = linalg.div ins(%[[X]], %[[X_bc]] : tensor<2x16x32xf32>, tensor<2x16x32xf32>) outs(%arg2 : tensor<2x16x32xf32>) -> tensor<2x16x32xf32>
// CHECK-NOT: linalg.generic
Loading