Skip to content

Commit bf0f56d

Browse files
JokerenfywkevinYuanwei FangCRobeckpeterbell10
authored
[PROTON] Intra kernel profiling (#7258)
### Instrumentation & Runtime - Introduce a dedicated instrumentation mode - `proton.start(..., mode="instrumentation", ...)` - Introduce both high- and low- level scope APIs - For Gluon DSL: `pl.scope`, `pl.enter_scope`, and `pl.exit_scope`. Profiling API for Triton DSL is disabled by default. - For TTGIR: `proton.record start` and `proton.record end` - Inject profiling buffers for each triton kernel at codegen time and pass them to the proton runtime so kernels can push data directly from the device to the host ### Proton Dialect & Lowering - Add Proton → ProtonGPU → LLVM pipelines, including passes for shared-memory allocation, profile scratch allocation, and a few optimizations for reduced overhead or improved accuracy. ### Tracing - `proton.start(..., data="trace", ...)` is supported for both fine- and coarse-grained events. --------- Co-authored-by: Yuanwei Fang <[email protected]> Co-authored-by: Yuanwei Fang <[email protected]> Co-authored-by: Corbin Robeck <[email protected]> Co-authored-by: peterbell10 <[email protected]> Co-authored-by: Corbin Robeck <[email protected]> Co-authored-by: Corbin Robeck <[email protected]> Co-authored-by: robeck <[email protected]> Co-authored-by: Srivatsan Ramesh <[email protected]> Co-authored-by: Shawn Zhong <[email protected]> Co-authored-by: Shawn Zhong <[email protected]> Co-authored-by: 鐘天楽 <[email protected]>
1 parent 96e53bb commit bf0f56d

File tree

221 files changed

+9936
-824
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

221 files changed

+9936
-824
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
209209
endif()
210210
# We always build proton dialect
211211
list(APPEND TRITON_PLUGIN_NAMES "proton")
212-
add_subdirectory(third_party/proton/dialect)
212+
add_subdirectory(third_party/proton/Dialect)
213213

214214
get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)
215215
get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS)
@@ -335,7 +335,7 @@ if(NOT TRITON_BUILD_PYTHON_MODULE)
335335
foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS})
336336
add_subdirectory(third_party/${CODEGEN_BACKEND})
337337
endforeach()
338-
add_subdirectory(third_party/proton/dialect)
338+
add_subdirectory(third_party/proton/Dialect)
339339
endif()
340340

341341
find_package(Threads REQUIRED)

Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ test-interpret: all
6767

6868
.PHONY: test-proton
6969
test-proton: all
70-
$(PYTEST) -s -n 8 third_party/proton/test
70+
$(PYTEST) -s -n 8 third_party/proton/test --ignore=third_party/proton/test/test_override.py
71+
$(PYTEST) -s third_party/proton/test/test_override.py
7172

7273
.PHONY: test-python
7374
test-python: test-unit test-regression test-interpret test-proton

