Skip to content

Commit 3167930

Browse files
committed
[AMD] Improved CanonicalizePointers for ExtractSlice
1 parent 1b2a86b commit 3167930

File tree

7 files changed

+171
-61
lines changed

7 files changed

+171
-61
lines changed

third_party/amd/backend/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,7 @@ def make_ttgir(mod, metadata, options):
263263
use_block_pingpong = is_pingpong_schedule_enabled(options.arch)
264264
if use_block_pingpong and options.num_stages in [2, 4]:
265265
amd.passes.ttgpuir.add_block_pingpong(pm, options.num_stages, use_async_copy)
266+
passes.ttgpuir.add_remove_layout_conversions(pm)
266267

267268
if knobs.amd.use_buffer_ops:
268269
amd.passes.ttgpuir.add_canonicalize_pointers(pm)

third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp

Lines changed: 47 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
#include "../TritonAMDGPUToLLVM/Utility.h"
12
#include "Dialect/TritonAMDGPU/IR/Dialect.h"
23
#include "TritonAMDGPUToLLVM/GCNAsmFormat.h"
34
#include "mlir/Conversion/LLVMCommon/Pattern.h"
@@ -49,6 +50,7 @@ using namespace mlir::triton;
4950
// clang-format on
5051

5152
namespace {
53+
5254
struct ExtractSliceOpConversion
5355
: public ConvertOpToLLVMPattern<amdgpu::ExtractSliceOp> {
5456
explicit ExtractSliceOpConversion(LLVMTypeConverter &typeConverter,
@@ -60,61 +62,61 @@ struct ExtractSliceOpConversion
6062
ConversionPatternRewriter &rewriter) const {
6163
Location loc = op->getLoc();
6264
auto srcTy = cast<RankedTensorType>(op.getSource().getType());
63-
auto srcLayout = srcTy.getEncoding();
65+
auto dstTy = cast<RankedTensorType>(op.getType());
6466
auto srcShape = srcTy.getShape();
65-
auto resultTy = cast<RankedTensorType>(op.getType());
66-
auto vals = unpackLLElements(loc, adaptor.getSource(), rewriter);
67-
auto elemsPerThread = triton::gpu::getElemsPerThread(srcTy);
68-
auto contigPerThread = triton::gpu::getContigPerThread(srcTy);
69-
auto totalContigPerThread = product<unsigned>(contigPerThread);
70-
auto order = triton::gpu::getOrder(srcTy);
67+
auto dstShape = dstTy.getShape();
7168

72-
// Calculate valid total number of workers in each dimension
69+
auto vals = unpackLLElements(loc, adaptor.getSource(), rewriter);
7370
auto shapePerCTATile = triton::gpu::getShapePerCTATile(srcTy);
74-
shapePerCTATile[0] =
75-
std::min(static_cast<unsigned>(srcShape[0]), shapePerCTATile[0]);
76-
shapePerCTATile[1] =
77-
std::min(static_cast<unsigned>(srcShape[1]), shapePerCTATile[1]);
78-
79-
// Rank == 2 checked in the verifier
80-
SmallVector<int64_t, 2> sizes;
81-
for (auto i = 0; i < 2; ++i) {
82-
sizes.push_back(resultTy.getDimSize(i));
83-
}
71+
auto srcCTAShape = LLVM::AMD::multiDimElementwise<int64_t, unsigned>(
72+
srcShape, shapePerCTATile, std::divides<unsigned>());
73+
auto dstCTAShape = LLVM::AMD::multiDimElementwise<int64_t, unsigned>(
74+
dstShape, shapePerCTATile, std::divides<unsigned>());
8475

76+
auto numCTATiles = std::accumulate(dstCTAShape.begin(), dstCTAShape.end(),
77+
1, std::multiplies<>());
8578
auto offsets = op.getStaticOffsets();
79+
auto firstTileCoordinate =
80+
LLVM::AMD::multiDimElementwise<int64_t, unsigned>(
81+
offsets, shapePerCTATile, std::divides<unsigned>());
8682

87-
// Calculate offsets and sizes in terms of CTA units.
88-
std::array<int64_t, 2> CTAOffsets{offsets[0] / shapePerCTATile[0],
89-
offsets[1] / shapePerCTATile[1]};
90-
std::array<int64_t, 2> CTASizes{sizes[0] / shapePerCTATile[0],
91-
sizes[1] / shapePerCTATile[1]};
92-
std::array<int64_t, 2> CTAPerShape{srcShape[0] / shapePerCTATile[0],
93-
srcShape[1] / shapePerCTATile[1]};
94-
95-
// The diagram above illustrates the graphical representation of the
96-
// skipElems, tensorStride, and lastIdx variables.
97-
auto skipElems = CTAOffsets[order[1]] * (elemsPerThread[order[0]] *
98-
contigPerThread[order[1]]) +
99-
CTAOffsets[order[0]] * totalContigPerThread;
100-
auto tensorStride =
101-
(CTAPerShape[order[0]] - CTASizes[order[0]]) * totalContigPerThread;
102-
auto lastIdx =
103-
(CTAOffsets[order[1]] + CTASizes[order[1]] - 1) *
104-
elemsPerThread[order[0]] * contigPerThread[order[1]] +
105-
(CTAOffsets[order[0]] + CTASizes[order[0]]) * totalContigPerThread;
106-
107-
assert(lastIdx <= vals.size());
83+
Attribute srcEncoding = srcTy.getEncoding();
84+
Attribute dstEncoding = dstTy.getEncoding();
85+
auto linearLayoutSrc = triton::gpu::toLinearLayout(srcShape, srcEncoding);
86+
auto linearLayoutDst = triton::gpu::toLinearLayout(dstShape, dstEncoding);
10887

88+
auto srcCTAOrder =
89+
LLVM::AMD::getCTATileOrder(srcTy.getContext(), linearLayoutSrc);
90+
auto dstCTAOrder =
91+
LLVM::AMD::getCTATileOrder(srcTy.getContext(), linearLayoutDst);
92+
93+
unsigned elemsPerThreadPerCTA =
94+
triton::gpu::getTotalElemsPerThread(srcTy) /
95+
std::accumulate(srcCTAShape.begin(), srcCTAShape.end(), 1,
96+
std::multiplies<>());
97+
98+
// 1. Process CTA tiles in the destination tensor according to the
99+
// destination's linear layout order of CTA tiles.
100+
// 2. For each tile position in the destination tensor, compute its
101+
// corresponding position in the source tensor.
102+
// 3. Copy the values from the source tile to the destination slice.
109103
SmallVector<Value> resultVals;
110-
for (int i = skipElems; i < lastIdx; i += tensorStride) {
111-
for (int j = 0; j < totalContigPerThread * CTASizes[order[0]]; ++j, ++i) {
112-
assert(i < lastIdx);
113-
resultVals.push_back(vals[i]);
104+
for (size_t i = 0; i < numCTATiles; i++) {
105+
auto coordInDstTensor =
106+
mlir::LLVM::delinearize(i, dstCTAShape, dstCTAOrder);
107+
auto coordInSrcTensor =
108+
LLVM::AMD::multiDimElementwise<unsigned, unsigned>(
109+
coordInDstTensor, firstTileCoordinate, std::plus<unsigned>());
110+
auto linearIdxInSrcTensor =
111+
mlir::LLVM::linearize(coordInSrcTensor, srcCTAShape, srcCTAOrder);
112+
113+
for (size_t j = 0; j < elemsPerThreadPerCTA; j++) {
114+
resultVals.push_back(
115+
vals[linearIdxInSrcTensor * elemsPerThreadPerCTA + j]);
114116
}
115117
}
116118
Value ret = packLLElements(loc, this->getTypeConverter(), resultVals,
117-
rewriter, resultTy);
119+
rewriter, dstTy);
118120

119121
rewriter.replaceOp(op, ret);
120122
return success();
@@ -124,11 +126,7 @@ struct ExtractSliceOpConversion
124126
matchAndRewrite(amdgpu::ExtractSliceOp op, OpAdaptor adaptor,
125127
ConversionPatternRewriter &rewriter) const override {
126128
auto srcTy = op.getSource().getType();
127-
if (isa<BlockedEncodingAttr, AMDMfmaEncodingAttr>(
128-
op.getSource().getType().getEncoding())) {
129-
return processLayout(op, adaptor, rewriter);
130-
}
131-
return failure();
129+
return processLayout(op, adaptor, rewriter);
132130
}
133131
};
134132
} // namespace

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.cpp

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -755,4 +755,43 @@ void addLocalLoadNoAliasScope(AliasAnalysisOpInterface llLoadOp) {
755755
llLoadOp.setAliasScopes(aliasScopes);
756756
}
757757

758+
SmallVector<unsigned> getCTATileOrder(MLIRContext *ctx,
759+
const LinearLayout &layout) {
760+
auto llEnc = triton::gpu::LinearEncodingAttr::get(ctx, layout);
761+
auto regDim = StringAttr::get(ctx, "register");
762+
auto &bases = layout.getBases().find(regDim)->second;
763+
764+
// Compute number of CTA tiles in a layout.
765+
unsigned totalElems = layout.getTotalOutDimSize();
766+
auto ctaShape = llEnc.getShapePerCTATile();
767+
unsigned elemsPerCTA =
768+
std::accumulate(ctaShape.begin(), ctaShape.end(), 1, std::multiplies<>());
769+
assert((totalElems % elemsPerCTA) == 0 &&
770+
"Total elements must be divisible by elemsPerCTA");
771+
unsigned numCTAs = totalElems / elemsPerCTA;
772+
773+
// To determine the CTA tile order, start by identifying the register basis
774+
// vector that corresponds to the first element of the second CTA tile. The
775+
// nonzero index in the logical tensor it maps to indicates the most minor
776+
// dimension. Then, for each subsequent basis register (first element of
777+
// some CTA tile), extract the next nonzero index to build the full dimension
778+
// order.
779+
unsigned totalPerThread =
780+
product(llEnc.basesPerDim(regDim, /*skipBroadcast=*/false)) / numCTAs;
781+
unsigned startIndex = static_cast<unsigned>(std::log2(totalPerThread));
782+
783+
llvm::SmallSetVector<unsigned, 8> order;
784+
for (unsigned i = startIndex; i < bases.size(); ++i) {
785+
auto it = std::find_if(bases[i].begin(), bases[i].end(),
786+
[](unsigned v) { return v != 0; });
787+
if (it != bases[i].end())
788+
order.insert(std::distance(bases[i].begin(), it));
789+
}
790+
791+
// Append any dims missing from our default order.
792+
for (unsigned dim : llEnc.getOrder())
793+
order.insert(dim);
794+
795+
return SmallVector<unsigned>(order.begin(), order.end());
796+
}
758797
} // namespace mlir::LLVM::AMD

third_party/amd/lib/TritonAMDGPUToLLVM/Utility.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,23 @@ void addLocalLoadNoAliasScope(AliasAnalysisOpInterface llLoadOp);
137137
// Attaches the "AsyncCopies" alias scope to llLoadDirectToLdsOp
138138
void addAsyncCopyAliasScope(AliasAnalysisOpInterface llLoadDirectToLdsOp);
139139

140+
// Determine the order in which CTA tiles are laid out across the tensor.
141+
SmallVector<unsigned> getCTATileOrder(MLIRContext *ctx,
142+
const LinearLayout &layout);
143+
144+
template <typename T, typename U, typename BinaryOp>
145+
std::vector<unsigned> multiDimElementwise(const ArrayRef<T> &lhs,
146+
const ArrayRef<U> &rhs, BinaryOp op) {
147+
assert(lhs.size() == rhs.size() && "Input dimensions must match");
148+
std::vector<unsigned> result;
149+
result.reserve(lhs.size());
150+
for (size_t i = 0, n = lhs.size(); i < n; ++i) {
151+
unsigned a = static_cast<unsigned>(lhs[i]);
152+
unsigned b = static_cast<unsigned>(rhs[i]);
153+
result.push_back(op(a, b));
154+
}
155+
return result;
156+
}
140157
} // namespace mlir::LLVM::AMD
141158

