Skip to content

Commit 9426d5d

Browse files
Reland "[PROTON] Intra kernel profiling (#7258)" (#4953)
This reverts commit 688adf6. PyTorch core CI: https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/17244586569 (passed)
2 parents a4e9b4f + 92c45c6 commit 9426d5d

File tree

234 files changed

+9781
-762
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

234 files changed

+9781
-762
lines changed

.github/workflows/build-test-reusable.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,8 @@ jobs:
285285
if: matrix.suite == 'rest' && inputs.driver_version == 'rolling' && inputs.device == 'max1100'
286286
run: |
287287
cd third_party/proton/test
288-
pytest test_api.py test_lib.py test_profile.py test_viewer.py test_record.py -s -v
288+
# FIXME: enable 'test_record.py' back
289+
pytest test_api.py test_lib.py test_profile.py test_viewer.py -s -v
289290
cd ..
290291
291292
- name: Run minicore tests

.github/workflows/pip-test.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ jobs:
5050
gh_token: ${{ secrets.GITHUB_TOKEN }}
5151
python_version: ${{ env.PYTHON_VERSION }}
5252
# transformers package is required for the inductor (e2e) test
53-
wheels_pattern: '{torch,transformers}-*.whl'
53+
wheels_pattern: 'torch-*.whl'
5454

5555
- name: Install Triton
5656
uses: ./.github/actions/setup-triton
@@ -61,6 +61,7 @@ jobs:
6161
sed -i '/^validate_nccl_dep_consistency.*/d' generate_binary_build_matrix.py
6262
python -c "from generate_binary_build_matrix import PYTORCH_EXTRA_INSTALL_REQUIREMENTS; print('\n'.join(PYTORCH_EXTRA_INSTALL_REQUIREMENTS['xpu'].split(' | ')))" | tee /tmp/requirements.txt
6363
pip install -r /tmp/requirements.txt
64+
pip install transformers==4.54.0
6465
6566
- name: Run core tests
6667
run: |

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
228228
endif()
229229
# We always build proton dialect
230230
list(APPEND TRITON_PLUGIN_NAMES "proton")
231-
add_subdirectory(third_party/proton/dialect)
231+
add_subdirectory(third_party/proton/Dialect)
232232

233233
get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)
234234
get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS)
@@ -360,7 +360,7 @@ if(NOT TRITON_BUILD_PYTHON_MODULE)
360360
foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS})
361361
add_subdirectory(third_party/${CODEGEN_BACKEND})
362362
endforeach()
363-
add_subdirectory(third_party/proton/dialect)
363+
add_subdirectory(third_party/proton/Dialect)
364364
endif()
365365

366366
find_package(Threads REQUIRED)

Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,8 @@ test-interpret: all
6969

7070
.PHONY: test-proton
7171
test-proton: all
72-
$(PYTEST) -s -n 8 third_party/proton/test
72+
$(PYTEST) -s -n 8 third_party/proton/test --ignore=third_party/proton/test/test_override.py
73+
$(PYTEST) -s third_party/proton/test/test_override.py
7374

7475
.PHONY: test-python
7576
test-python: test-unit test-regression test-interpret test-proton

