Skip to content

Commit 688adf6

Browse files
Revert "[PROTON] Intra kernel profiling (#7258)"
This reverts commit bf0f56d.
1 parent 33d511d commit 688adf6

File tree

219 files changed

+731
-9933
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

219 files changed

+731
-9933
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
228228
endif()
229229
# We always build proton dialect
230230
list(APPEND TRITON_PLUGIN_NAMES "proton")
231-
add_subdirectory(third_party/proton/Dialect)
231+
add_subdirectory(third_party/proton/dialect)
232232

233233
get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)
234234
get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS)
@@ -360,7 +360,7 @@ if(NOT TRITON_BUILD_PYTHON_MODULE)
360360
foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS})
361361
add_subdirectory(third_party/${CODEGEN_BACKEND})
362362
endforeach()
363-
add_subdirectory(third_party/proton/Dialect)
363+
add_subdirectory(third_party/proton/dialect)
364364
endif()
365365

366366
find_package(Threads REQUIRED)

Makefile

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -67,8 +67,7 @@ test-interpret: all
6767

6868
.PHONY: test-proton
6969
test-proton: all
70-
$(PYTEST) -s -n 8 third_party/proton/test --ignore=third_party/proton/test/test_override.py
71-
$(PYTEST) -s third_party/proton/test/test_override.py
70+
$(PYTEST) -s -n 8 third_party/proton/test
7271

7372
.PHONY: test-python
7473
test-python: test-unit test-regression test-interpret test-proton

