Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 41 additions & 8 deletions mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,13 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/DestinationStyleOpInterface.h"
#include "mlir/Interfaces/TilingInterface.h"
#include "mlir/Rewrite/FrozenRewritePatternSet.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
Expand Down Expand Up @@ -186,18 +188,49 @@ static void checkSafeToTileToForall(TilingInterface op,
}
}

/// Collect divider of the `ofr`.
static void collectDividers(OpFoldResult ofr,
SmallVector<OpFoldResult> &dividers) {
dividers.push_back(ofr);
if (ofr.is<Attribute>())
return;
auto mulOp = cast<Value>(ofr).getDefiningOp<arith::MulIOp>();
if (!mulOp)
return;

// Given `ofr` = `x` * `y`, all dividers of `x` and `y` are dividers of `ofr`.
collectDividers(mulOp.getLhs(), dividers);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Recursions are generally discouraged.

collectDividers(mulOp.getRhs(), dividers);
}

/// Check if `stride` evenly divides the trip count `size - offset`.
static bool tileDividesIterationDomain(Range loopRange) {
std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
std::optional<int64_t> offsetAsInt = getConstantIntValue(loopRange.offset);
if (!offsetAsInt)
return false;
std::optional<int64_t> sizeAsInt = getConstantIntValue(loopRange.size);
if (!sizeAsInt)
return false;
std::optional<int64_t> strideAsInt = getConstantIntValue(loopRange.stride);
if (!strideAsInt)
return false;
return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() == 0);
if (strideAsInt && offsetAsInt && sizeAsInt)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Please add { } around multi-line statements.

// `stride`/`size`/`offset` are static, checking (size - offset) % stride =
// 0.
return ((sizeAsInt.value() - offsetAsInt.value()) % strideAsInt.value() ==
0);

// At least `stride`/`size`/`offset` is dynamic.
SmallVector<OpFoldResult> dividersOfSize, dividersOfOffset;
collectDividers(loopRange.size, dividersOfSize);
collectDividers(loopRange.offset, dividersOfOffset);

// Return true if `stride` divides one of the dividers of both `size` and
// `offset`.
auto isStrideDividesDivider = [&](OpFoldResult divider) {
if (!strideAsInt)
// `stride` is dynamic.
return divider == loopRange.stride;

std::optional<int64_t> dividerAsInt = getConstantIntValue(divider);
return dividerAsInt && *dividerAsInt % *strideAsInt == 0;
};
return llvm::any_of(dividersOfSize, isStrideDividesDivider) &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dont you think this logic is more complicated that you want it to be. I'd like to scope this a bit more narrowly. Just look for the immediate operations being a arith.mul and avoid the recursion.

The more general solution should be using something the IntegerDivisibilityAnalysis (here) that is very similar to the range inference analysis that can be used to fold this away. Doing this in transformations like this becomes unmaintainable in the long run.

llvm::any_of(dividersOfOffset, isStrideDividesDivider);
}

/// Returns the bounded tile size given the current `offset`, `loopRange` and
Expand Down
47 changes: 47 additions & 0 deletions mlir/test/Dialect/Linalg/transform-op-tile.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,50 @@ func.func @tile_linalg_matmul(
-> tensor<128x128xf32>
return %0 : tensor<128x128xf32>
}

// -----

#map = affine_map<(d0) -> (d0)>

// CHECK-LABEL: splited_dynamic_linalg_generic
func.func @splited_dynamic_linalg_generic(%arg0: tensor<?xi16>, %arg1: tensor<?xi16>) -> tensor<?xi16> {
%c80 = arith.constant 80 : index
%c0 = arith.constant 0 : index
%dim = tensor.dim %arg1, %c0 : tensor<?xi16>
%0 = tensor.empty(%dim) : tensor<?xi16>
%1 = arith.divui %dim, %c80 : index
%2 = arith.muli %1, %c80 : index
%3 = arith.remui %dim, %c80 : index
%extracted_slice = tensor.extract_slice %arg0[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
%extracted_slice_0 = tensor.extract_slice %arg1[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
%extracted_slice_1 = tensor.extract_slice %0[0] [%2] [1] : tensor<?xi16> to tensor<?xi16>
// CHECK: scf.for
// CHECK-NOT: affine.min
%4 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%extracted_slice, %extracted_slice_0 : tensor<?xi16>, tensor<?xi16>) outs(%extracted_slice_1 : tensor<?xi16>) {
^bb0(%in_1: i16, %in_2: i16, %out: i16):
%6 = arith.addi %in_1, %in_2 : i16
linalg.yield %6 : i16
} -> tensor<?xi16>
%inserted_slice = tensor.insert_slice %4 into %0[%2] [%2] [1] : tensor<?xi16> into tensor<?xi16>
%extracted_slice_2 = tensor.extract_slice %arg0[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
%extracted_slice_3 = tensor.extract_slice %arg1[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
%extracted_slice_4 = tensor.extract_slice %0[%2] [%3] [1] : tensor<?xi16> to tensor<?xi16>
// CHECK-NOT: scf.for
%5 = linalg.generic {indexing_maps = [#map, #map, #map], iterator_types = ["parallel"]} ins(%extracted_slice_2, %extracted_slice_3 : tensor<?xi16>, tensor<?xi16>) outs(%extracted_slice_4 : tensor<?xi16>) {
^bb0(%in_1: i16, %in_2: i16, %out: i16):
%7 = arith.addi %in_1, %in_2 : i16
linalg.yield %7 : i16
} -> tensor<?xi16>
%inserted_slice_0 = tensor.insert_slice %5 into %inserted_slice[%2] [%3] [1] : tensor<?xi16> into tensor<?xi16>
return %inserted_slice_0 : tensor<?xi16>
}


module attributes {transform.with_named_sequence} {
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
%0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%const = transform.structured.match ops{["arith.constant"]} in %arg1 : (!transform.any_op) -> !transform.any_op
%1, %loop = transform.structured.tile_using_for %0 tile_sizes [%const] : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
transform.yield
}
}
Loading