Skip to content

Commit 93de426

Browse files
authored
[AMD] revert optimizations (#4919)
These optimizations break internal workloads
1 parent d997364 commit 93de426

File tree

23 files changed

+94
-844
lines changed

23 files changed

+94
-844
lines changed

bin/RegisterTritonDialects.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
6363
mlir::registerTritonAMDGPUStreamPipeline();
6464
mlir::registerTritonAMDGPUStreamPipelineV2();
6565
mlir::registerTritonAMDGPUCanonicalizePointers();
66-
mlir::triton::registerTritonAMDGPUInsertInstructionSchedHints();
67-
mlir::triton::registerTritonAMDGPULowerInstructionSchedHints();
6866

6967
// TODO: register Triton & TritonGPU passes
7068
registry.insert<mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,

include/triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 4 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -27,33 +27,15 @@ constexpr int patternBenefitPrioritizeOverLLVMConversions = 10;
2727
constexpr int patternBenefitClampOptimizedPattern = 20;
2828
constexpr int patternBenefitConvertLayoutOptimizedPattern = 20;
2929

30-
struct BackendCallbacks {
31-
/**
32-
* A backend-specific callback for appending auxiliary data during
33-
* `LocalStoreOp` conversion.
34-
*
35-
* @param[in] op The reference to the re-written `LocalStoreOp`.
36-
* @param[in] count The number of issued LLVM instructions.
37-
* @param[in] type The input type of issued LLVM instructions.
38-
*/
39-
std::function<void(triton::gpu::LocalStoreOp op, size_t llvmOpCount,
40-
Type llvmOpType)>
41-
localStoreOpConversion = nullptr;
42-
};
43-
4430
void populateElementwiseOpToLLVMPatterns(
4531
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
4632
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
4733
PatternBenefit benefit);
4834

49-
// The given callback is invoked at the end of a successful rewrite. The
50-
// callback receives 1) the current source op, 2) the number of issued LLVM
51-
// instructions and 3) their input types. Each MLIR backend can provide a
52-
// callback and, thus, handle backend-specific behaviors.
53-
void populateMemoryOpToLLVMPattern(
54-
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
55-
RewritePatternSet &patterns, PatternBenefit benefit,
56-
std::optional<BackendCallbacks> backendCallbacks = std::nullopt);
35+
void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter,
36+
const TargetInfoBase &targetInfo,
37+
RewritePatternSet &patterns,
38+
PatternBenefit benefit);
5739

5840
void populateAssertOpToLLVMPattern(LLVMTypeConverter &typeConverter,
5941
RewritePatternSet &patterns,

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1366,11 +1366,11 @@ SmallVector<Value> loadSharedToDistributed(RankedTensorType dstTy,
13661366
Location loc, RewriterBase &rewriter,
13671367
const TargetInfoBase &target);
13681368

1369-
void storeDistributedToShared(
1370-
MemDescType dstTy, RankedTensorType srcTy, Type elemLlvmTy,
1371-
ArrayRef<Value> srcVals, Value smemBase, ArrayRef<Value> dstStrides,
1372-
Location loc, RewriterBase &rewriter, const TargetInfoBase &target,
1373-
std::pair<size_t, Type> *const llvmOpCount = nullptr);
1369+
void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
1370+
Type elemLlvmTy, ArrayRef<Value> srcVals,
1371+
Value smemBase, ArrayRef<Value> dstStrides,
1372+
Location loc, RewriterBase &rewriter,
1373+
const TargetInfoBase &target);
13741374

