Skip to content

Commit 3288c3f

Browse files
authored
LinalgElementwiseFusion: fix linalg.index computation while fusing (#410)
Fix num of loops calculation
1 parent 18197f9 commit 3288c3f

File tree

2 files changed

+30
-2
lines changed

2 files changed

+30
-2
lines changed

mlir/lib/Dialect/Linalg/Transforms/ElementwiseOpFusion.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,7 @@ static void generateFusedElementwiseOpRegion(
223223
// `consumerToProducerLoopsMap` to map the producer indices.
224224
if (producer.hasIndexSemantics()) {
225225
// Add an index operation for every fused loop dimension.
226-
unsigned numFusedOpLoops =
227-
std::max(producer.getNumLoops(), consumer.getNumLoops());
226+
unsigned numFusedOpLoops = consumerToProducerLoopsMap.getNumDims();
228227
SmallVector<Value> fusedIndices;
229228
fusedIndices.reserve(numFusedOpLoops);
230229
llvm::transform(llvm::seq<uint64_t>(0, numFusedOpLoops),

mlir/test/Dialect/Linalg/fusion-elementwise-ops.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -977,3 +977,32 @@ module {
977977
// CHECK-DAG: %[[T3:.+]] = arith.addf %[[T2]], %[[B1]]
978978
// CHECK: linalg.yield %[[T3]] : f32
979979
// CHECK: return %[[GENERIC]]
980+
981+
// -----
982+
983+
#map = affine_map<(d0, d1, d2) -> (d1, d2)>
984+
#map1 = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
985+
#map2 = affine_map<(d0, d1) -> (d0, d1 floordiv 4, d1 mod 4)>
986+
#map3 = affine_map<(d0, d1) -> (d0, d1)>
987+
988+
// CHECK-DAG: [[$MAP0:#[a-zA-Z0-9_]*]] = affine_map<(d0, d1) -> (d0, d1)>
989+
// CHECK-DAG: [[$MAP1:#[a-zA-Z0-9_]*]] = affine_map<(d0) -> (d0 floordiv 4)>
990+
991+
func.func @fuse_and_collapse(%arg0: tensor<3x4xindex>) -> tensor<2x12xindex> {
992+
%1 = tensor.empty() : tensor<2x3x4xindex>
993+
// CHECK: linalg.generic {
994+
// CHECK: %[[INDEX1:[a-zA-Z0-9_]+]] = linalg.index 1 : index
995+
// CHECK-NEXT: %[[MAP:[a-zA-Z0-9_]+]] = affine.apply #map1(%[[INDEX1]])
996+
// CHECK-NEXT: linalg.yield %[[MAP]] : index
997+
%2 = linalg.generic {indexing_maps = [#map, #map1], iterator_types = ["parallel", "parallel", "parallel"]} ins(%arg0: tensor<3x4xindex>) outs(%1 : tensor<2x3x4xindex>) {
998+
^bb0(%in: index, %out: index):
999+
%3 = linalg.index 1 : index
1000+
linalg.yield %3: index
1001+
} -> tensor<2x3x4xindex>
1002+
%7 = tensor.empty() : tensor<2x12xindex>
1003+
%8 = linalg.generic {indexing_maps = [#map2, #map3], iterator_types = ["parallel", "parallel"]} ins(%2 : tensor<2x3x4xindex>) outs(%7 : tensor<2x12xindex>) {
1004+
^bb0(%in: index, %out: index):
1005+
linalg.yield %in : index
1006+
} -> tensor<2x12xindex>
1007+
return %8 : tensor<2x12xindex>
1008+
}

0 commit comments

Comments
 (0)