bin/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ target_link_libraries(triton-opt PRIVATE
1414
TritonTestAnalysis
1515
TritonTestDialect
1616
TritonAMDGPUTestAnalysis
17+
TritonTestProton
1718
# MLIR core
1819
MLIROptLib
1920
MLIRPass
@@ -34,6 +35,7 @@ target_link_libraries(triton-reduce PRIVATE
3435
TritonTestAnalysis
3536
TritonTestDialect
3637
TritonAMDGPUTestAnalysis
38+
TritonTestProton
3739
# MLIR core
3840
MLIRReduceLib
3941
MLIRPass
@@ -53,6 +55,7 @@ target_link_libraries(triton-lsp PRIVATE
5355
TritonTestAnalysis
5456
TritonTestDialect
5557
TritonAMDGPUTestAnalysis
58+
TritonTestProton
5659
# MLIR core
5760
MLIRLspServerLib
5861
MLIRPass
@@ -89,5 +92,6 @@ target_link_libraries(triton-tensor-layout PRIVATE
8992
${dialect_libs}
9093
TritonTestAnalysis
9194
TritonTestDialect
95+
TritonTestProton
9296
TritonAMDGPUTestAnalysis
9397
)

bin/RegisterTritonDialects.h

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,15 @@
11
#pragma once
22
#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
33
#include "amd/include/TritonAMDGPUTransforms/Passes.h"
4-
#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h"
5-
#include "third_party/nvidia/include/Dialect/NVWS/IR/Dialect.h"
6-
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
4+
#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h"
5+
#include "nvidia/include/Dialect/NVWS/IR/Dialect.h"
6+
#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/Passes.h"
7+
#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/Passes.h"
8+
#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/Passes.h"
9+
#include "proton/Dialect/include/Conversion/ProtonToProtonGPU/Passes.h"
10+
#include "proton/Dialect/include/Dialect/Proton/IR/Dialect.h"
11+
#include "proton/Dialect/include/Dialect/ProtonGPU/IR/Dialect.h"
12+
#include "proton/Dialect/include/Dialect/ProtonGPU/Transforms/Passes.h"
713
#include "triton/Dialect/Gluon/Transforms/Passes.h"
814
#include "triton/Dialect/Triton/IR/Dialect.h"
915
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -42,6 +48,9 @@ void registerTestMembarPass();
4248
void registerTestAMDGPUMembarPass();
4349
void registerTestTritonAMDGPURangeAnalysis();
4450
void registerTestLoopPeelingPass();
51+
namespace proton {
52+
void registerTestScopeIdAllocationPass();
53+
} // namespace proton
4554
} // namespace test
4655
} // namespace mlir
4756

@@ -99,6 +108,16 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
99108
// NVGPU transform passes
100109
mlir::registerNVHopperTransformsPasses();
101110

111+
// Proton passes
112+
mlir::test::proton::registerTestScopeIdAllocationPass();
113+
mlir::triton::proton::registerConvertProtonToProtonGPU();
114+
mlir::triton::proton::gpu::registerConvertProtonNvidiaGPUToLLVM();
115+
mlir::triton::proton::gpu::registerConvertProtonAMDGPUToLLVM();
116+
mlir::triton::proton::gpu::registerAllocateProtonSharedMemoryPass();
117+
mlir::triton::proton::gpu::registerAllocateProtonGlobalScratchBufferPass();
118+
mlir::triton::proton::gpu::registerScheduleBufferStorePass();
119+
mlir::triton::proton::gpu::registerAddSchedBarriersPass();
120+
102121
registry.insert<
103122
mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
104123
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
@@ -108,6 +127,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
108127
mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect,
109128
mlir::triton::nvgpu::NVGPUDialect, mlir::triton::nvws::NVWSDialect,
110129
mlir::triton::amdgpu::TritonAMDGPUDialect,
111-
mlir::triton::proton::ProtonDialect, mlir::ROCDL::ROCDLDialect,
130+
mlir::triton::proton::ProtonDialect,
131+
mlir::triton::proton::gpu::ProtonGPUDialect, mlir::ROCDL::ROCDLDialect,
112132
mlir::triton::gluon::GluonDialect>();
113133
}

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,12 @@ class TritonLLVMIRRewriter : public IRRewriter, public TritonLLVMOpBuilder {
319319
#define str_attr(str) ::mlir::StringAttr::get(ctx, (str))
320320

321321
namespace mlir {
322+
323+
// See FuncOpToLLVM.cpp for details about Triton's function calling conventions
324+
constexpr int kProfileScratchBufferOffset = -1;
325+
constexpr int kGlobalScratchBufferOffset = -2;
326+
constexpr int kSharedMemoryOffset = -3;
327+
322328
namespace triton {
323329

324330
namespace gpu {
@@ -439,6 +445,9 @@ Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
439445
const TargetInfoBase &targetInfo,
440446
FunctionOpInterface funcOp, Value allocOffset);
441447

448+
Value getProfileScratchPtr(Location loc, RewriterBase &rewriter,
449+
FunctionOpInterface funcOp);
450+
442451
Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
443452
const TargetInfoBase &target, Operation *op);
444453

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
5050
inline const std::set<std::string> CACHE_NEUTRAL_ENV_VARS = {
5151
// clang-format off
5252
"TRITON_REPRODUCER_PATH",
53-
"TRITON_ENABLE_PYTHON_STACKTRACE"
53+
"TRITON_ENABLE_PYTHON_STACKTRACE",
5454
// clang-format on
5555
};
5656

lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
103103

104104
promotedOperands.push_back(LLVM::getGlobalScratchPtr(
105105
loc, rewriter, targetInfo, caller, opOffsetVal));
106+
promotedOperands.push_back(
107+
LLVM::getProfileScratchPtr(loc, rewriter, caller));
106108
return promotedOperands;
107109
}
108110

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,20 @@ using namespace mlir;
1010
using namespace mlir::triton;
1111

1212
// NOTE: [Additional Function Arguments]
13+
// Triton patches additional arguments to the function signature to support
14+
// (1) shared memory, (2) global scratch memory, and (3) profile scratch memory.
1315
// To support use of shared memory and global scratch memory inside of a
1416
// function, the caller allocates a single large block of the relevant memory
1517
// and calls the function with these extra arguments at the end.
16-
// Specifically, the last argument is the global scratch memory allocation and
17-
// the second to last is the shared memory allocation.
18+
// Profile scratch memory is only used when the function is instrumented for
19+
// profiling.
1820
//
1921
// For the kernel function itself, the shared memory base is a global symbol
2022
// so no additional function argument is required but global scratch memory
2123
// allocation is still passed in as the last argument. Though here the scratch
2224
// memory is shared between all programs, so a linear offset based on the
2325
// program id is required to get the local scratch base.
2426

25-
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
26-
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
27-
/// information.
2827
struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
2928
FuncOpConversion(LLVMTypeConverter &converter,
3029
const TargetInfoBase &targetInfo, PatternBenefit benefit)
@@ -56,6 +55,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
5655
auto sharedPtrTy =
5756
LLVM::LLVMPointerType::get(ctx, targetInfo.getSharedAddressSpace());
5857
auto globalPtrTy = LLVM::LLVMPointerType::get(ctx, 1);
58+
auto profilePtrTy = LLVM::LLVMPointerType::get(ctx, 1);
5959

6060
// 1. Modify the function type to add the new arguments.
6161
auto funcTy = funcOp.getFunctionType();
@@ -73,6 +73,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
7373
amendedInputTy.push_back(sharedPtrTy);
7474
}
7575
amendedInputTy.push_back(globalPtrTy);
76+
amendedInputTy.push_back(profilePtrTy);
7677
auto amendedFuncTy =
7778
FunctionType::get(ctx, amendedInputTy, funcTy.getResults());
7879
// 2. Modify the argument attributes to add the new argument.
@@ -97,6 +98,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
9798
region.addArgument(sharedPtrTy, loc);
9899
}
99100
region.addArgument(globalPtrTy, loc);
101+
region.addArgument(profilePtrTy, loc);
100102
rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(),
101103
amendedFuncOp.end());
102104
return amendedFuncOp;

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,7 +1198,7 @@ SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc,
11981198
Value getStackPointer(RewriterBase &rewriter, FunctionOpInterface funcOp) {
11991199
// See NOTE: [Additional Function Arguments]
12001200
if (!isKernel(funcOp)) {
1201-
return funcOp.getArgument(funcOp.getNumArguments() - 2);
1201+
return funcOp.getArgument(funcOp.getNumArguments() + kSharedMemoryOffset);
12021202
}
12031203

12041204
auto mod = funcOp->getParentOfType<ModuleOp>();
@@ -1213,7 +1213,8 @@ Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
12131213
// See NOTE: [Additional Function Arguments]
12141214
if (!isKernel(funcOp)) {
12151215
// Base for this function
1216-
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1);
1216+
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() +
1217+
kGlobalScratchBufferOffset);
12171218
if (!allocOffset) {
12181219
return gmemBase;
12191220
}
@@ -1224,7 +1225,8 @@ Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
12241225
}
12251226