142159
#endif // TRITON_THIRD_PARTY_AMD_LIB_TRITONAMDGPUTOLLVM_UTILITY_H_

third_party/amd/lib/TritonAMDGPUTransforms/BlockPingpong.cpp

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1518,6 +1518,12 @@ LogicalResult Pingponger::transformFP4(OpBuilder &builder, Location loc) {
15181518
builder.setInsertionPointAfter(dotSOps[0]);
15191519
if (sliceDotScaled(builder, loc, dotSOps[0], 4).failed())
15201520
return failure();
1521+
1522+
if (genAsyncCopySlices(builder).failed()) {
1523+
LDBG("failed to slice global-to-local async copies");
1524+
return failure();
1525+
}
1526+
15211527
updateOpInsertion(dotSliceOps[0]);
15221528

15231529
appendOp(builder.create<ROCDL::SchedBarrier>(loc, 0));
@@ -1681,10 +1687,6 @@ void Pingponger::getDotPingponged() {
16811687
return;
16821688
}
16831689

1684-
if (llvm::failed(genAsyncCopySlices(builder))) {
1685-
LDBG("failed to slice global-to-local async copies");
1686-
}
1687-
16881690
auto updateSignature = updateForOpSignature(builder);
16891691
if (llvm::failed(updateSignature)) {
16901692
LDBG("failed to update forOp signature");
@@ -1695,6 +1697,18 @@ void Pingponger::getDotPingponged() {
16951697
LDBG("failed to update forOp signature");
16961698
}
16971699
}
1700+
1701+
forOp->walk([](ttg::AsyncCommitGroupOp groupOp) {
1702+
auto users = groupOp.getResult().getUsers();
1703+
if (users.empty()) {
1704+
SmallVector<Operation *> toDeleteVec;
1705+
for (auto token : groupOp.getInputTokens()) {
1706+
toDeleteVec.push_back(token.getDefiningOp());
1707+
}
1708+
groupOp->erase();
1709+
llvm::for_each(toDeleteVec, [](Operation *op) { op->erase(); });
1710+
}
1711+
});
16981712
addAsymmetricSyncToLoop(builder, loc);
16991713
return;
17001714
}