13751375
inline Value getStructFromSharedMemoryObject(Location loc,
13761376
const SharedMemoryObject &smemObj,

lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 11 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,12 @@ using namespace mlir::triton::gpu;
1515
// blocked -> shared.
1616
// Swizzling in shared memory to avoid bank conflict. Normally used for
1717
// A/B operands of dots.
18-
void lowerDistributedToShared(
19-
Location loc, Value src, Value dst, Value adaptorSrc,
20-
const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter,
21-
ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo,
22-
std::pair<size_t, Type> *const llvmOpCount = nullptr) {
18+
void lowerDistributedToShared(Location loc, Value src, Value dst,
19+
Value adaptorSrc,
20+
const SharedMemoryObject &smemObj,
21+
const LLVMTypeConverter *typeConverter,
22+
ConversionPatternRewriter &rewriter,
23+
const TargetInfoBase &targetInfo) {
2324
auto srcTy = cast<RankedTensorType>(src.getType());
2425
auto dstTy = cast<MemDescType>(dst.getType());
2526
auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding()).getOrder();
@@ -32,7 +33,7 @@ void lowerDistributedToShared(
3233
auto dstStrides = smemObj.getStrides();
3334
auto inVals = unpackLLElements(loc, adaptorSrc, rewriter);
3435
storeDistributedToShared(dstTy, srcTy, elemTy, inVals, smemBase, dstStrides,
35-
loc, rewriter, targetInfo, llvmOpCount);
36+
loc, rewriter, targetInfo);
3637
}
3738

3839
struct LocalAllocOpConversion
@@ -184,15 +185,12 @@ struct LocalStoreOpConversion
184185
public:
185186
using ConvertOpToLLVMPattern<
186187
triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern;
187-
using BackendCallbackType =
188-
decltype(BackendCallbacks::localStoreOpConversion);
189188

190189
LocalStoreOpConversion(const LLVMTypeConverter &converter,
191190
const TargetInfoBase &targetInfo,
192-
BackendCallbackType backendCallback,
193191
PatternBenefit benefit = 1)
194192
: ConvertOpToLLVMPattern<triton::gpu::LocalStoreOp>(converter, benefit),
195-
targetInfo(targetInfo), backendCallback(backendCallback) {}
193+
targetInfo(targetInfo) {}
196194

197195
LogicalResult
198196
matchAndRewrite(triton::gpu::LocalStoreOp op, OpAdaptor adaptor,
@@ -202,36 +200,24 @@ struct LocalStoreOpConversion
202200
getTypeConverter()->convertType(op.getDst().getType().getElementType());
203201
auto smemObj = LLVM::getSharedMemoryObjectFromStruct(
204202
op.getLoc(), adaptor.getDst(), llvmElemTy, rewriter);
205-
206-
std::pair<size_t, Type> llvmOpCount;
207203
lowerDistributedToShared(op.getLoc(), op.getSrc(), op.getDst(),
208204
adaptor.getSrc(), smemObj, getTypeConverter(),
209-
rewriter, targetInfo, &llvmOpCount);
210-
211-
if (backendCallback)
212-
(backendCallback)(op, llvmOpCount.first, llvmOpCount.second);
213-
205+
rewriter, targetInfo);
214206
rewriter.eraseOp(op);
215207
return success();
216208
}
217209

218210
private:
219211
const TargetInfoBase &targetInfo;
220-
BackendCallbackType backendCallback;
221212
};
222213

223214
} // namespace
224215

225216
void mlir::triton::populateMemoryOpToLLVMPattern(
226217
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
227-
RewritePatternSet &patterns, PatternBenefit benefit,
228-
std::optional<BackendCallbacks> backendCallbacks) {
218+
RewritePatternSet &patterns, PatternBenefit benefit) {
229219
patterns.add<LocalAllocOpConversion>(typeConverter, targetInfo, benefit);
230220
patterns.add<LocalDeallocOpConversion>(typeConverter, benefit);
231221
patterns.add<LocalLoadOpConversion>(typeConverter, targetInfo, benefit);
232-
233-
auto backendCall =
234-
backendCallbacks ? backendCallbacks->localStoreOpConversion : nullptr;
235-
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, backendCall,
236-
benefit);
222+
patterns.add<LocalStoreOpConversion>(typeConverter, targetInfo, benefit);
237223
}

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,7 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
404404
Type elemLlvmTy, ArrayRef<Value> srcVals,
405405
Value smemBase, ArrayRef<Value> dstStrides,
406406
Location loc, RewriterBase &rewriter,
407-
const TargetInfoBase &target,
408-
std::pair<size_t, Type> *const llvmOpCount) {
407+
const TargetInfoBase &target) {
409408
bool success = emitTransferBetweenRegistersAndShared(
410409
srcTy, dstTy, elemLlvmTy, /*maxVecElems=*/std::nullopt, smemBase,
411410
dstStrides, loc, rewriter, target, [&](VectorType vecTy, Value vecAddr) {
@@ -419,12 +418,7 @@ void storeDistributedToShared(MemDescType dstTy, RankedTensorType srcTy,
419418
store(vec, vecAddr)
420419
.setAlignment(vecTy.getNumElements() *
421420
elemLlvmTy.getIntOrFloatBitWidth() / 8);
422-
if (llvmOpCount) {
423-
++(llvmOpCount->first);
424-
llvmOpCount->second = vecTy;
425-
}
426421
});
427-
428422
if (!success)
429423
llvm::report_fatal_error("Failed to emit transfer from register to shared");
430424
}

test/TritonGPU/amd/amd-instruction-sched.mlir

Lines changed: 0 additions & 148 deletions
This file was deleted.

0 commit comments

Comments
 (0)