Skip to content

Commit 116343e

Browse files
newlingkeshavvinayak01
authored andcommitted
[Codegen] Select for pad value just before yielding (iree-org#21581)
BEFORE: all ops are padded with zero in `gpu-apply-padding` unless special cased, as for online attention. AFTER: a general solution for linalg operations where padding with 0 is not appropriate. Uses arith.select to ensure correct reduction id used. See PR description for more information. --------- Signed-off-by: James Newling <[email protected]> Signed-off-by: keshavvinayak01 <[email protected]>
1 parent 8fc362d commit 116343e

File tree

6 files changed

+455
-16
lines changed

6 files changed

+455
-16
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUApplyPaddingLevel.cpp

Lines changed: 173 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,23 @@
77
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
88
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
99
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenInterfaces.h"
10-
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.h"
1110
#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.h"
1211
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtInterfaces.h"
1312
#include "iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.h"
1413
#include "llvm/ADT/DenseSet.h"
1514
#include "llvm/Support/InterleavedRange.h"
15+
#include "mlir/Analysis/SliceAnalysis.h"
1616
#include "mlir/Dialect/Arith/IR/Arith.h"
1717
#include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h"
1818
#include "mlir/Dialect/Linalg/Transforms/Transforms.h"
1919
#include "mlir/Dialect/Tensor/IR/Tensor.h"
20-
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
2120
#include "mlir/Dialect/Utils/StaticValueUtils.h"
21+
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
2222
#include "mlir/IR/Dominance.h"
2323
#include "mlir/IR/PatternMatch.h"
2424
#include "mlir/IR/TypeUtilities.h"
2525
#include "mlir/IR/Value.h"
2626
#include "mlir/Interfaces/FunctionInterfaces.h"
27-
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
2827

2928
#define DEBUG_TYPE "iree-codegen-gpu-apply-padding-level"
3029

@@ -61,20 +60,95 @@ getTiledOps(Operation *funcOp, IREE::GPU::TilingLevel tilingLevel) {
6160
return targets;
6261
}
6362

63+
/// For a reduction dimension of a linalg operation `linalgOp`,
64+
///
65+
/// 1) `loopIndex` is the index of the loop in `linalgOp` that is reduced,
66+
/// 2) `operand` is the first operand indexed by the reduction dimension,
67+
/// 3) `operandDim` is the dimension of the operand that's reduced.
68+
///
69+
/// Example. Consider the generic below with 2 reduction dimensions.
70+
///
71+
/// ```mlir
72+
/// #map = affine_map<(d0, d1) -> (d1, d0)>
73+
/// #map1 = affine_map<(d0, d1) -> (d0, d1)>
74+
/// #map2 = affine_map<(d0, d1) -> ()>
75+
/// [...]
76+
/// %0 = linalg.generic {indexing_maps = [#map, #map1, #map2],
77+
/// iterator_types = ["reduction", "reduction"]}
78+
/// ins(%arg0, %arg1 : tensor<?x?xf16>, tensor<?x?xf16>)
79+
/// outs(%arg2 : tensor<f16>)
80+
/// [...]
81+
/// ```
82+
///
83+
/// d0 has reduction info (loopIndex = 0, operand = %arg0, operandDim = 1),
84+
/// d1 has reduction info (loopIndex = 1, operand = %arg0, operandDim = 0).
85+
namespace {
86+
struct ReductionDimInfo {
87+
public:
88+
unsigned loopIndex;
89+
unsigned operandDim;
90+
Value operand;
91+
};
92+
} // namespace
93+
94+
static FailureOr<SmallVector<ReductionDimInfo>>
95+
getReductionInfo(linalg::LinalgOp linalgOp) {
96+
SmallVector<AffineMap> indexingMaps = linalgOp.getIndexingMapsArray();
97+
SmallVector<utils::IteratorType> iteratorTypes =
98+
linalgOp.getIteratorTypesArray();
99+
100+
unsigned numReductionDims =
101+
llvm::count(iteratorTypes, utils::IteratorType::reduction);
102+
SmallVector<ReductionDimInfo> reductionDimInfos;
103+
reductionDimInfos.reserve(numReductionDims);
104+
105+
for (unsigned loopIdx = 0, endIdx = iteratorTypes.size(); loopIdx < endIdx;
106+
++loopIdx) {
107+
if (!linalg::isReductionIterator(iteratorTypes[loopIdx]))
108+
continue;
109+
110+
AffineExpr loopIdxExpr = getAffineDimExpr(loopIdx, linalgOp.getContext());
111+
for (auto [operandIndex, indexingMap] : llvm::enumerate(indexingMaps)) {
112+
if (std::optional<unsigned> index =
113+
indexingMap.getResultPosition(loopIdxExpr)) {
114+
Value operand = linalgOp->getOperand(operandIndex);
115+
reductionDimInfos.push_back({loopIdx, index.value(), operand});
116+
break;
117+
}
118+
}
119+
}
120+
121+
if (reductionDimInfos.size() != numReductionDims)
122+
return failure();
123+
124+
return reductionDimInfos;
125+
}
126+
64127
static LogicalResult applyPaddingLevel(RewriterBase &rewriter,
65128
TilingInterface tilingInterfaceOp,
66129
IREE::GPU::TilingLevel tilingLevel) {
130+
// 1.a. Get padding values. The default should be poison, instead of 0.
131+
//
132+
// TODO(newling) pad with poison. Requires
133+
//
134+
// https://github.com/iree-org/iree/pull/21573
135+
// https://github.com/llvm/llvm-project/pull/152003
136+
// https://github.com/iree-org/iree/issues/21575
137+
//
138+
// The linalg operations can be padded with any value because we rewrite the
139+
// basic block to select the reduction identity for the yielded value if the
140+
// index corresponds to the padded part of the tensor.
141+
//
142+
// Non-linalg operations require special handling.
143+
// TODO: Extract the special handling into an upstream PaddingOpInterface.
67144

68-
// 1.a. Get padding values.
69145
SmallVector<Attribute> paddingValues;
70146
for (Value operand : tilingInterfaceOp.getOperation()->getOperands()) {
71147
paddingValues.push_back(
72148
rewriter.getZeroAttr(getElementTypeOrSelf(operand.getType())));
73149
}
74150

75-
// 1.b. Special adjustment for OnlineAttention mask padding that needs to be
76-
// mindful of softmax and pad to -inf.
77-
// TODO: Extract into an upstream PaddingOpInterface.
151+
// 1.b. Special adjustment for OnlineAttention mask padding.
78152
if (auto onlineAttentionOp = dyn_cast<IREE::LinalgExt::OnlineAttentionOp>(
79153
tilingInterfaceOp.getOperation())) {
80154
TypedValue<ShapedType> mask = onlineAttentionOp.getMask();
@@ -108,10 +182,11 @@ static LogicalResult applyPaddingLevel(RewriterBase &rewriter,
108182
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
109183

110184
// 3. Set options.
111-
auto options = linalg::PadTilingInterfaceOptions()
112-
.setPaddingSizes(padSizes)
113-
.setPaddingValues(paddingValues)
114-
.setPadToMultipleOf(true);
185+
linalg::PadTilingInterfaceOptions options =
186+
linalg::PadTilingInterfaceOptions()
187+
.setPaddingSizes(padSizes)
188+
.setPaddingValues(paddingValues)
189+
.setPadToMultipleOf(true);
115190

116191
LLVM_DEBUG(DBGS() << "Start padding " << *tilingInterfaceOp << "\n";
117192
DBGS() << "--with tile sizes: "
@@ -121,6 +196,19 @@ static LogicalResult applyPaddingLevel(RewriterBase &rewriter,
121196
DBGS() << "--with padToMultipleOf: " << options.padToMultipleOf
122197
<< "\n");
123198

199+
// For linalg ops, we will rewrite the basic block in a way that means padded
200+
// parts of tensors are never read. This is useful to avoid inferring what
201+
// padding values should be for non-trivial basic blocks.
202+
FailureOr<SmallVector<ReductionDimInfo>> reductionDimInfo;
203+
if (auto linalgOp =
204+
dyn_cast<linalg::LinalgOp>(tilingInterfaceOp.getOperation())) {
205+
reductionDimInfo = getReductionInfo(linalgOp);
206+
if (failed(reductionDimInfo)) {
207+
tilingInterfaceOp.emitWarning("failed to map reduction dimensions");
208+
return failure();
209+
}
210+
}
211+
124212
// 4. Pad.
125213
SmallVector<tensor::PadOp> padOps;
126214
FailureOr<TilingInterface> maybePaddedOp =
@@ -130,9 +218,81 @@ static LogicalResult applyPaddingLevel(RewriterBase &rewriter,
130218
return failure();
131219
}
132220

133-
// 5. For each PadOp, create a linalg::CopyOp to allow dim propagations.
134221
TilingInterface paddedOp = *maybePaddedOp;
135-
for (auto padOp : padOps) {
222+
223+
if (auto paddedLinalgOp =
224+
dyn_cast<linalg::LinalgOp>(paddedOp.getOperation())) {
225+
Block *block = paddedLinalgOp.getBlock();
226+
227+
SmallVector<Operation *> reductions;
228+
for (auto [index, initOpOperand] :
229+
llvm::enumerate(paddedLinalgOp.getDpsInitsMutable())) {
230+
SmallVector<Operation *> combinerOps;
231+
matchReduction(paddedLinalgOp.getRegionOutputArgs(), index, combinerOps);
232+
reductions.insert(reductions.begin(), combinerOps.begin(),
233+
combinerOps.end());
234+
}
235+
236+
for (Operation *reduction : reductions) {
237+
std::optional<TypedAttr> reductionIdentity =
238+
arith::getNeutralElement(reduction);
239+
if (!reductionIdentity.has_value()) {
240+
paddedOp.emitWarning("failed to get neutral element for reduction");
241+
return failure();
242+
}
243+
244+
// Get the sizes of the reduction dimensions before padding:
245+
rewriter.setInsertionPoint(paddedOp.getOperation());
246+
SmallVector<std::pair<unsigned, Value>> reductionDimSizes;
247+
assert(succeeded(reductionDimInfo) &&
248+
"obtained with confirmation earlier");
249+
for (auto &&dimInfo : reductionDimInfo.value()) {
250+
Value redDimSize = rewriter.create<tensor::DimOp>(
251+
paddedOp.getLoc(), dimInfo.operand, dimInfo.operandDim);
252+
reductionDimSizes.push_back({dimInfo.loopIndex, redDimSize});
253+
}
254+
255+
// Add a check within the block to see if the current iteration over the
256+
// loops is inside or outside the padded part of the iteration space.
257+
rewriter.setInsertionPoint(reduction);
258+
SmallVector<Value> conds;
259+
for (auto &&[redDim, redDimSize] : reductionDimSizes) {
260+
Value redDimIndex =
261+
linalg::IndexOp::create(rewriter, paddedOp.getLoc(), redDim);
262+
Value cond = arith::CmpIOp::create(rewriter, paddedOp.getLoc(),
263+
arith::CmpIPredicate::ult,
264+
redDimIndex, redDimSize);
265+
conds.push_back(cond);
266+
}
267+
Value reductionIdentityValue = rewriter.create<arith::ConstantOp>(
268+
paddedOp.getLoc(), reductionIdentity.value());
269+
assert(conds.size() > 0);
270+
Value cond = conds[0];
271+
for (Value nxtCond : llvm::drop_begin(conds, 1)) {
272+
cond = rewriter.create<arith::AndIOp>(paddedOp.getLoc(), cond, nxtCond);
273+
}
274+
275+
// Find the reduction op operand that is reduced with the carried output.
276+
if (reduction->getNumOperands() != 2) {
277+
paddedOp.emitWarning("expected a reduction operation with 2 operands");
278+
return failure();
279+
}
280+
Value carry = block->getArguments().back();
281+
unsigned uncarryIndex = reduction->getOperand(0) == carry ? 1 : 0;
282+
Value uncarried = reduction->getOperand(uncarryIndex);
283+
284+
// Select the reduction identity value if in the padding region.
285+
Value selected = arith::SelectOp::create(
286+
rewriter, paddedOp.getLoc(), cond, uncarried, reductionIdentityValue);
287+
IRMapping mapping;
288+
mapping.map(reduction->getOperand(uncarryIndex), selected);
289+
Operation *redClone = rewriter.clone(*reduction, mapping);
290+
rewriter.replaceOp(reduction, redClone);
291+
}
292+
}
293+
294+
// 5. For each PadOp, create a linalg::CopyOp to allow dim propagations.
295+
for (tensor::PadOp padOp : padOps) {
136296
OpBuilder::InsertionGuard g(rewriter);
137297
rewriter.setInsertionPointAfter(padOp);
138298

compiler/src/iree/compiler/Codegen/Common/GPU/test/BUILD.bazel

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,8 @@ iree_lit_test_suite(
2121
"decompose_horizontally_fused_gemms.mlir",
2222
"gpu_alloc_private_memory_for_dps_ops.mlir",
2323
"gpu_apply_derived_thread_config.mlir",
24-
"gpu_apply_padding.mlir",
24+
"gpu_apply_padding_online_attention.mlir",
25+
"gpu_apply_padding_partial_reduction.mlir",
2526
"gpu_apply_tiling_level.mlir",
2627
"gpu_bubble_resource_casts.mlir",
2728
"gpu_check_resource_usage.mlir",

compiler/src/iree/compiler/Codegen/Common/GPU/test/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ iree_lit_test_suite(
1717
"decompose_horizontally_fused_gemms.mlir"
1818
"gpu_alloc_private_memory_for_dps_ops.mlir"
1919
"gpu_apply_derived_thread_config.mlir"
20-
"gpu_apply_padding.mlir"
20+
"gpu_apply_padding_online_attention.mlir"
21+
"gpu_apply_padding_partial_reduction.mlir"
2122
"gpu_apply_tiling_level.mlir"
2223
"gpu_bubble_resource_casts.mlir"
2324
"gpu_check_resource_usage.mlir"

0 commit comments

Comments
 (0)