|
28 | 28 | #include "mlir/IR/OpImplementation.h"
|
29 | 29 | #include "mlir/Pass/Pass.h"
|
30 | 30 | #include "mlir/Transforms/LoopUtils.h"
|
| 31 | +#include "llvm/ADT/TypeSwitch.h" |
31 | 32 | #include "llvm/Support/Debug.h"
|
32 | 33 |
|
33 | 34 | #define DEBUG_TYPE "linalg-utils"
|
@@ -194,6 +195,48 @@ IntegerAttr getSmallestBoundingIndex(Value size) {
|
194 | 195 | return nullptr;
|
195 | 196 | }
|
196 | 197 |
|
| 198 | +tensor::ExtractSliceOp makeComposedExtractSliceOp( |
| 199 | + OpBuilder &b, Location loc, Value source, ArrayRef<OpFoldResult> offsets, |
| 200 | + ArrayRef<OpFoldResult> sizes, ArrayRef<OpFoldResult> strides) { |
| 201 | + assert(source && "expect source to be nonzero"); |
| 202 | + |
| 203 | + // Do not fold if the producer is not an ExtractSliceOp. |
| 204 | + auto producerOp = source.getDefiningOp<tensor::ExtractSliceOp>(); |
| 205 | + if (!producerOp) |
| 206 | + return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes, |
| 207 | + strides); |
| 208 | + |
| 209 | + // Do not fold if the producer is rank reducing or if there are any non-unit |
| 210 | + // strides. Supporting non-unit strides complicates the offset computation |
| 211 | + // since the consumer offsets need to be multiplied by the producer strides. |
| 212 | + // TODO: support non-unit strides once there are use cases. |
| 213 | + SmallVector<OpFoldResult> allStrides = producerOp.getMixedStrides(); |
| 214 | + allStrides.append(strides.begin(), strides.end()); |
| 215 | + bool hasNonUnitStride = any_of(allStrides, [](OpFoldResult ofr) { |
| 216 | + return getConstantIntValue(ofr) != static_cast<int64_t>(1); |
| 217 | + }); |
| 218 | + if (hasNonUnitStride || |
| 219 | + producerOp.getSourceType().getRank() != |
| 220 | + producerOp.getResult().getType().cast<ShapedType>().getRank()) |
| 221 | + return b.create<tensor::ExtractSliceOp>(loc, source, offsets, sizes, |
| 222 | + strides); |
| 223 | + |
| 224 | + // Fold the producer by adding the offests and extracting the slice directly |
| 225 | + // from the producer source tensor. |
| 226 | + SmallVector<OpFoldResult> foldedOffsets(offsets.begin(), offsets.end()); |
| 227 | + AffineExpr dim1, dim2; |
| 228 | + bindDims(b.getContext(), dim1, dim2); |
| 229 | + for (auto en : enumerate(producerOp.getMixedOffsets())) { |
| 230 | + SmallVector<Value> offsetValues = { |
| 231 | + getValueOrCreateConstantIndexOp(b, loc, foldedOffsets[en.index()]), |
| 232 | + getValueOrCreateConstantIndexOp(b, loc, en.value())}; |
| 233 | + foldedOffsets[en.index()] = |
| 234 | + makeComposedAffineApply(b, loc, dim1 + dim2, offsetValues).getResult(); |
| 235 | + } |
| 236 | + return b.create<tensor::ExtractSliceOp>(loc, producerOp.source(), |
| 237 | + foldedOffsets, sizes, strides); |
| 238 | +} |
| 239 | + |
197 | 240 | /// Specialization to build an scf "for" nest.
|
198 | 241 | template <>
|
199 | 242 | void GenerateLoopNest<scf::ForOp>::doit(
|
@@ -603,15 +646,18 @@ Value makeTiledShape(OpBuilder &builder, Location loc, Value valueToTile,
|
603 | 646 | strides.push_back(builder.getIndexAttr(1));
|
604 | 647 | }
|
605 | 648 |
|
606 |
| - Operation *sliceOp = shapedType.isa<MemRefType>() |
607 |
| - ? builder |
608 |
| - .create<memref::SubViewOp>( |
609 |
| - loc, valueToTile, offsets, sizes, strides) |
610 |
| - .getOperation() |
611 |
| - : builder |
612 |
| - .create<tensor::ExtractSliceOp>( |
613 |
| - loc, valueToTile, offsets, sizes, strides) |
614 |
| - .getOperation(); |
| 649 | + auto *sliceOp = TypeSwitch<ShapedType, Operation *>(shapedType) |
| 650 | + .Case([&](MemRefType) { |
| 651 | + return builder.create<memref::SubViewOp>( |
| 652 | + loc, valueToTile, offsets, sizes, strides); |
| 653 | + }) |
| 654 | + .Case([&](RankedTensorType) { |
| 655 | + return makeComposedExtractSliceOp( |
| 656 | + builder, loc, valueToTile, offsets, sizes, strides); |
| 657 | + }) |
| 658 | + .Default([](ShapedType) -> Operation * { |
| 659 | + llvm_unreachable("Unexpected shaped type"); |
| 660 | + }); |
615 | 661 | return sliceOp->getResult(0);
|
616 | 662 | }
|
617 | 663 |
|
|
0 commit comments