diff --git a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h index 22ca8a99dd7db..1a4733df3f187 100644 --- a/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Tensor/Utils/Utils.h @@ -43,6 +43,11 @@ FailureOr computeTransposedType(RankedTensorType rankedTensorType, ArrayRef transposeVector); +/// Create tensor.collapse_shape to drop unit dimensions in `dropDims` in tensor +/// `src`. +CollapseShapeOp dropGivenUnitDims(OpBuilder &b, Location loc, Value src, + const llvm::SmallBitVector &dropDims); + /// A tensor.insert_slice is a cast-like operation if it merely rank-extends the /// source tensor or inserts the source tensor into a destination tensor with /// the same shape. diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index fcfb499bb1332..4fc8a17554435 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/AffineExpr.h" #include "mlir/IR/AffineMap.h" #include "mlir/IR/Dominance.h" @@ -26,6 +27,7 @@ #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/ScopeExit.h" +#include "llvm/ADT/SmallBitVector.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" @@ -271,12 +273,20 @@ mlir::linalg::fuseProducerOfTensor(OpBuilder &b, OpResult producerOpResult, consumerOpOperand); // Replace use. + Value def = fusedProducer->getResult(producerOpResult.getResultNumber()); + Type consumerType = consumerOpOperand.get().getType(); + // Check if rank-reduction occurred as part of the extract_slice. If yes, + // collapse the dropped dimensions. + if (cast(consumerType).getRank() != + cast(def.getType()).getRank()) { + llvm::SmallBitVector droppedDims = sliceOp.getDroppedDims(); + def = + tensor::dropGivenUnitDims(b, fusedProducer.getLoc(), def, droppedDims); + } // Canonicalizations are not guaranteed to have happened before constructing // `fusedProducer`. In the tensor case this can result in temporary type // mismatches. Insert a `tensor.cast` op to propagate the transformation // invariant that types are compatible. - Value def = fusedProducer->getResult(producerOpResult.getResultNumber()); - Type consumerType = consumerOpOperand.get().getType(); if (consumerType != def.getType()) def = b.create(fusedProducer.getLoc(), consumerType, def); consumerOpOperand.set(def); diff --git a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp index c3d56759a896a..11ae0108594dd 100644 --- a/mlir/lib/Dialect/Tensor/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Tensor/Utils/Utils.cpp @@ -94,6 +94,37 @@ mlir::tensor::computeTransposedType(RankedTensorType rankedTensorType, return transposedTensorType; } +CollapseShapeOp +mlir::tensor::dropGivenUnitDims(OpBuilder &b, Location loc, Value src, + const llvm::SmallBitVector &dropDims) { + auto srcType = cast(src.getType()); + int64_t rank = srcType.getRank(); + assert(rank == static_cast(dropDims.size()) && + "dropDims dimension does not match src tensor rank"); + assert(llvm::all_of( + dropDims.set_bits(), + [&](unsigned dim) { return srcType.getShape()[dim] == 1; }) && + "Dropping non unit dimension"); + // Computed reassociation map for the corresponding tensor.collapse_shape. + SmallVector reassocMaps; + // Current reassociation group to add dropped dimension to. + + int64_t nextDimToGroup = 0; + llvm::SmallBitVector keptDims(dropDims); + keptDims.flip(); + int64_t lastSetBit = keptDims.find_last(); + for (int64_t setBit : keptDims.set_bits()) { + // Group consecutive dropped dimension with the next non-dropped dimension. + // If this is the last set dimension, also group all subsequent dropped + // dimension, if any. + int64_t upTo = setBit == lastSetBit ? rank - 1 : setBit; + auto seq = llvm::seq_inclusive(nextDimToGroup, upTo); + reassocMaps.emplace_back(llvm::make_range(seq.begin(), seq.end())); + nextDimToGroup = setBit + 1; + } + return b.create(loc, src, reassocMaps); +} + bool mlir::tensor::isCastLikeInsertSliceOp(InsertSliceOp op) { llvm::SmallBitVector droppedDims = op.getDroppedDims(); int64_t srcDim = 0; diff --git a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir index 0f27a92c119cf..fd755a208b2c9 100644 --- a/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir +++ b/mlir/test/Dialect/Linalg/tile-and-fuse-tensors.mlir @@ -318,3 +318,81 @@ func.func @pad_generic_static(%small_input: tensor<58x1xf32>, %large_input: tens } return %for0 : tensor<64x128xf32> } + +// ----- + +#map0 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)> +#map1 = affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)> +#map2 = affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)> +#map3 = affine_map<(d0, d1, d2) -> (d0, d2)> +#map4 = affine_map<(d0, d1, d2) -> (d2, d1)> +#map5 = affine_map<(d0, d1, d2) -> (d0, d1)> +func.func @rank_reduced_extract_slice( + %prod_in: tensor<1x6x5xf32>, %prod_weight: tensor<1x5x6xf32>, + %cons_in: tensor<4x6xf32>, %prod_init: tensor<1x6x6xf32>, + %for_iv_init: tensor<4x6xf32>, %cons_init: tensor<4x2xf32> +) -> tensor<4x6xf32> { + %c0 = arith.constant 0 : index + %c2 = arith.constant 2 : index + %c6 = arith.constant 6 : index + %mmul_prod = linalg.generic + {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "reduction"]} + ins(%prod_in, %prod_weight : tensor<1x6x5xf32>, tensor<1x5x6xf32>) outs(%prod_init : tensor<1x6x6xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %10 = arith.mulf %in, %in_1 : f32 + %11 = arith.addf %out, %10 : f32 + linalg.yield %11 : f32 + } -> tensor<1x6x6xf32> + %for = scf.for %arg7 = %c0 to %c6 step %c2 iter_args(%arg6 = %for_iv_init) -> (tensor<4x6xf32>) { + + // Extract slice with rank-reduced result type. When fused in the loop + // with sliced operands, the producer linalg must have its now sliced + // result be rank-reduced as well to match consumer's use type. + %prod_slice = tensor.extract_slice %mmul_prod[0, 0, %arg7] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<6x2xf32> + %mmul_cons = linalg.generic + {indexing_maps = [#map3, #map4, #map5], iterator_types = ["parallel", "parallel", "reduction"]} + ins(%cons_in, %prod_slice : tensor<4x6xf32>, tensor<6x2xf32>) outs(%cons_init : tensor<4x2xf32>) { + ^bb0(%in: f32, %in_1: f32, %out: f32): + %20 = arith.mulf %in, %in_1 : f32 + %21 = arith.addf %out, %20 : f32 + linalg.yield %21 : f32 + } -> tensor<4x2xf32> + %4 = tensor.insert_slice %mmul_cons into %arg6[0, %arg7] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32> + scf.yield %4 : tensor<4x6xf32> + } + return %for : tensor<4x6xf32> +} + +// CHECK: func @rank_reduced_extract_slice( +// CHECK-SAME: %[[PROD_IN:[0-9a-z]*]]: tensor<1x6x5xf32> +// CHECK-SAME: %[[PROD_WEIGHT:[0-9a-z]*]]: tensor<1x5x6xf32> +// CHECK-SAME: %[[CONS_IN:[0-9a-z]*]]: tensor<4x6xf32> +// CHECK-SAME: %[[PROD_INIT:[0-9a-z]*]]: tensor<1x6x6xf32> +// CHECK-SAME: %[[FOR_IV_INIT:[0-9a-z]*]]: tensor<4x6xf32> +// CHECK-SAME: %[[CONS_INIT:[0-9a-z]*]]: tensor<4x2xf32> + +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[C6:.*]] = arith.constant 6 : index + +// For loop right after tensor alloc & fill, no linalg.generic. +// CHECK-NOT: linalg.generic +// CHECK-NEXT: %[[FOR:.*]] = scf.for %[[I:[0-9a-z]*]] = %[[C0]] to %[[C6]] step %[[C2]] iter_args(%[[ARG_ITER:.*]] = %[[FOR_IV_INIT]]) + +// Producer linalg.generic now inside the loop, with tiled args sliced before +// it. +// CHECK-DAG: %[[PROD_WEIGHT_SLICE:.*]] = tensor.extract_slice %[[PROD_WEIGHT]][0, 0, %[[I]]] [1, 5, 2] [1, 1, 1] : tensor<1x5x6xf32> to tensor<1x5x2xf32> +// CHECK-DAG: %[[PROD_INIT_SLICE:.*]] = tensor.extract_slice %[[PROD_INIT]][0, 0, %[[I]]] [1, 6, 2] [1, 1, 1] : tensor<1x6x6xf32> to tensor<1x6x2xf32> +// CHECK: %[[MMUL_PROD:.*]] = linalg.generic +// CHECK-SAME: ins(%[[PROD_IN]], %[[PROD_WEIGHT_SLICE]] : tensor<1x6x5xf32>, tensor<1x5x2xf32>) +// CHECK-SAME: outs(%[[PROD_INIT_SLICE]] : tensor<1x6x2xf32>) +// +// Consumer uses a rank-reduced version of producer result so a collapse_shape +// is generated. +// CHECK: %[[PROD_COLLAPSE:.*]] = tensor.collapse_shape %[[MMUL_PROD]] {{\[\[0, 1\], \[2\]\]}} : tensor<1x6x2xf32> into tensor<6x2xf32> +// CHECK: %[[MMUL_CONS:.*]] = linalg.generic +// CHECK-SAME: ins(%[[CONS_IN]], %[[PROD_COLLAPSE]] : tensor<4x6xf32>, tensor<6x2xf32>) +// CHECK-SAME: outs(%[[CONS_INIT]] : tensor<4x2xf32>) +// CHECK: %[[CONS_SLICE:.*]] = tensor.insert_slice %[[MMUL_CONS]] into %[[ARG_ITER]][0, %[[I]]] [4, 2] [1, 1] : tensor<4x2xf32> into tensor<4x6xf32> +// CHECK: scf.yield %[[CONS_SLICE]] : tensor<4x6xf32> +// CHECK: return %[[FOR]] : tensor<4x6xf32>