bin/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ target_link_libraries(triton-opt PRIVATE
1515
TritonTestAnalysis
1616
TritonTestDialect
1717
TritonAMDGPUTestAnalysis
18+
TritonTestProton
1819
# MLIR core
1920
MLIROptLib
2021
MLIRPass
@@ -35,6 +36,7 @@ target_link_libraries(triton-reduce PRIVATE
3536
TritonTestAnalysis
3637
TritonTestDialect
3738
TritonAMDGPUTestAnalysis
39+
TritonTestProton
3840
# MLIR core
3941
MLIRReduceLib
4042
MLIRPass
@@ -54,6 +56,7 @@ target_link_libraries(triton-lsp PRIVATE
5456
TritonTestAnalysis
5557
TritonTestDialect
5658
TritonAMDGPUTestAnalysis
59+
TritonTestProton
5760
# MLIR core
5861
MLIRLspServerLib
5962
MLIRPass
@@ -92,5 +95,6 @@ target_link_libraries(triton-tensor-layout PRIVATE
9295
${dialect_libs}
9396
TritonTestAnalysis
9497
TritonTestDialect
98+
TritonTestProton
9599
TritonAMDGPUTestAnalysis
96100
)

bin/RegisterTritonDialects.h

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,15 @@
1212

1313
#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
1414
#include "amd/include/TritonAMDGPUTransforms/Passes.h"
15-
#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h"
16-
#include "third_party/nvidia/include/Dialect/NVWS/IR/Dialect.h"
17-
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
15+
#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h"
16+
#include "nvidia/include/Dialect/NVWS/IR/Dialect.h"
17+
#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/Passes.h"
18+
#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/Passes.h"
19+
#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/Passes.h"
20+
#include "proton/Dialect/include/Conversion/ProtonToProtonGPU/Passes.h"
21+
#include "proton/Dialect/include/Dialect/Proton/IR/Dialect.h"
22+
#include "proton/Dialect/include/Dialect/ProtonGPU/IR/Dialect.h"
23+
#include "proton/Dialect/include/Dialect/ProtonGPU/Transforms/Passes.h"
1824
#include "triton/Dialect/Gluon/Transforms/Passes.h"
1925
#include "triton/Dialect/Triton/IR/Dialect.h"
2026
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -58,6 +64,9 @@ void registerTestMembarPass();
5864
void registerTestAMDGPUMembarPass();
5965
void registerTestTritonAMDGPURangeAnalysis();
6066
void registerTestLoopPeelingPass();
67+
namespace proton {
68+
void registerTestScopeIdAllocationPass();
69+
} // namespace proton
6170
} // namespace test
6271
} // namespace mlir
6372

@@ -127,6 +136,16 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
127136
// NVGPU transform passes
128137
mlir::registerNVHopperTransformsPasses();
129138

139+
// Proton passes
140+
mlir::test::proton::registerTestScopeIdAllocationPass();
141+
mlir::triton::proton::registerConvertProtonToProtonGPU();
142+
mlir::triton::proton::gpu::registerConvertProtonNvidiaGPUToLLVM();
143+
mlir::triton::proton::gpu::registerConvertProtonAMDGPUToLLVM();
144+
mlir::triton::proton::gpu::registerAllocateProtonSharedMemoryPass();
145+
mlir::triton::proton::gpu::registerAllocateProtonGlobalScratchBufferPass();
146+
mlir::triton::proton::gpu::registerScheduleBufferStorePass();
147+
mlir::triton::proton::gpu::registerAddSchedBarriersPass();
148+
130149
registry.insert<
131150
mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
132151
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
@@ -136,7 +155,8 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
136155
mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect,
137156
mlir::triton::nvgpu::NVGPUDialect, mlir::triton::nvws::NVWSDialect,
138157
mlir::triton::amdgpu::TritonAMDGPUDialect,
139-
mlir::triton::proton::ProtonDialect, mlir::ROCDL::ROCDLDialect,
158+
mlir::triton::proton::ProtonDialect,
159+
mlir::triton::proton::gpu::ProtonGPUDialect, mlir::ROCDL::ROCDLDialect,
140160
mlir::triton::gpu::intel::TritonIntelGPUDialect,
141161
mlir::triton::TritonGEN::TritonGENDialect,
142162
mlir::triton::gluon::GluonDialect>();

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,12 @@ class TritonLLVMIRRewriter : public IRRewriter, public TritonLLVMOpBuilder {
319319
#define str_attr(str) ::mlir::StringAttr::get(ctx, (str))
320320

321321
namespace mlir {
322+
323+
// See FuncOpToLLVM.cpp for details about Triton's function calling conventions
324+
constexpr int kProfileScratchBufferOffset = -1;
325+
constexpr int kGlobalScratchBufferOffset = -2;
326+
constexpr int kSharedMemoryOffset = -3;
327+
322328
namespace triton {
323329

324330
namespace gpu {
@@ -439,6 +445,9 @@ Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
439445
const TargetInfoBase &targetInfo,
440446
FunctionOpInterface funcOp, Value allocOffset);
441447

448+
Value getProfileScratchPtr(Location loc, RewriterBase &rewriter,
449+
FunctionOpInterface funcOp);
450+
442451
Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
443452
const TargetInfoBase &target, Operation *op);
444453

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
6060
inline const std::set<std::string> CACHE_NEUTRAL_ENV_VARS = {
6161
// clang-format off
6262
"TRITON_REPRODUCER_PATH",
63-
"TRITON_ENABLE_PYTHON_STACKTRACE"
63+
"TRITON_ENABLE_PYTHON_STACKTRACE",
6464
// clang-format on
6565
};
6666

lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
103103

104104
promotedOperands.push_back(LLVM::getGlobalScratchPtr(
105105
loc, rewriter, targetInfo, caller, opOffsetVal));
106+
promotedOperands.push_back(
107+
LLVM::getProfileScratchPtr(loc, rewriter, caller));
106108
return promotedOperands;
107109
}
108110

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,20 @@ using namespace mlir;
1010
using namespace mlir::triton;
1111

1212
// NOTE: [Additional Function Arguments]
13+
// Triton patches additional arguments to the function signature to support
14+
// (1) shared memory, (2) global scratch memory, and (3) profile scratch memory.
1315
// To support use of shared memory and global scratch memory inside of a
1416
// function, the caller allocates a single large block of the relevant memory
1517
// and calls the function with these extra arguments at the end.
16-
// Specifically, the last argument is the global scratch memory allocation and
17-
// the second to last is the shared memory allocation.
18+
// Profile scratch memory is only used when the function is instrumented for
19+
// profiling.
1820
//
1921
// For the kernel function itself, the shared memory base is a global symbol
2022
// so no additional function argument is required but global scratch memory
2123
// allocation is still passed in as the last argument. Though here the scratch
2224
// memory is shared between all programs, so a linear offset based on the
2325
// program id is required to get the local scratch base.
2426

25-
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
26-
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
27-
/// information.
2827
struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
2928
FuncOpConversion(LLVMTypeConverter &converter,
3029
const TargetInfoBase &targetInfo, PatternBenefit benefit)
@@ -56,6 +55,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
5655
auto sharedPtrTy =
5756
LLVM::LLVMPointerType::get(ctx, targetInfo.getSharedAddressSpace());
5857
auto globalPtrTy = LLVM::LLVMPointerType::get(ctx, 1);
58+
auto profilePtrTy = LLVM::LLVMPointerType::get(ctx, 1);
5959

6060
// 1. Modify the function type to add the new arguments.
6161
auto funcTy = funcOp.getFunctionType();
@@ -73,6 +73,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
7373
amendedInputTy.push_back(sharedPtrTy);
7474
}
7575
amendedInputTy.push_back(globalPtrTy);
76+
amendedInputTy.push_back(profilePtrTy);
7677
auto amendedFuncTy =
7778
FunctionType::get(ctx, amendedInputTy, funcTy.getResults());
7879
// 2. Modify the argument attributes to add the new argument.
@@ -97,6 +98,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
9798
region.addArgument(sharedPtrTy, loc);
9899
}
99100
region.addArgument(globalPtrTy, loc);
101+
region.addArgument(profilePtrTy, loc);
100102
rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(),
101103
amendedFuncOp.end());
102104
return amendedFuncOp;

0 commit comments

Comments
 (0)