third_party/amd/lib/TritonAMDGPUTransforms/CanonicalizePointers.cpp

Lines changed: 45 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1176,20 +1176,57 @@ class ConvertExtractSliceOp
11761176
}
11771177

11781178
Location loc = extractSliceOp->getLoc();
1179-
1179+
RankedTensorType resultType = extractSliceOp.getResult().getType();
11801180
const FatPointers::FatPtrAttrs &fatPtrAttrs =
11811181
fatPtrs.at({fatPtrBase, fatPtrOffset});
1182-
auto newSrc = createTensorPointer(rewriter, fatPtrBase, fatPtrOffset, loc,
1183-
fatPtrAttrs);
11841182

1185-
RankedTensorType resType = extractSliceOp.getResult().getType();
1186-
tt::amdgpu::ExtractSliceOp newExtractSliceOp =
1183+
Value newFatPtrOffset = nullptr;
1184+
auto origFatOffsetType = dyn_cast<RankedTensorType>(fatPtrOffset.getType());
1185+
auto slicedFatOffsetType = RankedTensorType::get(
1186+
resultType.getShape(), origFatOffsetType.getElementType(),
1187+
origFatOffsetType.getEncoding());
1188+
1189+
tt::amdgpu::ExtractSliceOp slicedFatPtrOffset =
11871190
rewriter.create<tt::amdgpu::ExtractSliceOp>(
1188-
loc, Type{resType}, Value{newSrc},
1191+
loc, Type{slicedFatOffsetType}, Value{fatPtrOffset},
11891192
extractSliceOp.getStaticOffsetsAttr());
1190-
rewriter.replaceOp(extractSliceOp, newExtractSliceOp);
1191-
fatPtrs[{fatPtrBase, newExtractSliceOp}] =
1193+
1194+
auto newResultPtrType =
1195+
RankedTensorType::get(resultType.getShape(), fatPtrBase.getType(),
1196+
origFatOffsetType.getEncoding());
1197+
1198+
// Scalar case: we only need to `tt.addptr %basePtr, %offset`
1199+
if (!origFatOffsetType) {
1200+
auto addPtrOp = rewriter.create<tt::AddPtrOp>(
1201+
loc, newResultPtrType, fatPtrBase, slicedFatPtrOffset);
1202+
for (const auto &attribute : fatPtrAttrs.attributes)
1203+
addPtrOp->setAttr(attribute.getFirst(), attribute.getSecond());
1204+
newFatPtrOffset = addPtrOp.getResult();
1205+
}
1206+
1207+
// Tensor case: splat the scalar pointer and add the (tensor) offset:
1208+
// ```
1209+
// %tensorBasePtr = tt.splat %basePtr
1210+
// %tensorPtr = tt.addptr %tensorBasePtr, %offset
1211+
// ```
1212+
if (fatPtrAttrs.canNarrow)
1213+
fatPtrOffset = createTruncIOffset(rewriter, loc, fatPtrOffset,
1214+
rewriter.getI32Type());
1215+
1216+
tt::SplatOp tensorPtr =
1217+
rewriter.create<tt::SplatOp>(loc, newResultPtrType, fatPtrBase);
1218+
tt::AddPtrOp addPtrOp = rewriter.create<tt::AddPtrOp>(
1219+
loc, newResultPtrType, tensorPtr, slicedFatPtrOffset);
1220+
1221+
for (const auto &attribute : fatPtrAttrs.attributes)
1222+
addPtrOp->setAttr(attribute.getFirst(), attribute.getSecond());
1223+
newFatPtrOffset = addPtrOp.getResult();
1224+
1225+
assert(newFatPtrOffset);
1226+
rewriter.replaceOp(extractSliceOp, newFatPtrOffset);
1227+
fatPtrs[{fatPtrBase, newFatPtrOffset}] =
11921228
fatPtrs.at({fatPtrBase, fatPtrOffset});
1229+
11931230
return success();
11941231
}
11951232
};

third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,10 @@ bool verifyNonNegativeExpr(
200200
return verifyNonSmallerByAssumption(op.getLhs(), assumptions,
201201
op.getRhs());
202202
})
203+
.Case<triton::amdgpu::ExtractSliceOp>([&](auto op) {
204+
return verifyNonNegativeExpr(op->getOperand(0), assumptions,
205+
solver);
206+
})
203207
.Default([&](Operation *) {
204208
// Conservatively assume that the expression is negative
205209
LDBG(" Unhandled op, cannot assume non-negative");

0 commit comments

Comments
 (0)