bin/CMakeLists.txt

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ target_link_libraries(triton-opt PRIVATE
1515
TritonTestAnalysis
1616
TritonTestDialect
1717
TritonAMDGPUTestAnalysis
18-
TritonTestProton
1918
# MLIR core
2019
MLIROptLib
2120
MLIRPass
@@ -36,7 +35,6 @@ target_link_libraries(triton-reduce PRIVATE
3635
TritonTestAnalysis
3736
TritonTestDialect
3837
TritonAMDGPUTestAnalysis
39-
TritonTestProton
4038
# MLIR core
4139
MLIRReduceLib
4240
MLIRPass
@@ -56,7 +54,6 @@ target_link_libraries(triton-lsp PRIVATE
5654
TritonTestAnalysis
5755
TritonTestDialect
5856
TritonAMDGPUTestAnalysis
59-
TritonTestProton
6057
# MLIR core
6158
MLIRLspServerLib
6259
MLIRPass
@@ -95,6 +92,5 @@ target_link_libraries(triton-tensor-layout PRIVATE
9592
${dialect_libs}
9693
TritonTestAnalysis
9794
TritonTestDialect
98-
TritonTestProton
9995
TritonAMDGPUTestAnalysis
10096
)

bin/RegisterTritonDialects.h

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,9 @@
1212

1313
#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
1414
#include "amd/include/TritonAMDGPUTransforms/Passes.h"
15-
#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h"
16-
#include "nvidia/include/Dialect/NVWS/IR/Dialect.h"
17-
#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/Passes.h"
18-
#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/Passes.h"
19-
#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/Passes.h"
20-
#include "proton/Dialect/include/Conversion/ProtonToProtonGPU/Passes.h"
21-
#include "proton/Dialect/include/Dialect/Proton/IR/Dialect.h"
22-
#include "proton/Dialect/include/Dialect/ProtonGPU/IR/Dialect.h"
23-
#include "proton/Dialect/include/Dialect/ProtonGPU/Transforms/Passes.h"
15+
#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h"
16+
#include "third_party/nvidia/include/Dialect/NVWS/IR/Dialect.h"
17+
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
2418
#include "triton/Dialect/Gluon/Transforms/Passes.h"
2519
#include "triton/Dialect/Triton/IR/Dialect.h"
2620
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -64,9 +58,6 @@ void registerTestMembarPass();
6458
void registerTestAMDGPUMembarPass();
6559
void registerTestTritonAMDGPURangeAnalysis();
6660
void registerTestLoopPeelingPass();
67-
namespace proton {
68-
void registerTestScopeIdAllocationPass();
69-
} // namespace proton
7061
} // namespace test
7162
} // namespace mlir
7263

@@ -136,16 +127,6 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
136127
// NVGPU transform passes
137128
mlir::registerNVHopperTransformsPasses();
138129

139-
// Proton passes
140-
mlir::test::proton::registerTestScopeIdAllocationPass();
141-
mlir::triton::proton::registerConvertProtonToProtonGPU();
142-
mlir::triton::proton::gpu::registerConvertProtonNvidiaGPUToLLVM();
143-
mlir::triton::proton::gpu::registerConvertProtonAMDGPUToLLVM();
144-
mlir::triton::proton::gpu::registerAllocateProtonSharedMemoryPass();
145-
mlir::triton::proton::gpu::registerAllocateProtonGlobalScratchBufferPass();
146-
mlir::triton::proton::gpu::registerScheduleBufferStorePass();
147-
mlir::triton::proton::gpu::registerAddSchedBarriersPass();
148-
149130
registry.insert<
150131
mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
151132
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
@@ -155,8 +136,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
155136
mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect,
156137
mlir::triton::nvgpu::NVGPUDialect, mlir::triton::nvws::NVWSDialect,
157138
mlir::triton::amdgpu::TritonAMDGPUDialect,
158-
mlir::triton::proton::ProtonDialect,
159-
mlir::triton::proton::gpu::ProtonGPUDialect, mlir::ROCDL::ROCDLDialect,
139+
mlir::triton::proton::ProtonDialect, mlir::ROCDL::ROCDLDialect,
160140
mlir::triton::gpu::intel::TritonIntelGPUDialect,
161141
mlir::triton::TritonGEN::TritonGENDialect,
162142
mlir::triton::gluon::GluonDialect>();

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -319,12 +319,6 @@ class TritonLLVMIRRewriter : public IRRewriter, public TritonLLVMOpBuilder {
319319
#define str_attr(str) ::mlir::StringAttr::get(ctx, (str))
320320

321321
namespace mlir {
322-
323-
// See FuncOpToLLVM.cpp for details about Triton's function calling conventions
324-
constexpr int kProfileScratchBufferOffset = -1;
325-
constexpr int kGlobalScratchBufferOffset = -2;
326-
constexpr int kSharedMemoryOffset = -3;
327-
328322
namespace triton {
329323

330324
namespace gpu {
@@ -445,9 +439,6 @@ Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
445439
const TargetInfoBase &targetInfo,
446440
FunctionOpInterface funcOp, Value allocOffset);
447441

448-
Value getProfileScratchPtr(Location loc, RewriterBase &rewriter,
449-
FunctionOpInterface funcOp);
450-
451442
Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
452443
const TargetInfoBase &target, Operation *op);
453444

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
5959
inline const std::set<std::string> CACHE_NEUTRAL_ENV_VARS = {
6060
// clang-format off
6161
"TRITON_REPRODUCER_PATH",
62-
"TRITON_ENABLE_PYTHON_STACKTRACE",
62+
"TRITON_ENABLE_PYTHON_STACKTRACE"
6363
// clang-format on
6464
};
6565

lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,6 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
103103

104104
promotedOperands.push_back(LLVM::getGlobalScratchPtr(
105105
loc, rewriter, targetInfo, caller, opOffsetVal));
106-
promotedOperands.push_back(
107-
LLVM::getProfileScratchPtr(loc, rewriter, caller));
108106
return promotedOperands;
109107
}
110108

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,21 @@ using namespace mlir;
1010
using namespace mlir::triton;
1111

1212
// NOTE: [Additional Function Arguments]
13-
// Triton patches additional arguments to the function signature to support
14-
// (1) shared memory, (2) global scratch memory, and (3) profile scratch memory.
1513
// To support use of shared memory and global scratch memory inside of a
1614
// function, the caller allocates a single large block of the relevant memory
1715
// and calls the function with these extra arguments at the end.
18-
// Profile scratch memory is only used when the function is instrumented for
19-
// profiling.
16+
// Specifically, the last argument is the global scratch memory allocation and
17+
// the second to last is the shared memory allocation.
2018
//
2119
// For the kernel function itself, the shared memory base is a global symbol
2220
// so no additional function argument is required but global scratch memory
2321
// allocation is still passed in as the last argument. Though here the scratch
2422
// memory is shared between all programs, so a linear offset based on the
2523
// program id is required to get the local scratch base.
2624

25+
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
26+
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
27+
/// information.
2728
struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
2829
FuncOpConversion(LLVMTypeConverter &converter,
2930
const TargetInfoBase &targetInfo, PatternBenefit benefit)
@@ -55,7 +56,6 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
5556
auto sharedPtrTy =
5657
LLVM::LLVMPointerType::get(ctx, targetInfo.getSharedAddressSpace());
5758
auto globalPtrTy = LLVM::LLVMPointerType::get(ctx, 1);
58-
auto profilePtrTy = LLVM::LLVMPointerType::get(ctx, 1);
5959

6060
// 1. Modify the function type to add the new arguments.
6161
auto funcTy = funcOp.getFunctionType();
@@ -73,7 +73,6 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
7373
amendedInputTy.push_back(sharedPtrTy);
7474
}
7575
amendedInputTy.push_back(globalPtrTy);
76-
amendedInputTy.push_back(profilePtrTy);
7776
auto amendedFuncTy =
7877
FunctionType::get(ctx, amendedInputTy, funcTy.getResults());
7978
// 2. Modify the argument attributes to add the new argument.
@@ -98,7 +97,6 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
9897
region.addArgument(sharedPtrTy, loc);
9998
}
10099
region.addArgument(globalPtrTy, loc);
101-
region.addArgument(profilePtrTy, loc);
102100
rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(),
103101
amendedFuncOp.end());
104102
return amendedFuncOp;

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,7 +1198,7 @@ SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc,
11981198
Value getStackPointer(RewriterBase &rewriter, FunctionOpInterface funcOp) {
11991199
// See NOTE: [Additional Function Arguments]
12001200
if (!isKernel(funcOp)) {
1201-
return funcOp.getArgument(funcOp.getNumArguments() + kSharedMemoryOffset);
1201+
return funcOp.getArgument(funcOp.getNumArguments() - 2);
12021202
}
12031203

12041204
auto mod = funcOp->getParentOfType<ModuleOp>();
@@ -1213,8 +1213,7 @@ Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
12131213
// See NOTE: [Additional Function Arguments]
12141214
if (!isKernel(funcOp)) {
12151215
// Base for this function
1216-
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() +
1217-
kGlobalScratchBufferOffset);
1216+
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1);
12181217
if (!allocOffset) {
12191218
return gmemBase;
12201219
}
@@ -1225,8 +1224,7 @@ Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
12251224
}
12261225

