From a370cd2d44b2715470c49dfb8b013d12dcff9826 Mon Sep 17 00:00:00 2001 From: Thomas Preud'homme Date: Wed, 12 Mar 2025 13:22:14 +0000 Subject: [PATCH 1/8] [MLIR][Linalg] Fix insert_slice fusion with rank reduction Insert_slice fusion with a linalg producer does not account for possible rank-reduction in the insert_slice return type. When that happens, a tosa.cast gets generated due to the type mismatch which is invalid for tensor with different rank. This later trips other pass. --- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 36 ++++++++++- .../Dialect/Linalg/tile-and-fuse-tensors.mlir | 63 +++++++++++++++++++ 2 files changed, 97 insertions(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 223d728b0b27d..81b204df5a0aa 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Dominance.h" @@ -26,6 +27,7 @@ #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SmallBitVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -235,6 +237,31 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) { return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand); } +/// Create tensor.collapse_shape to drop dimensions in `dropDims` in tensor +/// `from`. +tensor::CollapseShapeOp collapseTo(OpBuilder &b, Location loc, Value from, + const llvm::SmallBitVector &dropDims) { + auto fromType = cast(from.getType()); + assert(fromType.getRank() == dropDims.size()); + SmallVector reassocIdxsVec; + ReassociationIndices reassocIdxs; + + bool foundKeptDim = false; + for (int dim = 0; dim < fromType.getRank(); dim++) { + if (!dropDims.test(dim)) { + if (foundKeptDim) { + reassocIdxsVec.push_back(reassocIdxs); + reassocIdxs.clear(); + } + foundKeptDim = true; + } + reassocIdxs.push_back(dim); + } + if (!reassocIdxs.empty()) + reassocIdxsVec.push_back(reassocIdxs); + return b.create(loc, from, reassocIdxsVec); +} + FailureOr mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, OpOperand &consumerOpOperand) { @@ -255,6 +282,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, << "\nNot fusable, not an extract_slice op: " << inputTensor); return failure(); } + llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims(); // If producer is already in the same block as consumer, we are done. if (consumerOpOperand.get().getParentBlock() == @@ -272,12 +300,16 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, consumerOpOperand); // Replace use. + Value def = fusedProducer->getResult(producerOpResult.getResultNumber()); + Type consumerType = consumerOpOperand.get().getType(); + // Rank-reduction occured as part of the extract_slice. + if (cast(consumerType).getRank() != + cast(def.getType()).getRank()) + def = collapseTo(b, fusedProducer.getLoc(), def, droppedDims); // Canonicalizations are not guaranteed to have happened before constructing // `fusedProducer`. In the tensor case this can result in temporary type // mismatches. Insert a `tensor.cast` op to propagate the transformation // invariant that types are compatible. - Value def = fusedProducer->getResult(producerOpResult.getResultNumber()); - Type consumerType = consumerOpOperand.get().getType(); if (consumerType != def.getType()) def = b.create(fusedProducer.getLoc(), consumerType, def); consumerOpOperand.set(def); diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir index 0f27a92c119cf..b4fbdfacde899 100644 --- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -318,3 +318,66 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens } return %for0 : tensor<64x128xf32> } + +// ----- + +func.func @rank_reduced_extract_slice(%arg0: tensor<6x6x1x1x1x1xf32>, %arg1: tensor<6x6x1x1xf32>, %arg2: tensor<4x6xf32>) -> tensor<4x6xf32> { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c6 = arith.constant 6 : index + %cst = arith.constant 0.0 : f32 + %init1 = tensor.empty() : tensor<6x6x1x1x1x1xf32> + %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32> + %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d6, d5)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<6x6x1x1x1x1xf32>, tensor<6x6x1x1xf32>) outs(%fill1 : tensor<6x6x1x1x1x1xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %10 = arith.mulf %in, %in_1 : f32 + %11 = arith.addf %out, %10 : f32 + linalg.yield %11 : f32 + } -> tensor<6x6x1x1x1x1xf32> + %init2 = tensor.empty() : tensor<4x6xf32> + %1 = scf.for %arg4 = %c0 to %c6 step %c2 iter_args(%arg3 = %init2) -> (tensor<4x6xf32>) { + %2 = tensor.extract_slice %0[0, %arg4, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32> + %init3 = tensor.empty() : tensor<4x2xf32> + %fill3 = linalg.fill ins(%cst : f32) outs(%init3 : tensor<4x2xf32>) -> tensor<4x2xf32> + %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%fill3 : tensor<4x2xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %20 = arith.mulf %in, %in_1 : f32 + %21 = arith.addf %out, %20 : f32 + linalg.yield %21 : f32 + } -> tensor<4x2xf32> + %4 = tensor.insert_slice %3 into %arg3[0, %arg4] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32> + scf.yield %4 : tensor<4x6xf32> + } + return %1 : tensor<4x6xf32> +} + +// CHECK: func @rank_reduced_extract_slice( +// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<6x6x1x1x1x1xf32> +// CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<6x6x1x1xf32> +// CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32> + +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index +// CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32> +// CHECK: %[[FILL_PROD:.*]] = linalg.fill ins({{%.*}} : f32) +// CHECK-SAME: outs(%[[EMPTY_PROD]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32> +// CHECK: %[[EMPTY_FOR:.*]] = tensor.empty() : tensor<4x6xf32> +// CHECK: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<4x2xf32> +// CHECK: %[[FILL_CONS:.*]] = linalg.fill ins({{%.*}} : f32) +// CHECK-SAME: outs(%[[EMPTY_CONS]] : tensor<4x2xf32>) -> tensor<4x2xf32> +// CHECK: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[EMPTY_FOR]]) +// CHECK-DAG: %[[ARG0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[I]], 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32> +// CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[I]], 0, 0] [6, 2, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> to tensor<6x2x1x1xf32> +// CHECK-DAG: %[[FILL_PROD_SLICE:.*]] = tensor.extract_slice %[[FILL_PROD]][0, %[[I]], 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32> + +// CHECK: %[[MMUL_PROD:.*]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0_SLICE]], %[[ARG1_SLICE]] : tensor<6x2x1x1x1x1xf32>, tensor<6x2x1x1xf32>) +// CHECK-SAME: outs(%[[FILL_PROD_SLICE]] : tensor<6x2x1x1x1x1xf32>) +// CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0\], \[1, 2, 3, 4, 5\]\]}} : tensor<6x2x1x1x1x1xf32> into tensor<6x2xf32> +// CHECK: %[[MMUL_CONS:.*]] = linalg.generic +// CHECK-SAME: ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>) +// CHECK-SAME: outs(%[[FILL_CONS]] : tensor<4x2xf32>) +// CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32> +// CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32> +// CHECK: return %[[FOR]] : tensor<4x6xf32> From ce067327de33ee88397573ead73291764f15c627 Mon Sep 17 00:00:00 2001 From: Thomas Preud'homme Date: Tue, 25 Mar 2025 22:49:26 +0000 Subject: [PATCH 2/8] Add more comments and simplify test --- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 12 ++- .../Dialect/Linalg/tile-and-fuse-tensors.mlir | 88 ++++++++----------- 2 files changed, 45 insertions(+), 55 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 81b204df5a0aa..d18d6f7ff8dd8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -239,14 +239,20 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) { /// Create tensor.collapse_shape to drop dimensions in `dropDims` in tensor /// `from`. -tensor::CollapseShapeOp collapseTo(OpBuilder &b, Location loc, Value from, - const llvm::SmallBitVector &dropDims) { +static tensor::CollapseShapeOp collapseTo(OpBuilder &b, Location loc, Value from, + const llvm::SmallBitVector &dropDims) { auto fromType = cast(from.getType()); - assert(fromType.getRank() == dropDims.size()); + assert(fromType.getRank() == dropDims.size() && "dropDims dimension does not match from tensor rank"); + // Computed reassociation map for the corresponding tensor.collapse_shape. SmallVector reassocIdxsVec; + // Current reassociation indices to add dropped dimension to. ReassociationIndices reassocIdxs; bool foundKeptDim = false; + // Dropped dimensions might be at the beginning or end of the shape so + // combine all contiguous dimensions before and after a given non dropped + // dimension in reassocIdxs until another non dropped dimension is found. + // When that happens, add the reassociation indices to the map. for (int dim = 0; dim < fromType.getRank(); dim++) { if (!dropDims.test(dim)) { if (foundKeptDim) { diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir index b4fbdfacde899..46b70a9c0edba 100644 --- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -321,63 +321,47 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens // ----- -func.func @rank_reduced_extract_slice(%arg0: tensor<6x6x1x1x1x1xf32>, %arg1: tensor<6x6x1x1xf32>, %arg2: tensor<4x6xf32>) -> tensor<4x6xf32> { - %c0 = arith.constant 0 : index - %c2 = arith.constant 2 : index - %c6 = arith.constant 6 : index +func.func @rank_reduced_extract_slice(%cond : i1) -> tensor<6x2xf32> { %cst = arith.constant 0.0 : f32 - %init1 = tensor.empty() : tensor<6x6x1x1x1x1xf32> - %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32> - %0 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d6)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d6, d5)>, affine_map<(d0, d1, d2, d3, d4, d5, d6) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel", "reduction"]} ins(%arg0, %arg1 : tensor<6x6x1x1x1x1xf32>, tensor<6x6x1x1xf32>) outs(%fill1 : tensor<6x6x1x1x1x1xf32>) { - ^bb0(%in: f32, %in_1: f32, %out: f32): - %10 = arith.mulf %in, %in_1 : f32 - %11 = arith.addf %out, %10 : f32 - linalg.yield %11 : f32 + %cst1 = arith.constant 1.0 : f32 + + %empty1 = tensor.empty() : tensor<6x6x1x1x1x1xf32> + %init1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} outs(%empty1 : tensor<6x6x1x1x1x1xf32>) { + ^bb0(%out: f32): + linalg.yield %cst : f32 } -> tensor<6x6x1x1x1x1xf32> - %init2 = tensor.empty() : tensor<4x6xf32> - %1 = scf.for %arg4 = %c0 to %c6 step %c2 iter_args(%arg3 = %init2) -> (tensor<4x6xf32>) { - %2 = tensor.extract_slice %0[0, %arg4, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32> - %init3 = tensor.empty() : tensor<4x2xf32> - %fill3 = linalg.fill ins(%cst : f32) outs(%init3 : tensor<4x2xf32>) -> tensor<4x2xf32> - %3 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], iterator_types = ["parallel", "parallel", "reduction"]} ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%fill3 : tensor<4x2xf32>) { - ^bb0(%in: f32, %in_1: f32, %out: f32): - %20 = arith.mulf %in, %in_1 : f32 - %21 = arith.addf %out, %20 : f32 - linalg.yield %21 : f32 - } -> tensor<4x2xf32> - %4 = tensor.insert_slice %3 into %arg3[0, %arg4] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32> - scf.yield %4 : tensor<4x6xf32> + + %if = scf.if %cond -> tensor<6x2xf32> { + %extract0 = tensor.extract_slice %init1[0, 0, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32> + + %init2 = tensor.empty() : tensor<6x2xf32> + %add1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extract0 : tensor<6x2xf32>) outs(%init2 : tensor<6x2xf32>) { + ^bb0(%in: f32, %out: f32): + %add = arith.addf %in, %cst1 : f32 + linalg.yield %add : f32 + } -> tensor<6x2xf32> + scf.yield %add1 : tensor<6x2xf32> + } else { + %extract2 = tensor.extract_slice %init1[0, 2, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32> + scf.yield %extract2 : tensor<6x2xf32> } - return %1 : tensor<4x6xf32> + + return %if : tensor<6x2xf32> } // CHECK: func @rank_reduced_extract_slice( -// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<6x6x1x1x1x1xf32> -// CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<6x6x1x1xf32> -// CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32> +// CHECK-SAME: %[[COND:[0-9a-z]*]]: i1 -// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index -// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index // CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32> -// CHECK: %[[FILL_PROD:.*]] = linalg.fill ins({{%.*}} : f32) -// CHECK-SAME: outs(%[[EMPTY_PROD]] : tensor<6x6x1x1x1x1xf32>) -> tensor<6x6x1x1x1x1xf32> -// CHECK: %[[EMPTY_FOR:.*]] = tensor.empty() : tensor<4x6xf32> -// CHECK: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<4x2xf32> -// CHECK: %[[FILL_CONS:.*]] = linalg.fill ins({{%.*}} : f32) -// CHECK-SAME: outs(%[[EMPTY_CONS]] : tensor<4x2xf32>) -> tensor<4x2xf32> -// CHECK: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[EMPTY_FOR]]) -// CHECK-DAG: %[[ARG0_SLICE:.*]] = tensor.extract_slice %[[ARG0]][0, %[[I]], 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32> -// CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, %[[I]], 0, 0] [6, 2, 1, 1] [1, 1, 1, 1] : tensor<6x6x1x1xf32> to tensor<6x2x1x1xf32> -// CHECK-DAG: %[[FILL_PROD_SLICE:.*]] = tensor.extract_slice %[[FILL_PROD]][0, %[[I]], 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32> - -// CHECK: %[[MMUL_PROD:.*]] = linalg.generic -// CHECK-SAME: ins(%[[ARG0_SLICE]], %[[ARG1_SLICE]] : tensor<6x2x1x1x1x1xf32>, tensor<6x2x1x1xf32>) -// CHECK-SAME: outs(%[[FILL_PROD_SLICE]] : tensor<6x2x1x1x1x1xf32>) -// CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0\], \[1, 2, 3, 4, 5\]\]}} : tensor<6x2x1x1x1x1xf32> into tensor<6x2xf32> -// CHECK: %[[MMUL_CONS:.*]] = linalg.generic -// CHECK-SAME: ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>) -// CHECK-SAME: outs(%[[FILL_CONS]] : tensor<4x2xf32>) -// CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32> -// CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32> -// CHECK: return %[[FOR]] : tensor<4x6xf32> +// CHECK: %[[FILL_PROD:.*]] = linalg.generic +// CHECK-SAME: outs(%[[EMPTY_PROD]] : tensor<6x6x1x1x1x1xf32>) + +// CHECK: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<6x2xf32> +// CHECK: %[[EXTRACT_SLICE_CONS:.*]] = tensor.extract_slice %[[EMPTY_PROD]][0, 0, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32> + +// CHECK: %[[FILL_CONS:.*]] = linalg.generic +// CHECK-SAME: outs(%[[EXTRACT_SLICE_CONS]] : tensor<6x2x1x1x1x1xf32>) +// CHECK: %[[CONS_COLLAPSE:.*]] = tensor.collapse_shape %[[FILL_CONS]] {{\[\[0\], \[1, 2, 3, 4, 5\]\]}} : tensor<6x2x1x1x1x1xf32> into tensor<6x2xf32> +// CHECK: %[[ADD1_CONS:.*]] = linalg.generic +// CHECK-SAME: ins(%[[CONS_COLLAPSE]] : tensor<6x2xf32>) +// CHECK-SAME: outs(%[[EMPTY_CONS]] : tensor<6x2xf32>) From b51029d13fb67e2860fdf3832f6b66ea1b30544f Mon Sep 17 00:00:00 2001 From: Thomas Preud'homme Date: Wed, 26 Mar 2025 00:01:55 +0000 Subject: [PATCH 3/8] Fix clang-format --- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index d18d6f7ff8dd8..bcb21263ee68f 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -239,10 +239,12 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) { /// Create tensor.collapse_shape to drop dimensions in `dropDims` in tensor /// `from`. -static tensor::CollapseShapeOp collapseTo(OpBuilder &b, Location loc, Value from, - const llvm::SmallBitVector &dropDims) { +static tensor::CollapseShapeOp +collapseTo(OpBuilder &b, Location loc, Value from, + const llvm::SmallBitVector &dropDims) { auto fromType = cast(from.getType()); - assert(fromType.getRank() == dropDims.size() && "dropDims dimension does not match from tensor rank"); + assert(fromType.getRank() == dropDims.size() && + "dropDims dimension does not match from tensor rank"); // Computed reassociation map for the corresponding tensor.collapse_shape. SmallVector reassocIdxsVec; // Current reassociation indices to add dropped dimension to. From b19b6490e71181c7e956223b3e4f33330fd0fde7 Mon Sep 17 00:00:00 2001 From: Thomas Preud'homme Date: Wed, 7 May 2025 10:56:37 +0100 Subject: [PATCH 4/8] Address comments - rename collapseTo to better reflect its usage - assert it only collapse unit dimensions - rename ReassociationIndices-using variables to reassocGroup and reassocMaps, the same terminology used in tensor.collapse_shape documentation - use more representative test with comments to better explain what the patch does --- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 34 +++--- .../Dialect/Linalg/tile-and-fuse-tensors.mlir | 107 +++++++++++------- 2 files changed, 87 insertions(+), 54 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 3655c43940b06..c52f347135a9e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -236,37 +236,39 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) { return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand); } -/// Create tensor.collapse_shape to drop dimensions in `dropDims` in tensor +/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor /// `from`. static tensor::CollapseShapeOp -collapseTo(OpBuilder &b, Location loc, Value from, - const llvm::SmallBitVector &dropDims) { +dropGivenUnitDims(OpBuilder &b, Location loc, Value from, + const llvm::SmallBitVector &dropDims) { auto fromType = cast(from.getType()); - assert(fromType.getRank() == dropDims.size() && + assert(fromType.getRank() == static_cast(dropDims.size()) && "dropDims dimension does not match from tensor rank"); // Computed reassociation map for the corresponding tensor.collapse_shape. - SmallVector reassocIdxsVec; - // Current reassociation indices to add dropped dimension to. - ReassociationIndices reassocIdxs; + SmallVector reassocMaps; + // Current reassociation group to add dropped dimension to. + ReassociationIndices reassocGroup; bool foundKeptDim = false; // Dropped dimensions might be at the beginning or end of the shape so // combine all contiguous dimensions before and after a given non dropped - // dimension in reassocIdxs until another non dropped dimension is found. + // dimension in reassocGroup until another non dropped dimension is found. // When that happens, add the reassociation indices to the map. for (int dim = 0; dim < fromType.getRank(); dim++) { - if (!dropDims.test(dim)) { + if (dropDims.test(dim)) + assert(fromType.getShape()[dim] == 1 && "Dropping non unit dimension"); + else { if (foundKeptDim) { - reassocIdxsVec.push_back(reassocIdxs); - reassocIdxs.clear(); + reassocMaps.push_back(reassocGroup); + reassocGroup.clear(); } foundKeptDim = true; } - reassocIdxs.push_back(dim); + reassocGroup.push_back(dim); } - if (!reassocIdxs.empty()) - reassocIdxsVec.push_back(reassocIdxs); - return b.create(loc, from, reassocIdxsVec); + if (!reassocGroup.empty()) + reassocMaps.push_back(reassocGroup); + return b.create(loc, from, reassocMaps); } FailureOr @@ -312,7 +314,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, // Rank-reduction occured as part of the extract_slice. if (cast(consumerType).getRank() != cast(def.getType()).getRank()) - def = collapseTo(b, fusedProducer.getLoc(), def, droppedDims); + def = dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims); // Canonicalizations are not guaranteed to have happened before constructing // `fusedProducer`. In the tensor case this can result in temporary type // mismatches. Insert a `tensor.cast` op to propagate the transformation diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir index 46b70a9c0edba..693a2bb29f76e 100644 --- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -321,47 +321,78 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens // ----- -func.func @rank_reduced_extract_slice(%cond : i1) -> tensor<6x2xf32> { +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map4 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map5 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @rank_reduced_extract_slice(%arg0: tensor<1x6x5xf32>, %arg1: tensor<1x5x6xf32>, %arg2: tensor<4x6xf32>) -> tensor<4x6xf32> { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c6 = arith.constant 6 : index %cst = arith.constant 0.0 : f32 - %cst1 = arith.constant 1.0 : f32 - - %empty1 = tensor.empty() : tensor<6x6x1x1x1x1xf32> - %init1 = linalg.generic {indexing_maps = [affine_map<(d0, d1, d2, d3, d4, d5) -> (d0, d1, d2, d3, d4, d5)>], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "parallel"]} outs(%empty1 : tensor<6x6x1x1x1x1xf32>) { - ^bb0(%out: f32): - linalg.yield %cst : f32 - } -> tensor<6x6x1x1x1x1xf32> - - %if = scf.if %cond -> tensor<6x2xf32> { - %extract0 = tensor.extract_slice %init1[0, 0, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32> - - %init2 = tensor.empty() : tensor<6x2xf32> - %add1 = linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%extract0 : tensor<6x2xf32>) outs(%init2 : tensor<6x2xf32>) { - ^bb0(%in: f32, %out: f32): - %add = arith.addf %in, %cst1 : f32 - linalg.yield %add : f32 - } -> tensor<6x2xf32> - scf.yield %add1 : tensor<6x2xf32> - } else { - %extract2 = tensor.extract_slice %init1[0, 2, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2xf32> - scf.yield %extract2 : tensor<6x2xf32> + %init1 = tensor.empty() : tensor<1x6x6xf32> + %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<1x6x6xf32>) -> tensor<1x6x6xf32> + %0 = linalg.generic + {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} + ins(%arg0, %arg1 : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%fill1 : tensor<1x6x6xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %10 = arith.mulf %in, %in_1 : f32 + %11 = arith.addf %out, %10 : f32 + linalg.yield %11 : f32 + } -> tensor<1x6x6xf32> + %init2 = tensor.empty() : tensor<4x6xf32> + %1 = scf.for %arg4 = %c0 to %c6 step %c2 iter_args(%arg3 = %init2) -> (tensor<4x6xf32>) { + %2 = tensor.extract_slice %0[0, 0, %arg4] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32> + %init3 = tensor.empty() : tensor<4x2xf32> + %fill3 = linalg.fill ins(%cst : f32) outs(%init3 : tensor<4x2xf32>) -> tensor<4x2xf32> + %3 = linalg.generic + {indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"]} + ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%fill3 : tensor<4x2xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %20 = arith.mulf %in, %in_1 : f32 + %21 = arith.addf %out, %20 : f32 + linalg.yield %21 : f32 + } -> tensor<4x2xf32> + %4 = tensor.insert_slice %3 into %arg3[0, %arg4] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32> + scf.yield %4 : tensor<4x6xf32> } - - return %if : tensor<6x2xf32> + return %1 : tensor<4x6xf32> } // CHECK: func @rank_reduced_extract_slice( -// CHECK-SAME: %[[COND:[0-9a-z]*]]: i1 - -// CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<6x6x1x1x1x1xf32> -// CHECK: %[[FILL_PROD:.*]] = linalg.generic -// CHECK-SAME: outs(%[[EMPTY_PROD]] : tensor<6x6x1x1x1x1xf32>) +// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<1x6x5xf32> +// CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<1x5x6xf32> +// CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32> -// CHECK: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<6x2xf32> -// CHECK: %[[EXTRACT_SLICE_CONS:.*]] = tensor.extract_slice %[[EMPTY_PROD]][0, 0, 0, 0, 0, 0] [6, 2, 1, 1, 1, 1] [1, 1, 1, 1, 1, 1] : tensor<6x6x1x1x1x1xf32> to tensor<6x2x1x1x1x1xf32> - -// CHECK: %[[FILL_CONS:.*]] = linalg.generic -// CHECK-SAME: outs(%[[EXTRACT_SLICE_CONS]] : tensor<6x2x1x1x1x1xf32>) -// CHECK: %[[CONS_COLLAPSE:.*]] = tensor.collapse_shape %[[FILL_CONS]] {{\[\[0\], \[1, 2, 3, 4, 5\]\]}} : tensor<6x2x1x1x1x1xf32> into tensor<6x2xf32> -// CHECK: %[[ADD1_CONS:.*]] = linalg.generic -// CHECK-SAME: ins(%[[CONS_COLLAPSE]] : tensor<6x2xf32>) -// CHECK-SAME: outs(%[[EMPTY_CONS]] : tensor<6x2xf32>) +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<1x6x6xf32> +// CHECK-NEXT: %[[FILL_PROD:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY_PROD]] : tensor<1x6x6xf32>) -> tensor<1x6x6xf32> +// CHECK-NEXT: %[[EMPTY_FOR:.*]] = tensor.empty() : tensor<4x6xf32> +// CHECK-NEXT: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<4x2xf32> +// CHECK-NEXT: %[[FILL_CONS:.*]] = linalg.fill ins(%[[CST]] : f32) + +// For loop right after tensor alloc & fill, no linalg.generic. +// CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[EMPTY_FOR]]) + +// Producer linalg.generic now inside the loop, with tiled args sliced before +// it. +// CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32> +// CHECK-DAG: %[[PROD_SLICE:.*]] = tensor.extract_slice %[[FILL_PROD]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32> +// CHECK: %[[MMUL_PROD:.*]] = linalg.generic +// CHECK-SAME: ins(%[[ARG0]], %[[ARG1_SLICE]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>) +// CHECK-SAME: outs(%[[PROD_SLICE]] : tensor<1x6x2xf32>) +// +// Consumer uses a rank-reduced version of producer result so a collapse_shape +// is generated. +// CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0, 1\], \[2\]\]}} : tensor<1x6x2xf32> into tensor<6x2xf32> +// CHECK: %[[MMUL_CONS:.*]] = linalg.generic +// CHECK-SAME: ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>) +// CHECK-SAME: outs(%[[FILL_CONS]] : tensor<4x2xf32>) +// CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32> +// CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32> +// CHECK: return %[[FOR]] : tensor<4x6xf32> From 49deedf97d448f7bb71d47da0b0a1bbe7b53db96 Mon Sep 17 00:00:00 2001 From: Thomas Preud'homme Date: Mon, 19 May 2025 23:05:52 +0100 Subject: [PATCH 5/8] Clean up code dropGivenUnitDims(): - move assert out of loop - rework algorithm to make grouping more explicit and avoid complex nested ifs - fix occured typo Test: remove all tensor.empty and linalg.fill --- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 42 +++++++++---------- .../Dialect/Linalg/tile-and-fuse-tensors.mlir | 37 +++++++--------- 2 files changed, 36 insertions(+), 43 deletions(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index c52f347135a9e..d69c85984aa4e 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -242,32 +242,30 @@ static tensor::CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value from, const llvm::SmallBitVector &dropDims) { auto fromType = cast(from.getType()); - assert(fromType.getRank() == static_cast(dropDims.size()) && + int64_t rank = fromType.getRank(); + assert(rank == static_cast(dropDims.size()) && "dropDims dimension does not match from tensor rank"); + assert(llvm::all_of( + dropDims.set_bits(), + [&](unsigned dim) { return fromType.getShape()[dim] == 1; }) && + "Dropping non unit dimension"); // Computed reassociation map for the corresponding tensor.collapse_shape. SmallVector reassocMaps; // Current reassociation group to add dropped dimension to. - ReassociationIndices reassocGroup; - - bool foundKeptDim = false; - // Dropped dimensions might be at the beginning or end of the shape so - // combine all contiguous dimensions before and after a given non dropped - // dimension in reassocGroup until another non dropped dimension is found. - // When that happens, add the reassociation indices to the map. - for (int dim = 0; dim < fromType.getRank(); dim++) { - if (dropDims.test(dim)) - assert(fromType.getShape()[dim] == 1 && "Dropping non unit dimension"); - else { - if (foundKeptDim) { - reassocMaps.push_back(reassocGroup); - reassocGroup.clear(); - } - foundKeptDim = true; - } - reassocGroup.push_back(dim); + + int64_t nextDimToGroup = 0; + llvm::SmallBitVector keptDims(dropDims); + keptDims.flip(); + int64_t lastSetBit = keptDims.find_last(); + for(int64_t setBit : keptDims.set_bits()) { + // Group consecutive dropped dimension with the next non-dropped dimension. + // If this is the last set dimension, also group all subsequent dropped + // dimension, if any. + int64_t upTo = setBit == lastSetBit ? rank - 1 : setBit; + auto seq = llvm::seq_inclusive(nextDimToGroup, upTo); + reassocMaps.emplace_back(llvm::make_range(seq.begin(), seq.end())); + nextDimToGroup = setBit + 1; } - if (!reassocGroup.empty()) - reassocMaps.push_back(reassocGroup); return b.create(loc, from, reassocMaps); } @@ -311,7 +309,7 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, // Replace use. Value def = fusedProducer->getResult(producerOpResult.getResultNumber()); Type consumerType = consumerOpOperand.get().getType(); - // Rank-reduction occured as part of the extract_slice. + // Rank-reduction occurred as part of the extract_slice. if (cast(consumerType).getRank() != cast(def.getType()).getRank()) def = dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims); diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir index 693a2bb29f76e..9340e70b4d507 100644 --- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -327,35 +327,32 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens #map3 = affine_map<(d0, d1, d2) -> (d0, d2)> #map4 = affine_map<(d0, d1, d2) -> (d2, d1)> #map5 = affine_map<(d0, d1, d2) -> (d0, d1)> -func.func @rank_reduced_extract_slice(%arg0: tensor<1x6x5xf32>, %arg1: tensor<1x5x6xf32>, %arg2: tensor<4x6xf32>) -> tensor<4x6xf32> { +func.func @rank_reduced_extract_slice( + %arg0: tensor<1x6x5xf32>, %arg1: tensor<1x5x6xf32>, %arg2: tensor<4x6xf32>, + %arg3: tensor<1x6x6xf32>, %arg4: tensor<4x6xf32>, %arg5: tensor<4x2xf32> +) -> tensor<4x6xf32> { %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index %c6 = arith.constant 6 : index - %cst = arith.constant 0.0 : f32 - %init1 = tensor.empty() : tensor<1x6x6xf32> - %fill1 = linalg.fill ins(%cst : f32) outs(%init1 : tensor<1x6x6xf32>) -> tensor<1x6x6xf32> %0 = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%fill1 : tensor<1x6x6xf32>) { + ins(%arg0, %arg1 : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%arg3 : tensor<1x6x6xf32>) { ^bb0(%in: f32, %in_1: f32, %out: f32): %10 = arith.mulf %in, %in_1 : f32 %11 = arith.addf %out, %10 : f32 linalg.yield %11 : f32 } -> tensor<1x6x6xf32> - %init2 = tensor.empty() : tensor<4x6xf32> - %1 = scf.for %arg4 = %c0 to %c6 step %c2 iter_args(%arg3 = %init2) -> (tensor<4x6xf32>) { - %2 = tensor.extract_slice %0[0, 0, %arg4] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32> - %init3 = tensor.empty() : tensor<4x2xf32> - %fill3 = linalg.fill ins(%cst : f32) outs(%init3 : tensor<4x2xf32>) -> tensor<4x2xf32> + %1 = scf.for %arg7 = %c0 to %c6 step %c2 iter_args(%arg6 = %arg4) -> (tensor<4x6xf32>) { + %2 = tensor.extract_slice %0[0, 0, %arg7] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32> %3 = linalg.generic {indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"]} - ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%fill3 : tensor<4x2xf32>) { + ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%arg5 : tensor<4x2xf32>) { ^bb0(%in: f32, %in_1: f32, %out: f32): %20 = arith.mulf %in, %in_1 : f32 %21 = arith.addf %out, %20 : f32 linalg.yield %21 : f32 } -> tensor<4x2xf32> - %4 = tensor.insert_slice %3 into %arg3[0, %arg4] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32> + %4 = tensor.insert_slice %3 into %arg6[0, %arg7] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32> scf.yield %4 : tensor<4x6xf32> } return %1 : tensor<4x6xf32> @@ -365,24 +362,22 @@ func.func @rank_reduced_extract_slice(%arg0: tensor<1x6x5xf32>, %arg1: tensor<1x // CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<1x6x5xf32> // CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<1x5x6xf32> // CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32> +// CHECK-SAME: %[[ARG3:[0-9a-z]*]]: tensor<1x6x6xf32> +// CHECK-SAME: %[[ARG4:[0-9a-z]*]]: tensor<4x6xf32> +// CHECK-SAME: %[[ARG5:[0-9a-z]*]]: tensor<4x2xf32> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index // CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index -// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 -// CHECK: %[[EMPTY_PROD:.*]] = tensor.empty() : tensor<1x6x6xf32> -// CHECK-NEXT: %[[FILL_PROD:.*]] = linalg.fill ins(%[[CST]] : f32) outs(%[[EMPTY_PROD]] : tensor<1x6x6xf32>) -> tensor<1x6x6xf32> -// CHECK-NEXT: %[[EMPTY_FOR:.*]] = tensor.empty() : tensor<4x6xf32> -// CHECK-NEXT: %[[EMPTY_CONS:.*]] = tensor.empty() : tensor<4x2xf32> -// CHECK-NEXT: %[[FILL_CONS:.*]] = linalg.fill ins(%[[CST]] : f32) // For loop right after tensor alloc & fill, no linalg.generic. -// CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[EMPTY_FOR]]) +// CHECK-NOT: linalg.generic +// CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[ARG4]]) // Producer linalg.generic now inside the loop, with tiled args sliced before // it. // CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32> -// CHECK-DAG: %[[PROD_SLICE:.*]] = tensor.extract_slice %[[FILL_PROD]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32> +// CHECK-DAG: %[[PROD_SLICE:.*]] = tensor.extract_slice %[[ARG3]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32> // CHECK: %[[MMUL_PROD:.*]] = linalg.generic // CHECK-SAME: ins(%[[ARG0]], %[[ARG1_SLICE]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>) // CHECK-SAME: outs(%[[PROD_SLICE]] : tensor<1x6x2xf32>) @@ -392,7 +387,7 @@ func.func @rank_reduced_extract_slice(%arg0: tensor<1x6x5xf32>, %arg1: tensor<1x // CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0, 1\], \[2\]\]}} : tensor<1x6x2xf32> into tensor<6x2xf32> // CHECK: %[[MMUL_CONS:.*]] = linalg.generic // CHECK-SAME: ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>) -// CHECK-SAME: outs(%[[FILL_CONS]] : tensor<4x2xf32>) +// CHECK-SAME: outs(%[[ARG5]] : tensor<4x2xf32>) // CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32> // CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32> // CHECK: return %[[FOR]] : tensor<4x6xf32> From cf20e80c175bbfd6990aa3d58f0d972e796f7510 Mon Sep 17 00:00:00 2001 From: Thomas Preud'homme Date: Tue, 20 May 2025 11:33:20 +0100 Subject: [PATCH 6/8] Fix codestyle --- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index d69c85984aa4e..f983fb5e40fa7 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -257,7 +257,7 @@ dropGivenUnitDims(OpBuilder &b, Location loc, Value from, llvm::SmallBitVector keptDims(dropDims); keptDims.flip(); int64_t lastSetBit = keptDims.find_last(); - for(int64_t setBit : keptDims.set_bits()) { + for (int64_t setBit : keptDims.set_bits()) { // Group consecutive dropped dimension with the next non-dropped dimension. // If this is the last set dimension, also group all subsequent dropped // dimension, if any. From 42d8959bc2594be9218c1e709320e048cde8cc84 Mon Sep 17 00:00:00 2001 From: Thomas Preud'homme Date: Tue, 20 May 2025 13:26:57 +0100 Subject: [PATCH 7/8] Move dropGivenUnitDims to Tensor Utils --- .../include/mlir/Dialect/Tensor/Utils/Utils.h | 5 +++ mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 36 ++----------------- mlir/lib/Dialect/Tensor/Utils/Utils.cpp | 33 +++++++++++++++++ 3 files changed, 40 insertions(+), 34 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h index 22ca8a99dd7db..6c2a55f67db87 100644 --- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h @@ -43,6 +43,11 @@ FailureOr computeTransposedType(RankedTensorType rankedTensorType, ArrayRef transposeVector); +/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor +/// `from`. +CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value from, + const llvm::SmallBitVector &dropDims); + /// A tensor.insert_slice is a cast-like operation if it merely rank-extends the /// source tensor or inserts the source tensor into a destination tensor with /// the same shape. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index f983fb5e40fa7..e3673e2a385c0 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -236,39 +236,6 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpOperand &consumerOpOperand) { return fuseProducerOfTensor(b, producerOpResult, consumerOpOperand); } -/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor -/// `from`. -static tensor::CollapseShapeOp -dropGivenUnitDims(OpBuilder &b, Location loc, Value from, - const llvm::SmallBitVector &dropDims) { - auto fromType = cast(from.getType()); - int64_t rank = fromType.getRank(); - assert(rank == static_cast(dropDims.size()) && - "dropDims dimension does not match from tensor rank"); - assert(llvm::all_of( - dropDims.set_bits(), - [&](unsigned dim) { return fromType.getShape()[dim] == 1; }) && - "Dropping non unit dimension"); - // Computed reassociation map for the corresponding tensor.collapse_shape. - SmallVector reassocMaps; - // Current reassociation group to add dropped dimension to. - - int64_t nextDimToGroup = 0; - llvm::SmallBitVector keptDims(dropDims); - keptDims.flip(); - int64_t lastSetBit = keptDims.find_last(); - for (int64_t setBit : keptDims.set_bits()) { - // Group consecutive dropped dimension with the next non-dropped dimension. - // If this is the last set dimension, also group all subsequent dropped - // dimension, if any. - int64_t upTo = setBit == lastSetBit ? rank - 1 : setBit; - auto seq = llvm::seq_inclusive(nextDimToGroup, upTo); - reassocMaps.emplace_back(llvm::make_range(seq.begin(), seq.end())); - nextDimToGroup = setBit + 1; - } - return b.create(loc, from, reassocMaps); -} - FailureOr mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, OpOperand &consumerOpOperand) { @@ -312,7 +279,8 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, // Rank-reduction occurred as part of the extract_slice. if (cast(consumerType).getRank() != cast(def.getType()).getRank()) - def = dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims); + def = + tensor::dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims); // Canonicalizations are not guaranteed to have happened before constructing // `fusedProducer`. In the tensor case this can result in temporary type // mismatches. Insert a `tensor.cast` op to propagate the transformation diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp index c3d56759a896a..53a219dff48c5 100644 --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -94,6 +94,39 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType, return transposedTensorType; } +/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor +/// `from`. +CollapseShapeOp +mlir::tensor::dropGivenUnitDims(OpBuilder &b, Location loc, Value from, + const llvm::SmallBitVector &dropDims) { + auto fromType = cast(from.getType()); + int64_t rank = fromType.getRank(); + assert(rank == static_cast(dropDims.size()) && + "dropDims dimension does not match from tensor rank"); + assert(llvm::all_of( + dropDims.set_bits(), + [&](unsigned dim) { return fromType.getShape()[dim] == 1; }) && + "Dropping non unit dimension"); + // Computed reassociation map for the corresponding tensor.collapse_shape. + SmallVector reassocMaps; + // Current reassociation group to add dropped dimension to. + + int64_t nextDimToGroup = 0; + llvm::SmallBitVector keptDims(dropDims); + keptDims.flip(); + int64_t lastSetBit = keptDims.find_last(); + for (int64_t setBit : keptDims.set_bits()) { + // Group consecutive dropped dimension with the next non-dropped dimension. + // If this is the last set dimension, also group all subsequent dropped + // dimension, if any. + int64_t upTo = setBit == lastSetBit ? rank - 1 : setBit; + auto seq = llvm::seq_inclusive(nextDimToGroup, upTo); + reassocMaps.emplace_back(llvm::make_range(seq.begin(), seq.end())); + nextDimToGroup = setBit + 1; + } + return b.create(loc, from, reassocMaps); +} + bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) { llvm::SmallBitVector droppedDims = op.getDroppedDims(); int64_t srcDim = 0; From 6fc320ddd5696da921939eae2bea403c37c8568b Mon Sep 17 00:00:00 2001 From: Thomas Preud'homme Date: Tue, 20 May 2025 22:43:32 +0100 Subject: [PATCH 8/8] Address review comments Utils: - drop comments on implementation - rename from into src Fusion: - restrict live range of droppedDims - clarify comment for rank-reduction check Test: - Use more descriptive SSA and FileCheck variables - Emphasize the rank-reducing extract_slice in the input IR as the key aspect of the test. --- .../include/mlir/Dialect/Tensor/Utils/Utils.h | 4 +- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 8 +-- mlir/lib/Dialect/Tensor/Utils/Utils.cpp | 14 +++-- .../Dialect/Linalg/tile-and-fuse-tensors.mlir | 51 ++++++++++--------- 4 files changed, 41 insertions(+), 36 deletions(-) diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h index 6c2a55f67db87..1a4733df3f187 100644 --- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h @@ -44,8 +44,8 @@ computeTransposedType(RankedTensorType rankedTensorType, ArrayRef transposeVector); /// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor -/// `from`. -CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value from, +/// `src`. +CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value src, const llvm::SmallBitVector &dropDims); /// A tensor.insert_slice is a cast-like operation if it merely rank-extends the diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index e3673e2a385c0..4fc8a17554435 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -256,7 +256,6 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, << "\nNot fusable, not an extract_slice op: " << inputTensor); return failure(); } - llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims(); // If producer is already in the same block as consumer, we are done. if (consumerOpOperand.get().getParentBlock() == @@ -276,11 +275,14 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, // Replace use. Value def = fusedProducer->getResult(producerOpResult.getResultNumber()); Type consumerType = consumerOpOperand.get().getType(); - // Rank-reduction occurred as part of the extract_slice. + // Check if rank-reduction occurred as part of the extract_slice. If yes, + // collapse the dropped dimensions. if (cast(consumerType).getRank() != - cast(def.getType()).getRank()) + cast(def.getType()).getRank()) { + llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims(); def = tensor::dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims); + } // Canonicalizations are not guaranteed to have happened before constructing // `fusedProducer`. In the tensor case this can result in temporary type // mismatches. Insert a `tensor.cast` op to propagate the transformation diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp index 53a219dff48c5..11ae0108594dd 100644 --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -94,18 +94,16 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType, return transposedTensorType; } -/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor -/// `from`. CollapseShapeOp -mlir::tensor::dropGivenUnitDims(OpBuilder &b, Location loc, Value from, +mlir::tensor::dropGivenUnitDims(OpBuilder &b, Location loc, Value src, const llvm::SmallBitVector &dropDims) { - auto fromType = cast(from.getType()); - int64_t rank = fromType.getRank(); + auto srcType = cast(src.getType()); + int64_t rank = srcType.getRank(); assert(rank == static_cast(dropDims.size()) && - "dropDims dimension does not match from tensor rank"); + "dropDims dimension does not match src tensor rank"); assert(llvm::all_of( dropDims.set_bits(), - [&](unsigned dim) { return fromType.getShape()[dim] == 1; }) && + [&](unsigned dim) { return srcType.getShape()[dim] == 1; }) && "Dropping non unit dimension"); // Computed reassociation map for the corresponding tensor.collapse_shape. SmallVector reassocMaps; @@ -124,7 +122,7 @@ mlir::tensor::dropGivenUnitDims(OpBuilder &b, Location loc, Value from, reassocMaps.emplace_back(llvm::make_range(seq.begin(), seq.end())); nextDimToGroup = setBit + 1; } - return b.create(loc, from, reassocMaps); + return b.create(loc, src, reassocMaps); } bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) { diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir index 9340e70b4d507..fd755a208b2c9 100644 --- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -328,43 +328,48 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens #map4 = affine_map<(d0, d1, d2) -> (d2, d1)> #map5 = affine_map<(d0, d1, d2) -> (d0, d1)> func.func @rank_reduced_extract_slice( - %arg0: tensor<1x6x5xf32>, %arg1: tensor<1x5x6xf32>, %arg2: tensor<4x6xf32>, - %arg3: tensor<1x6x6xf32>, %arg4: tensor<4x6xf32>, %arg5: tensor<4x2xf32> + %prod_in: tensor<1x6x5xf32>, %prod_weight: tensor<1x5x6xf32>, + %cons_in: tensor<4x6xf32>, %prod_init: tensor<1x6x6xf32>, + %for_iv_init: tensor<4x6xf32>, %cons_init: tensor<4x2xf32> ) -> tensor<4x6xf32> { %c0 = arith.constant 0 : index %c2 = arith.constant 2 : index %c6 = arith.constant 6 : index - %0 = linalg.generic + %mmul_prod = linalg.generic {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} - ins(%arg0, %arg1 : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%arg3 : tensor<1x6x6xf32>) { + ins(%prod_in, %prod_weight : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%prod_init : tensor<1x6x6xf32>) { ^bb0(%in: f32, %in_1: f32, %out: f32): %10 = arith.mulf %in, %in_1 : f32 %11 = arith.addf %out, %10 : f32 linalg.yield %11 : f32 } -> tensor<1x6x6xf32> - %1 = scf.for %arg7 = %c0 to %c6 step %c2 iter_args(%arg6 = %arg4) -> (tensor<4x6xf32>) { - %2 = tensor.extract_slice %0[0, 0, %arg7] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32> - %3 = linalg.generic + %for = scf.for %arg7 = %c0 to %c6 step %c2 iter_args(%arg6 = %for_iv_init) -> (tensor<4x6xf32>) { + + // Extract slice with rank-reduced result type. When fused in the loop + // with sliced operands, the producer linalg must have its now sliced + // result be rank-reduced as well to match consumer's use type. + %prod_slice = tensor.extract_slice %mmul_prod[0, 0, %arg7] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32> + %mmul_cons = linalg.generic {indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"]} - ins(%arg2, %2 : tensor<4x6xf32>, tensor<6x2xf32>) outs(%arg5 : tensor<4x2xf32>) { + ins(%cons_in, %prod_slice : tensor<4x6xf32>, tensor<6x2xf32>) outs(%cons_init : tensor<4x2xf32>) { ^bb0(%in: f32, %in_1: f32, %out: f32): %20 = arith.mulf %in, %in_1 : f32 %21 = arith.addf %out, %20 : f32 linalg.yield %21 : f32 } -> tensor<4x2xf32> - %4 = tensor.insert_slice %3 into %arg6[0, %arg7] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32> + %4 = tensor.insert_slice %mmul_cons into %arg6[0, %arg7] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32> scf.yield %4 : tensor<4x6xf32> } - return %1 : tensor<4x6xf32> + return %for : tensor<4x6xf32> } // CHECK: func @rank_reduced_extract_slice( -// CHECK-SAME: %[[ARG0:[0-9a-z]*]]: tensor<1x6x5xf32> -// CHECK-SAME: %[[ARG1:[0-9a-z]*]]: tensor<1x5x6xf32> -// CHECK-SAME: %[[ARG2:[0-9a-z]*]]: tensor<4x6xf32> -// CHECK-SAME: %[[ARG3:[0-9a-z]*]]: tensor<1x6x6xf32> -// CHECK-SAME: %[[ARG4:[0-9a-z]*]]: tensor<4x6xf32> -// CHECK-SAME: %[[ARG5:[0-9a-z]*]]: tensor<4x2xf32> +// CHECK-SAME: %[[PROD_IN:[0-9a-z]*]]: tensor<1x6x5xf32> +// CHECK-SAME: %[[PROD_WEIGHT:[0-9a-z]*]]: tensor<1x5x6xf32> +// CHECK-SAME: %[[CONS_IN:[0-9a-z]*]]: tensor<4x6xf32> +// CHECK-SAME: %[[PROD_INIT:[0-9a-z]*]]: tensor<1x6x6xf32> +// CHECK-SAME: %[[FOR_IV_INIT:[0-9a-z]*]]: tensor<4x6xf32> +// CHECK-SAME: %[[CONS_INIT:[0-9a-z]*]]: tensor<4x2xf32> // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index @@ -372,22 +377,22 @@ func.func @rank_reduced_extract_slice( // For loop right after tensor alloc & fill, no linalg.generic. // CHECK-NOT: linalg.generic -// CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[ARG4]]) +// CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[FOR_IV_INIT]]) // Producer linalg.generic now inside the loop, with tiled args sliced before // it. -// CHECK-DAG: %[[ARG1_SLICE:.*]] = tensor.extract_slice %[[ARG1]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32> -// CHECK-DAG: %[[PROD_SLICE:.*]] = tensor.extract_slice %[[ARG3]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32> +// CHECK-DAG: %[[PROD_WEIGHT_SLICE:.*]] = tensor.extract_slice %[[PROD_WEIGHT]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32> +// CHECK-DAG: %[[PROD_INIT_SLICE:.*]] = tensor.extract_slice %[[PROD_INIT]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32> // CHECK: %[[MMUL_PROD:.*]] = linalg.generic -// CHECK-SAME: ins(%[[ARG0]], %[[ARG1_SLICE]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>) -// CHECK-SAME: outs(%[[PROD_SLICE]] : tensor<1x6x2xf32>) +// CHECK-SAME: ins(%[[PROD_IN]], %[[PROD_WEIGHT_SLICE]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>) +// CHECK-SAME: outs(%[[PROD_INIT_SLICE]] : tensor<1x6x2xf32>) // // Consumer uses a rank-reduced version of producer result so a collapse_shape // is generated. // CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0, 1\], \[2\]\]}} : tensor<1x6x2xf32> into tensor<6x2xf32> // CHECK: %[[MMUL_CONS:.*]] = linalg.generic -// CHECK-SAME: ins(%[[ARG2]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>) -// CHECK-SAME: outs(%[[ARG5]] : tensor<4x2xf32>) +// CHECK-SAME: ins(%[[CONS_IN]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>) +// CHECK-SAME: outs(%[[CONS_INIT]] : tensor<4x2xf32>) // CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32> // CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32> // CHECK: return %[[FOR]] : tensor<4x6xf32>