Skip to content

Commit 92d3629

Browse files
Merge commit '0ecb17225182d0f8b7176e1a5b0ccda94885af60'
2 parents d849cbc + 0ecb172 commit 92d3629

File tree

23 files changed

+122
-862
lines changed

23 files changed

+122
-862
lines changed

bin/RegisterTritonDialects.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
8585
// TritonAMDGPUTransforms passes
8686
mlir::registerTritonAMDGPUAccelerateMatmul();
8787
mlir::registerTritonAMDGPUOptimizeEpilogue();
88-
mlir::registerTritonAMDGPUBypassLDSForDotOperand();
8988
mlir::registerTritonAMDGPUReorderInstructions();
9089
mlir::registerTritonAMDGPUBlockPingpong();
9190
mlir::registerTritonAMDGPUStreamPipeline();

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,6 @@ enum class MMALoadType {
205205
};
206206
MMALoadType getMMALoadType(Operation *loadOp);
207207

208-
// Convert \param op operands and results to layout \param encoding.
209-
void convertOpEncoding(Attribute encoding, Operation *op);
210208
} // namespace mlir
211209

212210
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3131
"TRITON_ENABLE_LLVM_DEBUG",
3232
"TRITON_HIP_STREAM_PREFETCH",
3333
"TRITON_HIP_USE_BLOCK_PINGPONG",
34-
"TRITON_HIP_BYPASS_LDS_FOR_DOT",
3534
"TRITON_LLVM_DEBUG_ONLY",
3635
"TRITON_ENABLE_ASAN",
3736
"TRITON_OVERRIDE_ARCH",

lib/Dialect/Triton/IR/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,7 @@ LogicalResult ReshapeOp::canonicalize(ReshapeOp op, PatternRewriter &rewriter) {
700700
}
701701

