77#include " iree/compiler/Codegen/Common/GPU/Passes.h"
88#include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
99#include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
10- #include " iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
1110#include " iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
1211#include " iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
1312#include " iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
1413#include " llvm/ADT/DenseSet.h"
1514#include " llvm/Support/InterleavedRange.h"
15+ #include " mlir/Analysis/SliceAnalysis.h"
1616#include " mlir/Dialect/Arith/IR/Arith.h"
1717#include " mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
1818#include " mlir/Dialect/Linalg/Transforms/Transforms.h"
1919#include " mlir/Dialect/Tensor/IR/Tensor.h"
20- #include " mlir/Dialect/Tensor/Transforms/Transforms.h"
2120#include " mlir/Dialect/Utils/StaticValueUtils.h"
21+ #include " mlir/Dialect/Utils/StructuredOpsUtils.h"
2222#include " mlir/IR/Dominance.h"
2323#include " mlir/IR/PatternMatch.h"
2424#include " mlir/IR/TypeUtilities.h"
2525#include " mlir/IR/Value.h"
2626#include " mlir/Interfaces/FunctionInterfaces.h"
27- #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
2827
2928#define DEBUG_TYPE " iree-codegen-gpu-apply-padding-level"
3029
@@ -61,20 +60,95 @@ getTiledOps(Operation *funcOp, IREE::GPU::TilingLevel tilingLevel) {
6160 return targets;
6261}
6362
63+ // / For a reduction dimension of a linalg operation `linalgOp`,
64+ // /
65+ // / 1) `loopIndex` is the index of the loop in `linalgOp` that is reduced,
66+ // / 2) `operand` is the first operand indexed by the reduction dimension,
67+ // / 3) `operandDim` is the dimension of the operand that's reduced.
68+ // /
69+ // / Example. Consider the generic below with 2 reduction dimensions.
70+ // /
71+ // / ```mlir
72+ // / #map = affine_map<(d0, d1) -> (d1, d0)>
73+ // / #map1 = affine_map<(d0, d1) -> (d0, d1)>
74+ // / #map2 = affine_map<(d0, d1) -> ()>
75+ // / [...]
76+ // / %0 = linalg.generic {indexing_maps = [#map, #map1, #map2],
77+ // / iterator_types = ["reduction", "reduction"]}
78+ // / ins(%arg0, %arg1 : tensor<?x?xf16>, tensor<?x?xf16>)
79+ // / outs(%arg2 : tensor<f16>)
80+ // / [...]
81+ // / ```
82+ // /
83+ // / d0 has reduction info (loopIndex = 0, operand = %arg0, operandDim = 1),
84+ // / d1 has reduction info (loopIndex = 1, operand = %arg0, operandDim = 0).
85+ namespace {
86+ struct ReductionDimInfo {
87+ public:
88+ unsigned loopIndex;
89+ unsigned operandDim;
90+ Value operand;
91+ };
92+ } // namespace
93+
94+ static FailureOr<SmallVector<ReductionDimInfo>>
95+ getReductionInfo (linalg::LinalgOp linalgOp) {
96+ SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray ();
97+ SmallVector<utils::IteratorType> iteratorTypes =
98+ linalgOp.getIteratorTypesArray ();
99+
100+ unsigned numReductionDims =
101+ llvm::count (iteratorTypes, utils::IteratorType::reduction);
102+ SmallVector<ReductionDimInfo> reductionDimInfos;
103+ reductionDimInfos.reserve (numReductionDims);
104+
105+ for (unsigned loopIdx = 0 , endIdx = iteratorTypes.size (); loopIdx < endIdx;
106+ ++loopIdx) {
107+ if (!linalg::isReductionIterator (iteratorTypes[loopIdx]))
108+ continue ;
109+
110+ AffineExpr loopIdxExpr = getAffineDimExpr (loopIdx, linalgOp.getContext ());
111+ for (auto [operandIndex, indexingMap] : llvm::enumerate (indexingMaps)) {
112+ if (std::optional<unsigned > index =
113+ indexingMap.getResultPosition (loopIdxExpr)) {
114+ Value operand = linalgOp->getOperand (operandIndex);
115+ reductionDimInfos.push_back ({loopIdx, index.value (), operand});
116+ break ;
117+ }
118+ }
119+ }
120+
121+ if (reductionDimInfos.size () != numReductionDims)
122+ return failure ();
123+
124+ return reductionDimInfos;
125+ }
126+
64127static LogicalResult applyPaddingLevel (RewriterBase &rewriter,
65128 TilingInterface tilingInterfaceOp,
66129 IREE::GPU::TilingLevel tilingLevel) {
130+ // 1.a. Get padding values. The default should be poison, instead of 0.
131+ //
132+ // TODO(newling) pad with poison. Requires
133+ //
134+ // https://github.com/iree-org/iree/pull/21573
135+ // https://github.com/llvm/llvm-project/pull/152003
136+ // https://github.com/iree-org/iree/issues/21575
137+ //
138+ // The linalg operations can be padded with any value because we rewrite the
139+ // basic block to select the reduction identity for the yielded value if the
140+ // index corresponds to the padded part of the tensor.
141+ //
142+ // Non-linalg operations require special handling.
143+ // TODO: Extract the special handling into an upstream PaddingOpInterface.
67144
68- // 1.a. Get padding values.
69145 SmallVector<Attribute> paddingValues;
70146 for (Value operand : tilingInterfaceOp.getOperation ()->getOperands ()) {
71147 paddingValues.push_back (
72148 rewriter.getZeroAttr (getElementTypeOrSelf (operand.getType ())));
73149 }
74150
75- // 1.b. Special adjustment for OnlineAttention mask padding that needs to be
76- // mindful of softmax and pad to -inf.
77- // TODO: Extract into an upstream PaddingOpInterface.
151+ // 1.b. Special adjustment for OnlineAttention mask padding.
78152 if (auto onlineAttentionOp = dyn_cast<IREE::LinalgExt::OnlineAttentionOp>(
79153 tilingInterfaceOp.getOperation ())) {
80154 TypedValue<ShapedType> mask = onlineAttentionOp.getMask ();
@@ -108,10 +182,11 @@ static LogicalResult applyPaddingLevel(RewriterBase &rewriter,
108182 getAsIndexOpFoldResult (rewriter.getContext (), tileSizes);
109183
110184 // 3. Set options.
111- auto options = linalg::PadTilingInterfaceOptions ()
112- .setPaddingSizes (padSizes)
113- .setPaddingValues (paddingValues)
114- .setPadToMultipleOf (true );
185+ linalg::PadTilingInterfaceOptions options =
186+ linalg::PadTilingInterfaceOptions ()
187+ .setPaddingSizes (padSizes)
188+ .setPaddingValues (paddingValues)
189+ .setPadToMultipleOf (true );
115190
116191 LLVM_DEBUG (DBGS () << " Start padding " << *tilingInterfaceOp << " \n " ;
117192 DBGS () << " --with tile sizes: "
@@ -121,6 +196,19 @@ static LogicalResult applyPaddingLevel(RewriterBase &rewriter,
121196 DBGS () << " --with padToMultipleOf: " << options.padToMultipleOf
122197 << " \n " );
123198
199+ // For linalg ops, we will rewrite the basic block in a way that means padded
200+ // parts of tensors are never read. This is useful to avoid inferring what
201+ // padding values should be for non-trivial basic blocks.
202+ FailureOr<SmallVector<ReductionDimInfo>> reductionDimInfo;
203+ if (auto linalgOp =
204+ dyn_cast<linalg::LinalgOp>(tilingInterfaceOp.getOperation ())) {
205+ reductionDimInfo = getReductionInfo (linalgOp);
206+ if (failed (reductionDimInfo)) {
207+ tilingInterfaceOp.emitWarning (" failed to map reduction dimensions" );
208+ return failure ();
209+ }
210+ }
211+
124212 // 4. Pad.
125213 SmallVector<tensor::PadOp> padOps;
126214 FailureOr<TilingInterface> maybePaddedOp =
@@ -130,9 +218,81 @@ static LogicalResult applyPaddingLevel(RewriterBase &rewriter,
130218 return failure ();
131219 }
132220
133- // 5. For each PadOp, create a linalg::CopyOp to allow dim propagations.
134221 TilingInterface paddedOp = *maybePaddedOp;
135- for (auto padOp : padOps) {
222+
223+ if (auto paddedLinalgOp =
224+ dyn_cast<linalg::LinalgOp>(paddedOp.getOperation ())) {
225+ Block *block = paddedLinalgOp.getBlock ();
226+
227+ SmallVector<Operation *> reductions;
228+ for (auto [index, initOpOperand] :
229+ llvm::enumerate (paddedLinalgOp.getDpsInitsMutable ())) {
230+ SmallVector<Operation *> combinerOps;
231+ matchReduction (paddedLinalgOp.getRegionOutputArgs (), index, combinerOps);
232+ reductions.insert (reductions.begin (), combinerOps.begin (),
233+ combinerOps.end ());
234+ }
235+
236+ for (Operation *reduction : reductions) {
237+ std::optional<TypedAttr> reductionIdentity =
238+ arith::getNeutralElement (reduction);
239+ if (!reductionIdentity.has_value ()) {
240+ paddedOp.emitWarning (" failed to get neutral element for reduction" );
241+ return failure ();
242+ }
243+
244+ // Get the sizes of the reduction dimensions before padding:
245+ rewriter.setInsertionPoint (paddedOp.getOperation ());
246+ SmallVector<std::pair<unsigned , Value>> reductionDimSizes;
247+ assert (succeeded (reductionDimInfo) &&
248+ " obtained with confirmation earlier" );
249+ for (auto &&dimInfo : reductionDimInfo.value ()) {
250+ Value redDimSize = rewriter.create <tensor::DimOp>(
251+ paddedOp.getLoc (), dimInfo.operand , dimInfo.operandDim );
252+ reductionDimSizes.push_back ({dimInfo.loopIndex , redDimSize});
253+ }
254+
255+ // Add a check within the block to see if the current iteration over the
256+ // loops is inside or outside the padded part of the iteration space.
257+ rewriter.setInsertionPoint (reduction);
258+ SmallVector<Value> conds;
259+ for (auto &&[redDim, redDimSize] : reductionDimSizes) {
260+ Value redDimIndex =
261+ linalg::IndexOp::create (rewriter, paddedOp.getLoc (), redDim);
262+ Value cond = arith::CmpIOp::create (rewriter, paddedOp.getLoc (),
263+ arith::CmpIPredicate::ult,
264+ redDimIndex, redDimSize);
265+ conds.push_back (cond);
266+ }
267+ Value reductionIdentityValue = rewriter.create <arith::ConstantOp>(
268+ paddedOp.getLoc (), reductionIdentity.value ());
269+ assert (conds.size () > 0 );
270+ Value cond = conds[0 ];
271+ for (Value nxtCond : llvm::drop_begin (conds, 1 )) {
272+ cond = rewriter.create <arith::AndIOp>(paddedOp.getLoc (), cond, nxtCond);
273+ }
274+
275+ // Find the reduction op operand that is reduced with the carried output.
276+ if (reduction->getNumOperands () != 2 ) {
277+ paddedOp.emitWarning (" expected a reduction operation with 2 operands" );
278+ return failure ();
279+ }
280+ Value carry = block->getArguments ().back ();
281+ unsigned uncarryIndex = reduction->getOperand (0 ) == carry ? 1 : 0 ;
282+ Value uncarried = reduction->getOperand (uncarryIndex);
283+
284+ // Select the reduction identity value if in the padding region.
285+ Value selected = arith::SelectOp::create (
286+ rewriter, paddedOp.getLoc (), cond, uncarried, reductionIdentityValue);
287+ IRMapping mapping;
288+ mapping.map (reduction->getOperand (uncarryIndex), selected);
289+ Operation *redClone = rewriter.clone (*reduction, mapping);
290+ rewriter.replaceOp (reduction, redClone);
291+ }
292+ }
293+
294+ // 5. For each PadOp, create a linalg::CopyOp to allow dim propagations.
295+ for (tensor::PadOp padOp : padOps) {
136296 OpBuilder::InsertionGuard g (rewriter);
137297 rewriter.setInsertionPointAfter (padOp);
138298
0 commit comments