Skip to content

Commit 216385e

Browse files
Revert "[AMD] Add basics to allow bypass LDS for dot RHS (#5350)" (#5708)
Reverting, as I have to revert [cec1db5](triton-lang/triton@cec1db5), (which this change relies on) due to regression in internal tests.
1 parent 53e6e9e commit 216385e

File tree

14 files changed

+57
-473
lines changed

14 files changed

+57
-473
lines changed

bin/RegisterTritonDialects.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6060
// TritonAMDGPUTransforms passes
6161
mlir::registerTritonAMDGPUAccelerateMatmul();
6262
mlir::registerTritonAMDGPUOptimizeEpilogue();
63-
mlir::registerTritonAMDGPUBypassLDSForDotOperand();
6463
mlir::registerTritonAMDGPUReorderInstructions();
6564
mlir::registerTritonAMDGPUBlockPingpong();
6665
mlir::registerTritonAMDGPUStreamPipeline();

include/triton/Dialect/TritonGPU/Transforms/Utility.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,6 @@ enum class MMALoadType {
205205
};
206206
MMALoadType getMMALoadType(Operation *loadOp);
207207

208-
// Convert \param op operands and results to layout \param encoding.
209-
void convertOpEncoding(Attribute encoding, Operation *op);
210208
} // namespace mlir
211209

212210
#endif // TRITON_DIALECT_TRITONGPU_TRANSFORMS_UTILITY_H_

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
3131
"TRITON_ENABLE_LLVM_DEBUG",
3232
"TRITON_HIP_STREAM_PREFETCH",
3333
"TRITON_HIP_USE_BLOCK_PINGPONG",
34-
"TRITON_HIP_BYPASS_LDS_FOR_DOT",
3534
"TRITON_LLVM_DEBUG_ONLY",
3635
"TRITON_ENABLE_ASAN",
3736
"TRITON_OVERRIDE_ARCH",

lib/Dialect/TritonGPU/Transforms/Coalesce.cpp

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,55 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
104104
threadsPerWarp, CTALayout);
105105
}
106106

107+
static Type getNewType(Type type, Attribute encoding) {
108+
RankedTensorType tensorType = cast<RankedTensorType>(type);
109+
return RankedTensorType::get(tensorType.getShape(),
110+
tensorType.getElementType(), encoding);
111+
}
112+
113+
void coalesceOp(Attribute encoding, Operation *op) {
114+
OpBuilder builder(op);
115+
// Convert operands
116+
// For load/store with tensor pointers, we don't have to change the
117+
// operands' type, we do this by changing the outputs' type of
118+
// `make_tensor_ptr`
119+
SmallVector<Value, 4> newArgs;
120+
for (auto operand : op->getOperands()) {
121+
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
122+
if (tensorType &&
123+
!isa<triton::gpu::SharedEncodingAttr>(tensorType.getEncoding())) {
124+
Type newType = getNewType(tensorType, encoding);
125+
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
126+
op->getLoc(), newType, operand));
127+
} else {
128+
newArgs.push_back(operand);
129+
}
130+
}
131+
132+
// Convert output types
133+
SmallVector<Type, 4> newTypes;
134+
for (auto t : op->getResultTypes()) {
135+
bool isAsync = isa<triton::gpu::AsyncCopyGlobalToLocalOp>(op);
136+
newTypes.push_back(isAsync ? t : getNewType(t, encoding));
137+
}
138+
139+
// Construct new op with the new encoding
140+
Operation *newOp =
141+
builder.create(op->getLoc(), op->getName().getIdentifier(), newArgs,
142+
newTypes, op->getAttrs());
143+
144+
// Cast the results back to the original layout
145+
for (size_t i = 0; i < op->getNumResults(); i++) {
146+
Value newResult = newOp->getResult(i);
147+
if (newTypes[i] != op->getResultTypes()[i]) {
148+
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
149+
op->getLoc(), op->getResult(i).getType(), newResult);
150+
}
151+
op->getResult(i).replaceAllUsesWith(newResult);
152+
}
153+
op->erase();
154+
}
155+
107156
void runOnOperation() override {
108157
// Run axis info analysis
109158
ModuleOp moduleOp = getOperation();
@@ -138,7 +187,7 @@ struct CoalescePass : public impl::TritonGPUCoalesceBase<CoalescePass> {
138187
// 4. Convert the output of this new memory op back to L1
139188
// 5. Replace all the uses of the original memory op by the new one
140189
for (auto &kv : layoutMap) {
141-
convertOpEncoding(kv.second, kv.first);
190+
coalesceOp(kv.second, kv.first);
142191
}
143192
}
144193
};

lib/Dialect/TritonGPU/Transforms/RemoveLayoutConversions.cpp

Lines changed: 6 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -1022,43 +1022,6 @@ void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast() {
10221022
}
10231023
}
10241024