702702
OpFoldResult ReshapeOp::fold(FoldAdaptor adaptor) {
703-
if (getType() == getSrc().getType() && !getAllowReorder()) {
703+
if (getType() == getSrc().getType()) {
704704
// no-op
705705
return getSrc();
706706
}

lib/Dialect/TritonGPU/IR/LinearLayoutConversions.cpp

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1105,6 +1105,7 @@ LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot,
11051105
auto rank = shape.size();
11061106
auto opIdx = dot.getOpIdx();
11071107
int kDim = (opIdx == 0) ? rank - 1 : rank - 2;
1108+
int nonKDim = (opIdx == 0) ? rank - 2 : rank - 1;
11081109

11091110
StringAttr kReg = S("register");
11101111
StringAttr kLane = S("lane");
@@ -1121,8 +1122,11 @@ LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot,
11211122
auto reg = 1 << logReg;
11221123
basesReg.push_back({0, reg});
11231124
}
1124-
std::vector<std::vector<int>> basesLane = {{1, 0}, {2, 0}, {4, 0}};
1125-
int numTileCols;
1125+
std::vector<std::vector<int>> basesLane = {
1126+
{1, 0}, {2, 0}, {4, 0}, {0, 0}, {0, 0}};
1127+
bool kX2 = shape[kDim] > 8 * 16 / elemBitWidth;
1128+
bool kX4 = shape[kDim] > 16 * 16 / elemBitWidth;
1129+
bool nonKX2 = shape[nonKDim] > 8;
11261130
// Construct a tile consisting of 4 8x8x16bits sub-tiles to use ldmatrix
11271131
// efficiently. opIdx=0 and opIdx=1 are handled differently.
11281132
if (opIdx == 0) {
@@ -1135,13 +1139,16 @@ LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot,
11351139
if (needTrans) {
11361140
assert(elemBitWidth <= 16 && "Only elements smaller than 16 bits are "
11371141
"supported in the transposed mode");
1138-
basesLane.push_back({0, 8});
1139-
basesLane.push_back({8, 0});
1142+
if (nonKX2)
1143+
basesLane[3] = {0, 8};
1144+
if (kX2)
1145+
basesLane[4] = {8 * 16 / elemBitWidth, 0};
11401146
} else {
1141-
basesLane.push_back({8, 0});
1142-
basesLane.push_back({0, 8 * 16 / elemBitWidth});
1147+
if (nonKX2)
1148+
basesLane[3] = {8, 0};
1149+
if (kX2)
1150+
basesLane[4] = {0, 8 * 16 / elemBitWidth};
11431151
}
1144-
numTileCols = 16 * 16 / elemBitWidth;
11451152
} else {
11461153
// The matrix elements of thread 0 are distributed in the following pattern
11471154
// (fp16):
@@ -1151,14 +1158,20 @@ LinearLayout chooseDotLdMatrixLayout(DotOperandEncodingAttr dot,
11511158
if (needTrans) {
11521159
assert(elemBitWidth <= 16 && "Only elements smaller than 16 bits are "
11531160
"supported in the transposed mode");
1154-
basesLane.push_back({8, 0});
1155-
basesLane.push_back({16, 0});
1161+
if (kX2)
1162+
basesLane[3] = {8, 0};
1163+
if (kX4)
1164+
basesLane[4] = {16, 0};
11561165
} else {
1157-
basesLane.push_back({0, 8 * 16 / elemBitWidth});
1158-
basesLane.push_back({0, 16 * 16 / elemBitWidth});
1166+
if (kX2)
1167+
basesLane[3] = {0, 8 * 16 / elemBitWidth};
1168+
if (kX4)
1169+
basesLane[4] = {0, 16 * 16 / elemBitWidth};
11591170
}
1160-
numTileCols = 32 * 16 / elemBitWidth;
11611171
}
1172+
int numTileCols =
1173+
(8 * 16 / elemBitWidth)
1174+
<< (static_cast<int>(kX2) + static_cast<int>(kX4 && opIdx == 1));
11621175
// Expand the `register` dimension so the size of columns matches `K`.
11631176
auto layout =
11641177
LinearLayout({{kReg, basesReg}, {kLane, basesLane}, {kWarp, {}}},

lib/Dialect/TritonGPU/IR/Ops.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ struct CanonicalizeConvertFromReshape
6363

6464
if (isExpensiveView(convert.getSrc().getType(), op.getType()))
6565
return failure();
66-
if (!op.getAllowReorder())
66+
if (!op.getAllowReorder() || op.getEfficientLayout())
6767
return failure();
6868

6969
rewriter.replaceOpWithNewOp<triton::ReshapeOp>(

lib/Dialect/TritonGPU/Transforms/Coalesce.cpp

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,55 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
104104
threadsPerWarp, CTALayout);
105105
}
106106

107+
static Type getNewType(Type type, Attribute encoding) {
108+
RankedTensorType tensorType = cast<RankedTensorType>(type);
109+
return RankedTensorType::get(tensorType.getShape(),
110+
tensorType.getElementType(), encoding);
111+
}
112+
113+
void coalesceOp(Attribute encoding, Operation *op) {
114+
OpBuilder builder(op);
115+
// Convert operands
116+
// For load/store with tensor pointers, we don't have to change the
117+
// operands' type, we do this by changing the outputs' type of
118+
// `make_tensor_ptr`
119+
SmallVector<Value, 4> newArgs;
120+
for (auto operand : op->getOperands()) {
121+
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
122+
if (tensorType &&
123+
!isa<triton::gpu::SharedEncodingAttr>(tensorType.getEncoding())) {
124+
Type newType = getNewType(tensorType, encoding);
125+
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
126+
op->getLoc(), newType, operand));
127+
} else {
128+
newArgs.push_back(operand);
129+
}
130+
}
131+
132+
// Convert output types
133+
SmallVector<Type, 4> newTypes;
134+
for (auto t : op->getResultTypes()) {
135+
bool isAsync = isa<triton::gpu::AsyncCopyGlobalToLocalOp>(op);
136+
newTypes.push_back(isAsync ? t : getNewType(t, encoding));
137+
}
138+
139+
// Construct new op with the new encoding
140+
Operation *newOp =
141+
builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs,
142+
newTypes, op->getAttrs());
143+
144+
// Cast the results back to the original layout
145+
for (size_t i = 0; i < op->getNumResults(); i++) {
146+
Value newResult = newOp->getResult(i);
147+
if (newTypes[i] != op->getResultTypes()[i]) {
148+
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
149+
op->getLoc(), op->getResult(i).getType(), newResult);
150+
}
151+
op->getResult(i).replaceAllUsesWith(newResult);
152+
}
153+
op->erase();
154+
}
155+
107156
void runOnOperation() override {
108157
// Run axis info analysis
109158
ModuleOp moduleOp = getOperation();
@@ -138,7 +187,7 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
138187
// 4. Convert the output of this new memory op back to L1
139188
// 5. Replace all the uses of the original memory op by the new one
140189
for (auto &kv : layoutMap) {
141-
convertOpEncoding(kv.second, kv.first);
190+
coalesceOp(kv.second, kv.first);
142191
}
143192
}
144193
};

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 6 additions & 158 deletions
Original file line numberDiff line numberDiff line change
@@ -131,8 +131,6 @@ class LayoutRematerialization {
131131
void backwardRematerialization(ConvertLayoutOp convertOp);
132132
void hoistConvertOnTopOfExtOrBroadcast();
133133
void hoistConvertOnTopOfExtOrBroadcast(ConvertLayoutOp convertOp);
134-
void hoistConvertIntoConditionals();
135-
void hoistConvertIntoConditionals(ConvertLayoutOp convertOp);
136134
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
137135
ConvertLayoutOp convertOp, IRMapping &mapping);
138136
void rewriteSlice(SetVector<Value> &slice, DenseMap<Value, Attribute> &layout,
@@ -1022,66 +1020,13 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
10221020
}
10231021
}
10241022

