Skip to content

Commit 33d511d

Browse files
Merge commit 'bf0f56dc318cca3fa568e49e85136cbce09f12f3'
2 parents afe1e9d + bf0f56d commit 33d511d

File tree

219 files changed

+9933
-731
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

219 files changed

+9933
-731
lines changed

CMakeLists.txt

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ if(TRITON_BUILD_PYTHON_MODULE)
228228
endif()
229229
# We always build proton dialect
230230
list(APPEND TRITON_PLUGIN_NAMES "proton")
231-
add_subdirectory(third_party/proton/dialect)
231+
add_subdirectory(third_party/proton/Dialect)
232232

233233
get_property(triton_libs GLOBAL PROPERTY TRITON_LIBS)
234234
get_property(triton_plugins GLOBAL PROPERTY TRITON_PLUGINS)
@@ -360,7 +360,7 @@ if(NOT TRITON_BUILD_PYTHON_MODULE)
360360
foreach(CODEGEN_BACKEND ${TRITON_CODEGEN_BACKENDS})
361361
add_subdirectory(third_party/${CODEGEN_BACKEND})
362362
endforeach()
363-
add_subdirectory(third_party/proton/dialect)
363+
add_subdirectory(third_party/proton/Dialect)
364364
endif()
365365

366366
find_package(Threads REQUIRED)

Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ test-interpret: all
6767

6868
.PHONY: test-proton
6969
test-proton: all
70-
$(PYTEST) -s -n 8 third_party/proton/test
70+
$(PYTEST) -s -n 8 third_party/proton/test --ignore=third_party/proton/test/test_override.py
71+
$(PYTEST) -s third_party/proton/test/test_override.py
7172

7273
.PHONY: test-python
7374
test-python: test-unit test-regression test-interpret test-proton