1025-
bool shouldPropagateConversion(ConvertLayoutOp convertOp) {
1026-
RankedTensorType targetType = convertOp.getType();
1027-
auto dotEnc = dyn_cast<DotOperandEncodingAttr>(targetType.getEncoding());
1028-
// If the target encoding is not DotOperandEncodingAttr, allow propagation.
1029-
if (!dotEnc) {
1030-
return true;
1031-
}
1032-
// Skip conversions to DotOperandEncodingAttr when the operand index is 0.
1033-
// This heuristic is applied to prevent moving the blocked->dot conversion of
1034-
// the Q tensor (a loop invariant in Flash Attention) outside the loop. Doing
1035-
// so can increase register pressure and cause spilling in some cases.
1036-
if (dotEnc.getOpIdx() == 0) {
1037-
return false;
1038-
}
1039-
// Skip conversions to DotOperandEncodingAttr when the operand index is 1 if
1040-
// it's not intentionally placed above a load as we have to be a bit more
1041-
// careful with the heuristics for both correctness and performance.
1042-
// TODO: Fix this logic to avoid propagating conversions backward unless
1043-
// it reduces the total number of conversions.
1044-
assert(dotEnc.getOpIdx() == 1);
1045-
SetVector<Operation *> slice;
1046-
BackwardSliceOptions opt;
1047-
opt.omitBlockArguments = true;
1048-
opt.filter = [&](Operation *op) {
1049-
return op->getParentRegion() == convertOp->getParentRegion();
1050-
};
1051-
getBackwardSlice(convertOp.getOperation(), &slice, opt);
1052-
1053-
for (Operation *currOp : slice) {
1054-
if (isa<LoadOp>(currOp)) {
1055-
return false;
1056-
}
1057-
}
1058-
// Allow propagation if no LoadOp is found.
1059-
return true;
1060-
}
1061-
10621025
void LayoutRematerialization::hoistConvertIntoConditionals() {
10631026
// Go through each ConvertLayoutOp.
10641027
SmallVector<ConvertLayoutOp> convertOps;
@@ -1077,11 +1040,11 @@ void LayoutRematerialization::hoistConvertIntoConditionals() {
10771040

10781041
void LayoutRematerialization::backwardRematerialization(
10791042
ConvertLayoutOp convertOp) {
1043+
// we don't handle conversions to DotOperandEncodingAttr
1044+
// this is a heuristic to accommodate fused attention
10801045
RankedTensorType targetType = convertOp.getType();
1081-
if (!shouldPropagateConversion(convertOp)) {
1046+
if (isa<DotOperandEncodingAttr>(targetType.getEncoding()))
10821047
return;
1083-
}
1084-
10851048
Value oldV = convertOp.getSrc();
10861049
LDBG("check backward remat with source " << oldV << " encoding "
10871050
<< targetType.getEncoding());
@@ -1120,10 +1083,11 @@ void LayoutRematerialization::backwardRematerialization(
11201083
// of the convert.
11211084
void LayoutRematerialization::hoistConvertOnTopOfExtOrBroadcast(
11221085
ConvertLayoutOp convertOp) {
1086+
// we don't handle conversions to DotOperandEncodingAttr
1087+
// this is a heuristics to accommodate fused attention
11231088
RankedTensorType targetType = convertOp.getType();
1124-
if (!shouldPropagateConversion(convertOp)) {
1089+
if (isa<DotOperandEncodingAttr>(targetType.getEncoding()))
11251090
return;
1126-
}
11271091

11281092
auto isExtOrBroadcastOp = [](Operation *op) {
11291093
if (isa<arith::ExtSIOp, arith::ExtUIOp, arith::ExtFOp, BroadcastOp,

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,54 +1057,6 @@ MMALoadType getMMALoadType(Operation *loadOp) {
10571057
}
10581058
}
10591059

1060-
static Type getNewType(Type type, Attribute encoding) {
1061-
RankedTensorType tensorType = cast<RankedTensorType>(type);
1062-
return RankedTensorType::get(tensorType.getShape(),
1063-
tensorType.getElementType(), encoding);
1064-
}
1065-
1066-
void convertOpEncoding(Attribute encoding, Operation *op) {
1067-
OpBuilder builder(op);
1068-
// Convert operands
1069-
// For load/store with tensor pointers, we don't have to change the
1070-
// operands' type, we do this by changing the outputs' type of
1071-
// `make_tensor_ptr`
1072-
SmallVector<Value, 4> newArgs;
1073-
for (auto operand : op->getOperands()) {
1074-
auto tensorType = dyn_cast<RankedTensorType>(operand.getType());
1075-
if (tensorType &&
1076-
!isa<triton::gpu::SharedEncodingAttr>(tensorType.getEncoding())) {
1077-
Type newType = getNewType(tensorType, encoding);
1078-
newArgs.push_back(builder.create<triton::gpu::ConvertLayoutOp>(
1079-
op->getLoc(), newType, operand));
1080-
} else {
1081-
newArgs.push_back(operand);
1082-
}
1083-
}
1084-
1085-
// Convert output types
1086-
SmallVector<Type, 4> newTypes;
1087-
for (auto t : op->getResultTypes()) {
1088-
bool isAsync = isa<triton::gpu::AsyncCopyGlobalToLocalOp>(op);
1089-
newTypes.push_back(isAsync ? t : getNewType(t, encoding));
1090-
}
1091-
1092-
// Construct new op with the new encoding
1093-
Operation *newOp = builder.create(op->getLoc(), op->getName().getIdentifier(),
1094-
newArgs, newTypes, op->getAttrs());
1095-
1096-
// Cast the results back to the original layout
1097-
for (size_t i = 0; i < op->getNumResults(); i++) {
1098-
Value newResult = newOp->getResult(i);
1099-
if (newTypes[i] != op->getResultTypes()[i]) {
1100-
newResult = builder.create<triton::gpu::ConvertLayoutOp>(
1101-
op->getLoc(), op->getResult(i).getType(), newResult);
1102-
}
1103-
op->getResult(i).replaceAllUsesWith(newResult);
1104-
}
1105-
op->erase();
1106-
}
1107-
11081060
namespace {
11091061

11101062
/// Detect dead arguments in scf.for op by assuming all the values are dead and

0 commit comments

Comments
 (0)