1025-
bool shouldPropagateConversion(ConvertLayoutOp convertOp) {
1026-
RankedTensorType targetType = convertOp.getType();
1027-
auto dotEnc = dyn_cast<DotOperandEncodingAttr>(targetType.getEncoding());
1028-
// If the target encoding is not DotOperandEncodingAttr, allow propagation.
1029-
if (!dotEnc) {
1030-
return true;
1031-
}
1032-
// Skip conversions to DotOperandEncodingAttr when the operand index is 0.
1033-
// This heuristic is applied to prevent moving the blocked->dot conversion of
1034-
// the Q tensor (a loop invariant in Flash Attention) outside the loop. Doing
1035-
// so can increase register pressure and cause spilling in some cases.
1036-
if (dotEnc.getOpIdx() == 0) {
1037-
return false;
1038-
}
1039-
// Skip conversions to DotOperandEncodingAttr when the operand index is 1 if
1040-
// it's not intentionally placed above a load as we have to be a bit more
1041-
// careful with the heuristics for both correctness and performance.
1042-
// TODO: Fix this logic to avoid propagating conversions backward unless
1043-
// it reduces the total number of conversions.
1044-
assert(dotEnc.getOpIdx() == 1);
1045-
SetVector<Operation *> slice;
1046-
BackwardSliceOptions opt;
1047-
opt.omitBlockArguments = true;
1048-
opt.filter = [&](Operation *op) {
1049-
return op->getParentRegion() == convertOp->getParentRegion();
1050-
};
1051-
getBackwardSlice(convertOp.getOperation(), &slice, opt);
1052-
1053-
for (Operation *currOp : slice) {
1054-
if (isa<LoadOp>(currOp)) {
1055-
return false;
1056-
}
1057-
}
1058-
// Allow propagation if no LoadOp is found.
1059-
return true;
1060-
}
1061-
1062-
void LayoutRematerialization::hoistConvertIntoConditionals() {
1063-
// Go through each ConvertLayoutOp.
1064-
SmallVector<ConvertLayoutOp> convertOps;
1065-
funcOp.walk(
1066-
[&](ConvertLayoutOp convertOp) { convertOps.push_back(convertOp); });
1067-
for (ConvertLayoutOp convertOp : convertOps) {
1068-
hoistConvertIntoConditionals(convertOp);
1069-
if (!opToDelete.contains(convertOp)) {
1070-
// If the conversion didn't get removed, consider it for reuse in future
1071-
// backward slices.
1072-
addRematValue(convertOp.getSrc(), convertOp.getType().getEncoding(),
1073-
convertOp.getResult());
1074-
}
1075-
}
1076-
}
1077-
10781023
void LayoutRematerialization::backwardRematerialization(
10791024
ConvertLayoutOp convertOp) {
1025+
// we don't handle conversions to DotOperandEncodingAttr
1026+
// this is a heuristic to accommodate fused attention
10801027
RankedTensorType targetType = convertOp.getType();
1081-
if (!shouldPropagateConversion(convertOp)) {
1028+
if (isa<DotOperandEncodingAttr>(targetType.getEncoding()))
10821029
return;
1083-
}
1084-
10851030
Value oldV = convertOp.getSrc();
10861031
LDBG("check backward remat with source " << oldV << " encoding "
10871032
<< targetType.getEncoding());
@@ -1120,10 +1065,11 @@ void LayoutRematerialization::backwardRematerialization(
11201065
// of the convert.
11211066
void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
11221067
ConvertLayoutOp convertOp) {
1068+
// we don't handle conversions to DotOperandEncodingAttr
1069+
// this is a heuristics to accommodate fused attention
11231070
RankedTensorType targetType = convertOp.getType();
1124-
if (!shouldPropagateConversion(convertOp)) {
1071+
if (isa<DotOperandEncodingAttr>(targetType.getEncoding()))
11251072
return;
1126-
}
11271073

11281074
auto isExtOrBroadcastOp = [](Operation *op) {
11291075
if (isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp, BroadcastOp,
@@ -1205,100 +1151,6 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
12051151
rewriteSlice(slice, layout, convertOp, mapping);
12061152
}
12071153

1208-
void LayoutRematerialization::hoistConvertIntoConditionals(
1209-
ConvertLayoutOp convertOp) {
1210-
// Take the backward slice of tensor dependencies, stopping at conditionals.
1211-
SetVector<Value> slice;
1212-
DenseMap<Value, Attribute> layout;
1213-
auto isIfOp = [](Operation *op) { return isa<scf::IfOp>(op); };
1214-
if (failed(getRematerializableSlice(convertOp.getSrcMutable(),
1215-
convertOp.getType().getEncoding(), slice,
1216-
layout, isIfOp)))
1217-
return;
1218-
1219-
// Find conditional edges above which the conversion can be hoisted.
1220-
SmallVector<std::pair<Value, OpOperand *>> hoistAbove;
1221-
unsigned sliceSize = slice.size();
1222-
// The routine will recurse through backward slices, e.g. to handle loops and
1223-
// conditional chains. Thus, we re-query the size of `slice`.
1224-
for (unsigned i = 0; i < slice.size(); i++) {
1225-
Value v = slice[i];
1226-
auto ifOp = v.getDefiningOp<scf::IfOp>();
1227-
if (!ifOp)
1228-
continue;
1229-
1230-
Attribute rootLayout = layout.at(v);
1231-
unsigned resIdx = cast<OpResult>(v).getResultNumber();
1232-
1233-
// Take the backward slice along each branch.
1234-
auto thenYield =
1235-
cast<scf::YieldOp>(ifOp.getThenRegion().front().getTerminator());
1236-
auto elseYield =
1237-
cast<scf::YieldOp>(ifOp.getElseRegion().front().getTerminator());
1238-
1239-
OpOperand &thenRes = thenYield.getResultsMutable()[resIdx];
1240-
OpOperand &elseRes = elseYield.getResultsMutable()[resIdx];
1241-
1242-
SetVector<Value> thenSlice, elseSlice;
1243-
DenseMap<Value, Attribute> thenLayout, elseLayout;
1244-
1245-
LogicalResult thenResult = getRematerializableSlice(
1246-
thenRes, rootLayout, thenSlice, thenLayout, isIfOp);
1247-
LogicalResult elseResult = getRematerializableSlice(
1248-
elseRes, rootLayout, elseSlice, elseLayout, isIfOp);
1249-
1250-
// If propagation across both edges of this conditional succeeded, then we
1251-
// don't need to hoist across it.
1252-
if (succeeded(thenResult) && succeeded(elseResult)) {
1253-
slice.insert(thenSlice.begin(), thenSlice.end());
1254-
slice.insert(elseSlice.begin(), elseSlice.end());
1255-
layout.insert(thenLayout.begin(), thenLayout.end());
1256-
layout.insert(elseLayout.begin(), elseLayout.end());
1257-
continue;
1258-
}
1259-
1260-
// If propagation across both edges failed, then there is nothing to do
1261-
// for this one.
1262-
if (failed(thenResult) && failed(elseResult))
1263-
continue;
1264-
1265-
// The layout conversion can be rematerialized along one edge but not the
1266-
// other. We can hoist the conversion into the other branch.
1267-
if (succeeded(elseResult)) {
1268-
std::swap(thenSlice, elseSlice);
1269-
std::swap(thenLayout, elseLayout);
1270-
hoistAbove.push_back({v, &thenRes});
1271-
} else {
1272-
hoistAbove.push_back({v, &elseRes});
1273-
}
1274-
slice.insert(thenSlice.begin(), thenSlice.end());
1275-
layout.insert(thenLayout.begin(), thenLayout.end());
1276-
}
1277-
1278-
// It's hard to know if duplicating the conversion into separate branches is
1279-
// profitable without more analysis. For now, hoist at most one.
1280-
if (hoistAbove.size() != 1)
1281-
return;
1282-
1283-
IRMapping mapping;
1284-
for (auto [result, edge] : hoistAbove) {
1285-
// Hoist the convert into the conditional and rewrite the slice.
1286-
OpBuilder b(edge->getOwner());
1287-
Value v = edge->get();
1288-
Attribute encoding = layout.at(result);
1289-
1290-
auto tensorType = cast<RankedTensorType>(v.getType());
1291-
auto newType = RankedTensorType::get(tensorType.getShape(),
1292-
tensorType.getElementType(), encoding);
1293-
1294-
Value newCvt = b.create<ConvertLayoutOp>(convertOp.getLoc(), newType, v);
1295-
1296-
mapping.map(v, newCvt);
1297-
slice.remove(v);
1298-
}
1299-
rewriteSlice(slice, layout, convertOp, mapping);
1300-
}
1301-
13021154
void backwardRematerialization(ModuleOp module) {
13031155
module.walk([](FuncOp funcOp) {
13041156
LayoutRematerialization layoutRemat(funcOp);
@@ -1313,10 +1165,6 @@ void hoistConvert(ModuleOp module) {
13131165
LayoutRematerialization layoutRemat(funcOp);
13141166
layoutRemat.hoistConvertOnTopOfExtOrBroadcast();
13151167
layoutRemat.cleanup();
1316-
1317-
layoutRemat = LayoutRematerialization(funcOp);
1318-
layoutRemat.hoistConvertIntoConditionals();
1319-
layoutRemat.cleanup();
13201168
});
13211169
}
13221170
} // namespace

0 commit comments

Comments
 (0)