Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def Linalg_IndexOp : Linalg_Op<"index", [Pure]>,

let assemblyFormat = [{ $dim attr-dict `:` type($result) }];
let hasVerifier = 1;
let hasFolder = 1;
}

def Linalg_SoftmaxOp : Linalg_Op<"softmax",
Expand Down
29 changes: 29 additions & 0 deletions mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2283,6 +2283,35 @@ LogicalResult IndexOp::verify() {
return success();
}

OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
auto linalgOp = cast<LinalgOp>((*this)->getParentOp());
int64_t flatDimPos =
cast<AffineDimExpr>(linalgOp.getShapesToLoopsMap().getResult(getDim()))
.getPosition();

// Find the flat dimension position among the operands.
int64_t flatPosOffset = 0;
for (Value operand : linalgOp->getOperands()) {
assert(flatDimPos >= flatPosOffset && "invalid position");
auto shapedType = dyn_cast<ShapedType>(operand.getType());
if (!shapedType)
break;

int64_t rank = shapedType.getRank();
if (flatDimPos < flatPosOffset + rank) {
// Found the dimension within this shape. Now we can either fold if the
// dim size is 1, or bail out otherwise.
int64_t pos = flatDimPos - flatPosOffset;
if (shapedType.getDimSize(pos) != 1)
break;

return IntegerAttr::get(IndexType::get(getContext()), 0);
}
flatPosOffset += rank;
}
return OpFoldResult{};
}

/////// Operations corresponding to library calls defined with Tablegen ////////

#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"
Expand Down
80 changes: 80 additions & 0 deletions mlir/test/Dialect/Linalg/canonicalize.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,86 @@ func.func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
}

// -----

// CHECK: func @fold_linalg_index_tensor_static
func.func @fold_linalg_index_tensor_static(%0: tensor<4x16xi32>, %1: tensor<1x16xi32>,
%2: tensor<4x1xi32>) -> tensor<4x1xi32> {
// CHECK-NEXT: linalg.generic
// CHECK: %[[IDX_0:.+]] = linalg.index 0 : index
// CHECK-NOT: linalg.index 1
// CHECK: %[[IDX_2:.+]] = linalg.index 2 : index
// CHECK: %[[ADD:.+]] = arith.addi %[[IDX_0]], %[[IDX_2]]
// CHECK: %[[CAST:.+]] = arith.index_cast %[[ADD]]
// CHECK: linalg.yield %[[CAST]]
%res = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
affine_map<(d0, d1, d2) -> (d1, d2)>,
affine_map<(d0, d1, d2) -> (d0, d1)>],
iterator_types = ["parallel", "parallel", "reduction"]}
ins(%0, %1 : tensor<4x16xi32>, tensor<1x16xi32>)
outs(%2 : tensor<4x1xi32>) {
^bb0(%lhs: i32, %rhs: i32, %out: i32):
%idx0 = linalg.index 0 : index
%idx1 = linalg.index 1 : index
%idx2 = linalg.index 2 : index
%add0 = arith.addi %idx0, %idx1 : index
%add1 = arith.addi %add0, %idx2 : index
%int = arith.index_cast %add1 : index to i32
linalg.yield %int : i32
} -> tensor<4x1xi32>
return %res : tensor<4x1xi32>
}

// -----

// CHECK: func @fold_linalg_index_tensor_dynamic
func.func @fold_linalg_index_tensor_dynamic(%0: tensor<?x1xi32>,
%1: tensor<?x1xi32>) -> tensor<?x1xi32> {
// CHECK-NEXT: linalg.generic
// CHECK: %[[IDX_0:.+]] = linalg.index 0 : index
// CHECK-NOT: linalg.index 1
// CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX_0]]
// CHECK: linalg.yield %[[CAST]]
%res = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d1, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%0 : tensor<?x1xi32>)
outs(%1 : tensor<?x1xi32>) {
^bb0(%lhs: i32, %out: i32):
%idx0 = linalg.index 0 : index
%idx1 = linalg.index 1 : index
%add = arith.addi %idx0, %idx1 : index
%int = arith.index_cast %add : index to i32
linalg.yield %int : i32
} -> tensor<?x1xi32>
return %res : tensor<?x1xi32>
}

// -----

// CHECK: func @fold_linalg_index_memref
func.func @fold_linalg_index_memref(%0: memref<1x?xi32>, %1: memref<1x?xi32>) {
// CHECK-NEXT: linalg.generic
// CHECK-NOT: linalg.index 0
// CHECK: %[[IDX_1:.+]] = linalg.index 1 : index
// CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX_1]]
// CHECK: linalg.yield %[[CAST]]
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
affine_map<(d0, d1) -> (d1, d1)>],
iterator_types = ["parallel", "parallel"]}
ins(%0 : memref<1x?xi32>)
outs(%1 : memref<1x?xi32>) {
^bb0(%lhs: i32, %out: i32):
%idx0 = linalg.index 0 : index
%idx1 = linalg.index 1 : index
%add = arith.addi %idx0, %idx1 : index
%int = arith.index_cast %add : index to i32
linalg.yield %int : i32
}
return
}

// -----

// CHECK-LABEL: func @fold_fill_reshape()
func.func @fold_fill_reshape() -> tensor<6x4xf32> {
%zero = arith.constant 0.0 : f32
Expand Down
Loading