99#include " iree/compiler/Codegen/Dialect/GPU/IR/GPULoweringConfigUtils.h"
1010#include " iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
1111#include " iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUDialect.h"
12+ #include " iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
13+ #include " iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUInterfaces.h"
1214#include " iree/compiler/Codegen/Utils/LinalgOpInfo.h"
1315#include " iree/compiler/Codegen/Utils/Utils.h"
16+ #include " mlir/Dialect/Bufferization/IR/Bufferization.h"
17+ #include " mlir/Dialect/GPU/IR/GPUDialect.h"
1418#include " mlir/Dialect/Linalg/IR/Linalg.h"
1519#include " mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
1620#include " mlir/Dialect/Tensor/IR/Tensor.h"
@@ -25,11 +29,83 @@ namespace mlir::iree_compiler {
2529#include " iree/compiler/Codegen/Common/GPU/Passes.h.inc"
2630
2731namespace {
32+ // / Helper to insert copy with derived thread config.
33+ Value promoteValue (OpBuilder &builder, Location loc, Value v) {
34+ auto tensorType = cast<RankedTensorType>(v.getType ());
35+ SmallVector<OpFoldResult> mixedSizes = tensor::getMixedSizes (builder, loc, v);
36+ Value empty = builder.create <tensor::EmptyOp>(loc, mixedSizes,
37+ tensorType.getElementType ());
38+ auto copy = builder.create <linalg::CopyOp>(loc, v, empty);
39+ setLoweringConfig (
40+ copy, IREE::GPU::DerivedThreadConfigAttr::get (builder.getContext ()));
41+ return copy.getResult (0 );
42+ }
43+
44+ // / Helper to promote results. If the target value is consumed only by a
45+ // / `tensor.extract_slice`, this will promote the result of the slice instead.
46+ void promoteResult (OpBuilder &builder, Operation *op, Value valToMakeShared) {
47+ IRRewriter rewriter (builder);
48+ Location loc = op->getLoc ();
49+ OpBuilder::InsertionGuard g (rewriter);
50+ rewriter.setInsertionPointAfterValue (valToMakeShared);
51+ tensor::ExtractSliceOp extractSliceOp;
52+ SetVector<Operation *> opsToReplaceUseIn;
53+ Value valueToReplace = valToMakeShared;
54+ for (auto user : valToMakeShared.getUsers ()) {
55+ extractSliceOp = dyn_cast<tensor::ExtractSliceOp>(user);
56+ if (extractSliceOp) {
57+ // If the result is consumed by an extract_slice then we expect there to
58+ // be exactly one extract slice that is then consumed.
59+ // TODO (nirvedhmeshram) : This is fairly special case. Instead we should
60+ // just promote results before doing padding which introduces the extract
61+ // slice.
62+ if (!valToMakeShared.hasOneUse ())
63+ return ;
64+ valueToReplace = extractSliceOp.getResult ();
65+ for (auto user : extractSliceOp->getUsers ()) {
66+ opsToReplaceUseIn.insert (user);
67+ }
68+ break ;
69+ }
70+ opsToReplaceUseIn.insert (user);
71+ }
72+ auto tensorType = cast<RankedTensorType>(valToMakeShared.getType ());
73+ if (!tensorType) {
74+ return ;
75+ }
76+ SmallVector<Value> dynamicSizes;
77+ for (auto [idx, size] : llvm::enumerate (tensorType.getShape ())) {
78+ if (ShapedType::isDynamic (size)) {
79+ dynamicSizes.push_back (
80+ rewriter.create <tensor::DimOp>(loc, valToMakeShared, idx));
81+ }
82+ }
83+ Attribute addressSpace = gpu::AddressSpaceAttr::get (
84+ rewriter.getContext (), gpu::GPUDialect::getWorkgroupAddressSpace ());
85+ auto alloc = rewriter.create <bufferization::AllocTensorOp>(loc, tensorType,
86+ dynamicSizes);
87+ alloc.setMemorySpaceAttr (addressSpace);
88+ auto copy =
89+ rewriter.create <linalg::CopyOp>(loc, valToMakeShared, alloc.getResult ());
90+
91+ Value replacement = copy.getResult (0 );
92+ // If in extract slice is present we make it consume the new copy.
93+ if (extractSliceOp) {
94+ extractSliceOp.getSourceMutable ().assign (replacement);
95+ replacement = valueToReplace;
96+ }
97+
98+ rewriter.setInsertionPointAfterValue (replacement);
99+ replacement = promoteValue (rewriter, loc, replacement);
100+ valueToReplace.replaceUsesWithIf (replacement, [&](OpOperand &use) {
101+ return opsToReplaceUseIn.contains (use.getOwner ());
102+ });
103+ }
28104
29105// / Inserts a `linalg.copy` directly before the given operation on the
30106// / specified operand, for example with operand index = 1:
31107// /
32- // / linalg.matmul ins(%0, %1)
108+ // / %2 = linalg.matmul ins(%0, %1)
33109// /
34110// / becomes
35111// /
@@ -41,7 +117,24 @@ namespace {
41117// / If the producer is already a tilable op, the producer is just annotated with
42118// / #iree_gpu.derived_thread_config to indicate that it should be distributed
43119// / to threads independently of the matmul.
120+ // / Additionally we can also promote results so in above example we will
121+ // / generate for index = 2 :
122+ // / %out_buffer = bufferization.alloc_tensor
123+ // / %copy1 = linalg.copy %2 to %out_buffer
124+ // / %copy2 = linalg.copy %copy1 to %empty {
125+ // / lowering_config = #iree_gpu.derived_thread_config}
44126void promoteOperand (OpBuilder &builder, Operation *op, unsigned index) {
127+ auto dpsOp = dyn_cast<DestinationStyleOpInterface>(op);
128+ if (!dpsOp)
129+ return ;
130+ // We use the convention that if we are passing an index beyond the inputs
131+ // then we promote the result of the corresponding dps init.
132+ if (index >= dpsOp.getNumDpsInputs ()) {
133+ index -= dpsOp.getNumDpsInputs ();
134+ assert (index < op->getNumResults () &&
135+ " trying to promote out of bound result index" );
136+ return promoteResult (builder, op, op->getResult (index));
137+ }
45138 Value operand = op->getOperand (index);
46139
47140 if (auto producer = operand.getDefiningOp <TilingInterface>()) {
@@ -70,14 +163,8 @@ void promoteOperand(OpBuilder &builder, Operation *op, unsigned index) {
70163 return ;
71164 }
72165
73- SmallVector<OpFoldResult> mixedSizes =
74- tensor::getMixedSizes (builder, op->getLoc (), operand);
75- Value empty = builder.create <tensor::EmptyOp>(op->getLoc (), mixedSizes,
76- tensorType.getElementType ());
77- auto copy = builder.create <linalg::CopyOp>(op->getLoc (), operand, empty);
78- setLoweringConfig (
79- copy, IREE::GPU::DerivedThreadConfigAttr::get (builder.getContext ()));
80- op->setOperand (index, copy.getResult (0 ));
166+ auto replacement = promoteValue (builder, op->getLoc (), operand);
167+ op->setOperand (index, replacement);
81168}
82169
83170struct GPUPromoteMatmulOperandsPass final
0 commit comments