12271226
// Base for entire kernel
1228-
auto gmemBase =
1229-
funcOp.getArgument(funcOp.getNumArguments() + kGlobalScratchBufferOffset);
1227+
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1);
12301228

12311229
ModuleOp mod = funcOp.getOperation()->getParentOfType<ModuleOp>();
12321230
auto allocSizeAttr = mod.getOperation()->getAttrOfType<mlir::IntegerAttr>(
@@ -1268,15 +1266,6 @@ Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
12681266
return res;
12691267
}
12701268

1271-
Value getProfileScratchPtr(Location loc, RewriterBase &rewriter,
1272-
FunctionOpInterface funcOp) {
1273-
// See NOTE: [Additional Function Arguments]
1274-
// FIXME(Keren): This is broken when we have device functions, we
1275-
// need to implement proper calling convention
1276-
return funcOp.getArgument(funcOp.getNumArguments() +
1277-
kProfileScratchBufferOffset);
1278-
}
1279-
12801269
Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
12811270
const TargetInfoBase &target, Operation *op) {
12821271
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(),

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "mlir/Dialect/UB/IR/UBOps.h"
44
#include "mlir/Pass/Pass.h"
55
#include "mlir/Transforms/DialectConversion.h"
6+
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
67
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
78
#include "triton/Dialect/Triton/IR/Dialect.h"
89
#include "triton/Dialect/Triton/IR/Utility.h"
@@ -596,6 +597,17 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
596597
// clang-format on
597598
>(typeConverter, context);
598599
}
600+
// Proton patterns
601+
// NOTE: Because Proton's inputs are scalars and not tensors this conversion
602+
// isn't strictly necessary however you could envision a case where we pass in
603+
// tensors in for Triton object specific tracing operations in which case we
604+
// would need to fill in the OpConversionPattern
605+
void populateProtonPatterns(TritonGPUTypeConverter &typeConverter,
606+
RewritePatternSet &patterns) {
607+
MLIRContext *context = patterns.getContext();
608+
patterns.add<GenericOpPattern<triton::proton::RecordOp>>(typeConverter,
609+
context);
610+
}
599611
//
600612
// SCF patterns
601613
//
@@ -809,6 +821,7 @@ class ConvertTritonToTritonGPU
809821
populateArithPatternsAndLegality(typeConverter, patterns, target);
810822
populateMathPatternsAndLegality(typeConverter, patterns, target);
811823
populateTritonPatterns(typeConverter, patterns, numCTAs);
824+
populateProtonPatterns(typeConverter, patterns);
812825
// TODO: can we use
813826
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
814827
populateSCFPatterns(typeConverter, patterns);

0 commit comments

Comments
 (0)