Skip to content

Commit dcad5ac

Browse files
authored
[Proton][Dialect] Add Initial Frontend and Target Backend Infrastructure For Proton Dialect (#5506)
Implement initial basic infrastructure for the Proton Dialect added in triton-lang/triton#5119 This PR extends the initial boilerplate MLIR Dialect code to the Triton frontend and target backends - currently just lowered to a no-op.
1 parent 2b06b2c commit dcad5ac

File tree

13 files changed

+157
-2
lines changed

13 files changed

+157
-2
lines changed

lib/Conversion/TritonToTritonGPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ add_triton_library(TritonToTritonGPU
1010
MLIRPass
1111
MLIRTransforms
1212
TritonIR
13+
ProtonIR
1314
TritonGPUIR
1415
TritonGPUTransforms
1516
)

lib/Conversion/TritonToTritonGPU/TritonToTritonGPUPass.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#include "triton/Conversion/TritonToTritonGPU/Passes.h.inc"
1717
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
1818

19+
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
20+
1921
namespace {
2022

2123
using namespace mlir;
@@ -555,7 +557,17 @@ void populateTritonPatterns(TritonGPUTypeConverter &typeConverter,
555557
GenericOpPattern<triton::DotScaledOp>, GenericOpPattern<triton::CallOp>,
556558
TritonFuncOpPattern>(typeConverter, context);
557559
}
558-
560+
// Proton patterns
561+
// NOTE: Because Proton's inputs are scalars and not tensors this conversion
562+
// isn't strictly nessessary however you could envision a case where we pass in
563+
// tensors in for Triton object specific tracing operations in which case we
564+
// would need to fill in the OpConversionPattern
565+
void populateProtonPatterns(TritonGPUTypeConverter &typeConverter,
566+
RewritePatternSet &patterns) {
567+
MLIRContext *context = patterns.getContext();
568+
patterns.add<GenericOpPattern<triton::proton::RecordOp>>(typeConverter,
569+
context);
570+
}
559571
//
560572
// SCF patterns
561573
//
@@ -770,6 +782,7 @@ class ConvertTritonToTritonGPU
770782
populateArithPatternsAndLegality(typeConverter, patterns, target);
771783
populateMathPatternsAndLegality(typeConverter, patterns, target);
772784
populateTritonPatterns(typeConverter, patterns, numCTAs);
785+
populateProtonPatterns(typeConverter, patterns);
773786
// TODO: can we use
774787
// mlir::scf::populateSCFStructurealTypeConversionsAndLegality(...) here?
775788
populateSCFPatterns(typeConverter, patterns);

python/src/ir.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
#include "triton/Tools/Sys/GetEnv.hpp"
3232
#include "llvm/Support/SourceMgr.h"
3333

34+
#include "third_party/proton/dialect/include/Dialect/Proton/IR/Dialect.h"
35+
3436
namespace {
3537

3638
namespace py = pybind11;
@@ -235,7 +237,8 @@ void init_triton_ir(py::module &&m) {
235237
registry.insert<TritonDialect, ::mlir::triton::gpu::TritonGPUDialect,
236238
math::MathDialect, arith::ArithDialect, scf::SCFDialect,
237239
::mlir::gpu::GPUDialect, cf::ControlFlowDialect,
238-
LLVM::LLVMDialect, mlir::ub::UBDialect>();
240+
::mlir::triton::proton::ProtonDialect, LLVM::LLVMDialect,
241+
mlir::ub::UBDialect>();
239242
mlir::LLVM::registerInlinerInterface(registry);
240243
registerBuiltinDialectTranslation(registry);
241244
registerLLVMDialectTranslation(registry);
@@ -1654,6 +1657,11 @@ void init_triton_ir(py::module &&m) {
16541657
std::vector<int32_t> &tensorShape) -> Value {
16551658
return self.create<MakeTensorDescOp>(base, shape, strides,
16561659
tensorShape);
1660+
})
1661+
// Proton Ops
1662+
.def("create_proton_record",
1663+
[](TritonOpBuilder &self, bool isStart, int32_t regionId) -> void {
1664+
self.create<mlir::triton::proton::RecordOp>(isStart, regionId);
16571665
});
16581666

16591667
py::class_<PassManager>(m, "pass_manager", py::module_local())

third_party/amd/lib/TritonAMDGPUToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,5 @@ add_triton_library(TritonAMDGPUToLLVM
2929
LINK_LIBS PUBLIC
3030
TritonGPUToLLVM
3131
TritonAMDGPUIR
32+
TritonProtonToLLVM
3233
)

third_party/amd/lib/TritonAMDGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
#include "triton/Dialect/TritonGPU/IR/Dialect.h"
2525
#include "triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
2626

27+
#include "third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h"
28+
2729
namespace mlir::triton {
2830
#define GEN_PASS_DEF_CONVERTTRITONAMDGPUTOLLVM
2931
#include "TritonAMDGPUToLLVM/Passes.h.inc"
@@ -228,6 +230,10 @@ struct ConvertTritonAMDGPUToLLVM
228230
patterns);
229231
mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns,
230232
targetInfo, commonBenefit);
233+
234+
mlir::triton::proton::populateRecordOpToLLVMPattern(
235+
typeConverter, patterns, targetInfo, commonBenefit);
236+
231237
mlir::ub::populateUBToLLVMConversionPatterns(typeConverter, patterns);
232238

233239
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns)))) {

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,4 +25,5 @@ add_triton_library(TritonNVIDIAGPUToLLVM
2525

2626
LINK_LIBS PUBLIC
2727
TritonGPUToLLVM
28+
TritonProtonToLLVM
2829
)

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
2323
#include "triton/Conversion/TritonGPUToLLVM/TypeConverter.h"
2424

25+
#include "third_party/proton/dialect/include/TritonProtonToLLVM/PatternTritonProtonOpToLLVM.h"
26+
2527
namespace mlir {
2628
namespace triton {
2729
#define GEN_PASS_DEF_CONVERTTRITONGPUTOLLVM
@@ -149,6 +151,8 @@ struct ConvertTritonGPUToLLVM
149151
targetInfo, benefit);
150152
mlir::triton::populatePrintOpToLLVMPattern(typeConverter, patterns,
151153
targetInfo, benefit);
154+
mlir::triton::proton::populateRecordOpToLLVMPattern(typeConverter, patterns,
155+
targetInfo, benefit);
152156
mlir::triton::populateControlFlowOpToLLVMPattern(typeConverter, patterns,
153157
targetInfo, benefit);
154158
mlir::triton::NVIDIA::populateSPMDOpToLLVMPattern(typeConverter, patterns,
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
#ifndef TRITON_CONVERSION_TRITONPROTON_TO_LLVM_PATTERNS_TRITON_PROTON_OP_TO_LLVM_H
2+
#define TRITON_CONVERSION_TRITONPROTON_TO_LLVM_PATTERNS_TRITON_PROTON_OP_TO_LLVM_H
3+
4+
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
5+
6+
using namespace mlir;
7+
using namespace mlir::triton;
8+
9+
namespace mlir {
10+
namespace triton {
11+
namespace proton {
12+
void populateRecordOpToLLVMPattern(LLVMTypeConverter &typeConverter,
13+
RewritePatternSet &patterns,
14+
const TargetInfoBase &targetInfo,
15+
PatternBenefit benefit);
16+
} // namespace proton
17+
} // namespace triton
18+
} // namespace mlir
19+
20+
#endif
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1 +1,2 @@
11
add_subdirectory(Dialect)
2+
add_subdirectory(TritonProtonToLLVM)
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
add_triton_library(TritonProtonToLLVM
2+
RecordOpToLLVM.cpp
3+
4+
LINK_LIBS PUBLIC
5+
ProtonIR
6+
)

0 commit comments

Comments
 (0)