12261227
// Base for entire kernel
1227-
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1);
1228+
auto gmemBase =
1229+
funcOp.getArgument(funcOp.getNumArguments() + kGlobalScratchBufferOffset);
12281230

12291231
ModuleOp mod = funcOp.getOperation()->getParentOfType<ModuleOp>();
12301232
auto allocSizeAttr = mod.getOperation()->getAttrOfType<mlir::IntegerAttr>(
@@ -1266,6 +1268,15 @@ Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
12661268
return res;
12671269
}
12681270

1271+
Value getProfileScratchPtr(Location loc, RewriterBase &rewriter,
1272+
FunctionOpInterface funcOp) {
1273+
// See NOTE: [Additional Function Arguments]
1274+
// FIXME(Keren): This is broken when we have device functions, we
1275+
// need to implement proper calling convention
1276+
return funcOp.getArgument(funcOp.getNumArguments() +
1277+
kProfileScratchBufferOffset);
1278+
}
1279+
12691280
Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
12701281
const TargetInfoBase &target, Operation *op) {
12711282
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(),

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include "mlir/Dialect/UB/IR/UBOps.h"
44
#include "mlir/Pass/Pass.h"
55
#include "mlir/Transforms/DialectConversion.h"
6-
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
76
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
87
#include "triton/Dialect/Triton/IR/Dialect.h"
98
#include "triton/Dialect/Triton/IR/Utility.h"
@@ -597,17 +596,6 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
597596
// clang-format on
598597
>(typeConverter, context);
599598
}
600-
// Proton patterns
601-
// NOTE: Because Proton's inputs are scalars and not tensors this conversion
602-
// isn't strictly necessary however you could envision a case where we pass in
603-
// tensors in for Triton object specific tracing operations in which case we
604-
// would need to fill in the OpConversionPattern
605-
void populateProtonPatterns(TritonGPUTypeConverter &typeConverter,
606-
RewritePatternSet &patterns) {
607-
MLIRContext *context = patterns.getContext();
608-
patterns.add<GenericOpPattern<triton::proton::RecordOp>>(typeConverter,
609-
context);
610-
}
611599
//
612600
// SCF patterns
613601
//
@@ -821,7 +809,6 @@ class ConvertTritonToTritonGPU
821809
populateArithPatternsAndLegality(typeConverter, patterns, target);
822810
populateMathPatternsAndLegality(typeConverter, patterns, target);
823811
populateTritonPatterns(typeConverter, patterns, numCTAs);
824-
populateProtonPatterns(typeConverter, patterns);
825812
// TODO: can we use
826813
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
827814
populateSCFPatterns(typeConverter, patterns);

0 commit comments

Comments
 (0)