Skip to content

Commit 492439a

Browse files
committed
[mlir][linalg] Add folder for linalg.index
We know that the index of unit dims is always 0.
1 parent 2cdf474 commit 492439a

File tree

3 files changed

+110
-0
lines changed

3 files changed

+110
-0
lines changed

mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ def Linalg_IndexOp : Linalg_Op<"index", [Pure]>,
8888

8989
let assemblyFormat = [{ $dim attr-dict `:` type($result) }];
9090
let hasVerifier = 1;
91+
let hasFolder = 1;
9192
}
9293

9394
def Linalg_SoftmaxOp : Linalg_Op<"softmax",

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2283,6 +2283,35 @@ LogicalResult IndexOp::verify() {
22832283
return success();
22842284
}
22852285

2286+
OpFoldResult IndexOp::fold(FoldAdaptor adaptor) {
2287+
auto linalgOp = cast<LinalgOp>((*this)->getParentOp());
2288+
int64_t flatDimPos =
2289+
cast<AffineDimExpr>(linalgOp.getShapesToLoopsMap().getResult(getDim()))
2290+
.getPosition();
2291+
2292+
// Find the flat dimension position among the operands.
2293+
int64_t flatPosOffset = 0;
2294+
for (Value operand : linalgOp->getOperands()) {
2295+
assert(flatDimPos >= flatPosOffset && "invalid position");
2296+
auto shapedType = dyn_cast<ShapedType>(operand.getType());
2297+
if (!shapedType)
2298+
break;
2299+
2300+
int64_t rank = shapedType.getRank();
2301+
if (flatDimPos < flatPosOffset + rank) {
2302+
// Found the dimension within this shape. Now we can either fold if the
2303+
// dim size is 1, or bail out otherwise.
2304+
int64_t pos = flatDimPos - flatPosOffset;
2305+
if (shapedType.getDimSize(pos) != 1)
2306+
break;
2307+
2308+
return IntegerAttr::get(IndexType::get(getContext()), 0);
2309+
}
2310+
flatPosOffset += rank;
2311+
}
2312+
return OpFoldResult{};
2313+
}
2314+
22862315
/////// Operations corresponding to library calls defined with Tablegen ////////
22872316

22882317
#include "mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yamlgen.cpp.inc"

mlir/test/Dialect/Linalg/canonicalize.mlir

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -305,6 +305,86 @@ func.func @self_copy(%arg0 : memref<2x3x?x4xf32>) {
305305
}
306306

307307
// -----
308+
309+
// CHECK: func @fold_linalg_index_tensor_static
310+
func.func @fold_linalg_index_tensor_static(%0: tensor<4x16xi32>, %1: tensor<1x16xi32>,
311+
%2: tensor<4x1xi32>) -> tensor<4x1xi32> {
312+
// CHECK-NEXT: linalg.generic
313+
// CHECK: %[[IDX_0:.+]] = linalg.index 0 : index
314+
// CHECK-NOT: linalg.index 1
315+
// CHECK: %[[IDX_2:.+]] = linalg.index 2 : index
316+
// CHECK: %[[ADD:.+]] = arith.addi %[[IDX_0]], %[[IDX_2]]
317+
// CHECK: %[[CAST:.+]] = arith.index_cast %[[ADD]]
318+
// CHECK: linalg.yield %[[CAST]]
319+
%res = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>,
320+
affine_map<(d0, d1, d2) -> (d1, d2)>,
321+
affine_map<(d0, d1, d2) -> (d0, d1)>],
322+
iterator_types = ["parallel", "parallel", "reduction"]}
323+
ins(%0, %1 : tensor<4x16xi32>, tensor<1x16xi32>)
324+
outs(%2 : tensor<4x1xi32>) {
325+
^bb0(%lhs: i32, %rhs: i32, %out: i32):
326+
%idx0 = linalg.index 0 : index
327+
%idx1 = linalg.index 1 : index
328+
%idx2 = linalg.index 2 : index
329+
%add0 = arith.addi %idx0, %idx1 : index
330+
%add1 = arith.addi %add0, %idx2 : index
331+
%int = arith.index_cast %add1 : index to i32
332+
linalg.yield %int : i32
333+
} -> tensor<4x1xi32>
334+
return %res : tensor<4x1xi32>
335+
}
336+
337+
// -----
338+
339+
// CHECK: func @fold_linalg_index_tensor_dynamic
340+
func.func @fold_linalg_index_tensor_dynamic(%0: tensor<?x1xi32>,
341+
%1: tensor<?x1xi32>) -> tensor<?x1xi32> {
342+
// CHECK-NEXT: linalg.generic
343+
// CHECK: %[[IDX_0:.+]] = linalg.index 0 : index
344+
// CHECK-NOT: linalg.index 1
345+
// CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX_0]]
346+
// CHECK: linalg.yield %[[CAST]]
347+
%res = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
348+
affine_map<(d0, d1) -> (d1, d1)>],
349+
iterator_types = ["parallel", "parallel"]}
350+
ins(%0 : tensor<?x1xi32>)
351+
outs(%1 : tensor<?x1xi32>) {
352+
^bb0(%lhs: i32, %out: i32):
353+
%idx0 = linalg.index 0 : index
354+
%idx1 = linalg.index 1 : index
355+
%add = arith.addi %idx0, %idx1 : index
356+
%int = arith.index_cast %add : index to i32
357+
linalg.yield %int : i32
358+
} -> tensor<?x1xi32>
359+
return %res : tensor<?x1xi32>
360+
}
361+
362+
// -----
363+
364+
// CHECK: func @fold_linalg_index_memref
365+
func.func @fold_linalg_index_memref(%0: memref<1x?xi32>, %1: memref<1x?xi32>) {
366+
// CHECK-NEXT: linalg.generic
367+
// CHECK-NOT: linalg.index 0
368+
// CHECK: %[[IDX_1:.+]] = linalg.index 1 : index
369+
// CHECK: %[[CAST:.+]] = arith.index_cast %[[IDX_1]]
370+
// CHECK: linalg.yield %[[CAST]]
371+
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>,
372+
affine_map<(d0, d1) -> (d1, d1)>],
373+
iterator_types = ["parallel", "parallel"]}
374+
ins(%0 : memref<1x?xi32>)
375+
outs(%1 : memref<1x?xi32>) {
376+
^bb0(%lhs: i32, %out: i32):
377+
%idx0 = linalg.index 0 : index
378+
%idx1 = linalg.index 1 : index
379+
%add = arith.addi %idx0, %idx1 : index
380+
%int = arith.index_cast %add : index to i32
381+
linalg.yield %int : i32
382+
}
383+
return
384+
}
385+
386+
// -----
387+
308388
// CHECK-LABEL: func @fold_fill_reshape()
309389
func.func @fold_fill_reshape() -> tensor<6x4xf32> {
310390
%zero = arith.constant 0.0 : f32

0 commit comments

Comments
 (0)