bin/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ target_link_libraries(triton-opt PRIVATE
1515
TritonTestAnalysis
1616
TritonTestDialect
1717
TritonAMDGPUTestAnalysis
18+
TritonTestProton
1819
# MLIR core
1920
MLIROptLib
2021
MLIRPass
@@ -35,6 +36,7 @@ target_link_libraries(triton-reduce PRIVATE
3536
TritonTestAnalysis
3637
TritonTestDialect
3738
TritonAMDGPUTestAnalysis
39+
TritonTestProton
3840
# MLIR core
3941
MLIRReduceLib
4042
MLIRPass
@@ -54,6 +56,7 @@ target_link_libraries(triton-lsp PRIVATE
5456
TritonTestAnalysis
5557
TritonTestDialect
5658
TritonAMDGPUTestAnalysis
59+
TritonTestProton
5760
# MLIR core
5861
MLIRLspServerLib
5962
MLIRPass
@@ -92,5 +95,6 @@ target_link_libraries(triton-tensor-layout PRIVATE
9295
${dialect_libs}
9396
TritonTestAnalysis
9497
TritonTestDialect
98+
TritonTestProton
9599
TritonAMDGPUTestAnalysis
96100
)

bin/RegisterTritonDialects.h

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,15 @@
1212

1313
#include "amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"
1414
#include "amd/include/TritonAMDGPUTransforms/Passes.h"
15-
#include "third_party/nvidia/include/Dialect/NVGPU/IR/Dialect.h"
16-
#include "third_party/nvidia/include/Dialect/NVWS/IR/Dialect.h"
17-
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
15+
#include "nvidia/include/Dialect/NVGPU/IR/Dialect.h"
16+
#include "nvidia/include/Dialect/NVWS/IR/Dialect.h"
17+
#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/Passes.h"
18+
#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonAMDGPUToLLVM/Passes.h"
19+
#include "proton/Dialect/include/Conversion/ProtonGPUToLLVM/ProtonNvidiaGPUToLLVM/Passes.h"
20+
#include "proton/Dialect/include/Conversion/ProtonToProtonGPU/Passes.h"
21+
#include "proton/Dialect/include/Dialect/Proton/IR/Dialect.h"
22+
#include "proton/Dialect/include/Dialect/ProtonGPU/IR/Dialect.h"
23+
#include "proton/Dialect/include/Dialect/ProtonGPU/Transforms/Passes.h"
1824
#include "triton/Dialect/Gluon/Transforms/Passes.h"
1925
#include "triton/Dialect/Triton/IR/Dialect.h"
2026
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
@@ -58,6 +64,9 @@ void registerTestMembarPass();
5864
void registerTestAMDGPUMembarPass();
5965
void registerTestTritonAMDGPURangeAnalysis();
6066
void registerTestLoopPeelingPass();
67+
namespace proton {
68+
void registerTestScopeIdAllocationPass();
69+
} // namespace proton
6170
} // namespace test
6271
} // namespace mlir
6372

@@ -127,6 +136,16 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
127136
// NVGPU transform passes
128137
mlir::registerNVHopperTransformsPasses();
129138

139+
// Proton passes
140+
mlir::test::proton::registerTestScopeIdAllocationPass();
141+
mlir::triton::proton::registerConvertProtonToProtonGPU();
142+
mlir::triton::proton::gpu::registerConvertProtonNvidiaGPUToLLVM();
143+
mlir::triton::proton::gpu::registerConvertProtonAMDGPUToLLVM();
144+
mlir::triton::proton::gpu::registerAllocateProtonSharedMemoryPass();
145+
mlir::triton::proton::gpu::registerAllocateProtonGlobalScratchBufferPass();
146+
mlir::triton::proton::gpu::registerScheduleBufferStorePass();
147+
mlir::triton::proton::gpu::registerAddSchedBarriersPass();
148+
130149
registry.insert<
131150
mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,
132151
mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect,
@@ -136,7 +155,8 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
136155
mlir::gpu::GPUDialect, mlir::LLVM::LLVMDialect, mlir::NVVM::NVVMDialect,
137156
mlir::triton::nvgpu::NVGPUDialect, mlir::triton::nvws::NVWSDialect,
138157
mlir::triton::amdgpu::TritonAMDGPUDialect,
139-
mlir::triton::proton::ProtonDialect, mlir::ROCDL::ROCDLDialect,
158+
mlir::triton::proton::ProtonDialect,
159+
mlir::triton::proton::gpu::ProtonGPUDialect, mlir::ROCDL::ROCDLDialect,
140160
mlir::triton::gpu::intel::TritonIntelGPUDialect,
141161
mlir::triton::TritonGEN::TritonGENDialect,
142162
mlir::triton::gluon::GluonDialect>();

include/triton/Conversion/TritonGPUToLLVM/Utility.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,12 @@ class TritonLLVMIRRewriter : public IRRewriter, public TritonLLVMOpBuilder {
319319
#define str_attr(str) ::mlir::StringAttr::get(ctx, (str))
320320

321321
namespace mlir {
322+
323+
// See FuncOpToLLVM.cpp for details about Triton's function calling conventions
324+
constexpr int kProfileScratchBufferOffset = -1;
325+
constexpr int kGlobalScratchBufferOffset = -2;
326+
constexpr int kSharedMemoryOffset = -3;
327+
322328
namespace triton {
323329

324330
namespace gpu {
@@ -439,6 +445,9 @@ Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
439445
const TargetInfoBase &targetInfo,
440446
FunctionOpInterface funcOp, Value allocOffset);
441447

448+
Value getProfileScratchPtr(Location loc, RewriterBase &rewriter,
449+
FunctionOpInterface funcOp);
450+
442451
Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
443452
const TargetInfoBase &target, Operation *op);
444453

include/triton/Tools/Sys/GetEnv.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ inline const std::set<std::string> CACHE_INVALIDATING_ENV_VARS = {
5959
inline const std::set<std::string> CACHE_NEUTRAL_ENV_VARS = {
6060
// clang-format off
6161
"TRITON_REPRODUCER_PATH",
62-
"TRITON_ENABLE_PYTHON_STACKTRACE"
62+
"TRITON_ENABLE_PYTHON_STACKTRACE",
6363
// clang-format on
6464
};
6565

lib/Conversion/TritonGPUToLLVM/ControlFlowOpToLLVM.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,8 @@ struct CallOpConversion : public ConvertOpToLLVMPattern<triton::CallOp> {
103103

104104
promotedOperands.push_back(LLVM::getGlobalScratchPtr(
105105
loc, rewriter, targetInfo, caller, opOffsetVal));
106+
promotedOperands.push_back(
107+
LLVM::getProfileScratchPtr(loc, rewriter, caller));
106108
return promotedOperands;
107109
}
108110

lib/Conversion/TritonGPUToLLVM/FuncOpToLLVM.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,20 @@ using namespace mlir;
1010
using namespace mlir::triton;
1111

1212
// NOTE: [Additional Function Arguments]
13+
// Triton patches additional arguments to the function signature to support
14+
// (1) shared memory, (2) global scratch memory, and (3) profile scratch memory.
1315
// To support use of shared memory and global scratch memory inside of a
1416
// function, the caller allocates a single large block of the relevant memory
1517
// and calls the function with these extra arguments at the end.
16-
// Specifically, the last argument is the global scratch memory allocation and
17-
// the second to last is the shared memory allocation.
18+
// Profile scratch memory is only used when the function is instrumented for
19+
// profiling.
1820
//
1921
// For the kernel function itself, the shared memory base is a global symbol
2022
// so no additional function argument is required but global scratch memory
2123
// allocation is still passed in as the last argument. Though here the scratch
2224
// memory is shared between all programs, so a linear offset based on the
2325
// program id is required to get the local scratch base.
2426

25-
/// FuncOp legalization pattern that converts MemRef arguments to pointers to
26-
/// MemRef descriptors (LLVM struct data types) containing all the MemRef type
27-
/// information.
2827
struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
2928
FuncOpConversion(LLVMTypeConverter &converter,
3029
const TargetInfoBase &targetInfo, PatternBenefit benefit)
@@ -56,6 +55,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
5655
auto sharedPtrTy =
5756
LLVM::LLVMPointerType::get(ctx, targetInfo.getSharedAddressSpace());
5857
auto globalPtrTy = LLVM::LLVMPointerType::get(ctx, 1);
58+
auto profilePtrTy = LLVM::LLVMPointerType::get(ctx, 1);
5959

6060
// 1. Modify the function type to add the new arguments.
6161
auto funcTy = funcOp.getFunctionType();
@@ -73,6 +73,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
7373
amendedInputTy.push_back(sharedPtrTy);
7474
}
7575
amendedInputTy.push_back(globalPtrTy);
76+
amendedInputTy.push_back(profilePtrTy);
7677
auto amendedFuncTy =
7778
FunctionType::get(ctx, amendedInputTy, funcTy.getResults());
7879
// 2. Modify the argument attributes to add the new argument.
@@ -97,6 +98,7 @@ struct FuncOpConversion : public ConvertOpToLLVMPattern<triton::FuncOp> {
9798
region.addArgument(sharedPtrTy, loc);
9899
}
99100
region.addArgument(globalPtrTy, loc);
101+
region.addArgument(profilePtrTy, loc);
100102
rewriter.inlineRegionBefore(region, amendedFuncOp.getBody(),
101103
amendedFuncOp.end());
102104
return amendedFuncOp;

lib/Conversion/TritonGPUToLLVM/Utility.cpp

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1198,7 +1198,7 @@ SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc,
11981198
Value getStackPointer(RewriterBase &rewriter, FunctionOpInterface funcOp) {
11991199
// See NOTE: [Additional Function Arguments]
12001200
if (!isKernel(funcOp)) {
1201-
return funcOp.getArgument(funcOp.getNumArguments() - 2);
1201+
return funcOp.getArgument(funcOp.getNumArguments() + kSharedMemoryOffset);
12021202
}
12031203

12041204
auto mod = funcOp->getParentOfType<ModuleOp>();
@@ -1213,7 +1213,8 @@ Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
12131213
// See NOTE: [Additional Function Arguments]
12141214
if (!isKernel(funcOp)) {
12151215
// Base for this function
1216-
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1);
1216+
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() +
1217+
kGlobalScratchBufferOffset);
12171218
if (!allocOffset) {
12181219
return gmemBase;
12191220
}
@@ -1224,7 +1225,8 @@ Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
12241225
}
12251226

12261227
// Base for entire kernel
1227-
auto gmemBase = funcOp.getArgument(funcOp.getNumArguments() - 1);
1228+
auto gmemBase =
1229+
funcOp.getArgument(funcOp.getNumArguments() + kGlobalScratchBufferOffset);
12281230

12291231
ModuleOp mod = funcOp.getOperation()->getParentOfType<ModuleOp>();
12301232
auto allocSizeAttr = mod.getOperation()->getAttrOfType<mlir::IntegerAttr>(
@@ -1266,6 +1268,15 @@ Value getGlobalScratchPtr(Location loc, RewriterBase &rewriter,
12661268
return res;
12671269
}
12681270

1271+
Value getProfileScratchPtr(Location loc, RewriterBase &rewriter,
1272+
FunctionOpInterface funcOp) {
1273+
// See NOTE: [Additional Function Arguments]
1274+
// FIXME(Keren): This is broken when we have device functions, we
1275+
// need to implement proper calling convention
1276+
return funcOp.getArgument(funcOp.getNumArguments() +
1277+
kProfileScratchBufferOffset);
1278+
}
1279+
12691280
Value getSharedMemoryBase(Location loc, RewriterBase &rewriter,
12701281
const TargetInfoBase &target, Operation *op) {
12711282
auto ptrTy = LLVM::LLVMPointerType::get(rewriter.getContext(),

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 0 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
#include "mlir/Dialect/UB/IR/UBOps.h"
44
#include "mlir/Pass/Pass.h"
55
#include "mlir/Transforms/DialectConversion.h"
6-
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
76
#include "triton/Conversion/TritonToTritonGPU/Passes.h"
87
#include "triton/Dialect/Triton/IR/Dialect.h"
98
#include "triton/Dialect/Triton/IR/Utility.h"
@@ -597,17 +596,6 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
597596
// clang-format on
598597
>(typeConverter, context);
599598
}
600-
// Proton patterns
601-
// NOTE: Because Proton's inputs are scalars and not tensors this conversion
602-
// isn't strictly necessary however you could envision a case where we pass in
603-
// tensors in for Triton object specific tracing operations in which case we
604-
// would need to fill in the OpConversionPattern
605-
void populateProtonPatterns(TritonGPUTypeConverter &typeConverter,
606-
RewritePatternSet &patterns) {
607-
MLIRContext *context = patterns.getContext();
608-
patterns.add<GenericOpPattern<triton::proton::RecordOp>>(typeConverter,
609-
context);
610-
}
611599
//
612600
// SCF patterns
613601
//
@@ -821,7 +809,6 @@ class ConvertTritonToTritonGPU
821809
populateArithPatternsAndLegality(typeConverter, patterns, target);
822810
populateMathPatternsAndLegality(typeConverter, patterns, target);
823811
populateTritonPatterns(typeConverter, patterns, numCTAs);
824-
populateProtonPatterns(typeConverter, patterns);
825812
// TODO: can we use
826813
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
827814
populateSCFPatterns(typeConverter, patterns);

0 commit comments

Comments
 (0)