55#include " mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
66#include " mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
77#include " mlir/Conversion/MathToLLVM/MathToLLVM.h"
8- #include " mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
98#include " mlir/Conversion/UBToLLVM/UBToLLVM.h"
109#include " mlir/Dialect/LLVMIR/LLVMDialect.h"
1110#include " mlir/Dialect/LLVMIR/NVVMDialect.h"
@@ -36,12 +35,6 @@ using namespace mlir::triton::NVIDIA;
3635
3736namespace {
3837
39- // pass ws related named attrs.
40- static void addAttrs (Operation *op, ArrayRef<mlir::NamedAttribute> attrs) {
41- for (const NamedAttribute attr : attrs)
42- op->setAttr (attr.getName (), attr.getValue ());
43- }
44-
4538class TritonLLVMFunctionConversionTarget : public ConversionTarget {
4639public:
4740 explicit TritonLLVMFunctionConversionTarget (MLIRContext &ctx)
@@ -71,55 +64,41 @@ struct ConvertTritonGPUToLLVM
7164 : public triton::impl::ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
7265 using ConvertTritonGPUToLLVMBase::ConvertTritonGPUToLLVMBase;
7366
74- void getDependentDialects (DialectRegistry ®istry) const override {
75- registry.insert <triton::nvgpu::NVGPUDialect, LLVM::LLVMDialect,
76- NVVM::NVVMDialect>();
77- }
78-
7967 ConvertTritonGPUToLLVM (int32_t computeCapability)
8068 : ConvertTritonGPUToLLVMBase({computeCapability}) {}
81-
8269 ConvertTritonGPUToLLVM (int32_t computeCapability, int32_t ptxVersion)
8370 : ConvertTritonGPUToLLVMBase({computeCapability, ptxVersion}) {}
8471
8572 void runOnOperation () override {
8673 MLIRContext *context = &getContext ();
8774 ModuleOp mod = getOperation ();
88-
89- mlir::LowerToLLVMOptions option (context);
90- option.overrideIndexBitwidth (32 );
9175 TargetInfo targetInfo (computeCapability, ptxVersion);
92- TritonGPUToLLVMTypeConverter typeConverter (context, option, targetInfo);
93- TritonLLVMConversionTarget convTarget (*context);
94- int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs (mod);
95- int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp (mod);
9676
9777 // Allocate shared memory and set barrier
9878 ModuleAllocation allocation (mod);
9979 ModuleMembarAnalysis membarPass (&allocation);
10080 membarPass.run ();
10181
10282 // Lower functions
103- {
104- mlir::LowerToLLVMOptions option (context);
105- TritonGPUToLLVMTypeConverter typeConverter (context, option, targetInfo);
106- TritonLLVMFunctionConversionTarget funcTarget (*context);
107- RewritePatternSet funcPatterns (context);
108- mlir::triton::populateFuncOpConversionPattern (
109- typeConverter, funcPatterns, targetInfo, patternBenefitDefault);
110- mlir::cf::populateControlFlowToLLVMConversionPatterns (typeConverter,
111- funcPatterns);
112- if (failed (
113- applyPartialConversion (mod, funcTarget, std::move (funcPatterns))))
114- return signalPassFailure ();
115- }
83+ TritonLLVMFunctionConversionTarget funcTarget (*context);
84+ RewritePatternSet funcPatterns (context);
85+ TritonGPUToLLVMTypeConverter funcTypeConverter (context, targetInfo);
86+ mlir::triton::populateFuncOpConversionPattern (
87+ funcTypeConverter, funcPatterns, targetInfo, patternBenefitDefault);
88+ mlir::cf::populateControlFlowToLLVMConversionPatterns (funcTypeConverter,
89+ funcPatterns);
90+ if (failed (
91+ applyPartialConversion (mod, funcTarget, std::move (funcPatterns))))
92+ return signalPassFailure ();
11693
11794 // initSharedMemory is run before the conversion of call and ret ops,
11895 // because the call op has to know the shared memory base address of each
11996 // function
97+ mlir::LowerToLLVMOptions option (context);
98+ option.overrideIndexBitwidth (32 );
99+ TritonGPUToLLVMTypeConverter typeConverter (context, option, targetInfo);
120100 initSharedMemory (typeConverter);
121101 ModuleAxisInfoAnalysis axisInfoAnalysis (mod);
122- OpBuilder::InsertPoint indexInsertPoint;
123102
124103 RewritePatternSet patterns (context);
125104 int benefit = patternBenefitPrioritizeOverLLVMConversions;
@@ -178,13 +157,16 @@ struct ConvertTritonGPUToLLVM
178157 typeConverter, patterns, benefit);
179158 mlir::triton::populateMakeRangeOpToLLVMPattern (typeConverter, targetInfo,
180159 patterns, benefit);
181- populateTCGen5MMAOpToLLVMPattern (typeConverter, patterns, benefit);
160+ mlir::triton::NVIDIA::populateTCGen5MMAOpToLLVMPattern (typeConverter,
161+ patterns, benefit);
182162 mlir::triton::NVIDIA::populateFp4ToFpToLLVMPatterns (typeConverter, patterns,
183163 benefit);
164+ TritonLLVMConversionTarget convTarget (*context);
184165 if (failed (applyPartialConversion (mod, convTarget, std::move (patterns))))
185166 return signalPassFailure ();
186167
187168 // Fold CTAId when there is only 1 CTA.
169+ int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs (mod);
188170 if (numCTAs == 1 ) {
189171 mod.walk ([](triton::nvgpu::ClusterCTAIdOp id) {
190172 OpBuilder b (id);
@@ -198,7 +180,6 @@ struct ConvertTritonGPUToLLVM
198180 void initSharedMemory (LLVMTypeConverter &typeConverter) {
199181 ModuleOp mod = getOperation ();
200182 OpBuilder b (mod.getBodyRegion ());
201- auto ctx = mod.getContext ();
202183 auto loc = mod.getLoc ();
203184 auto elemTy = typeConverter.convertType (b.getIntegerType (8 ));
204185 // Set array size 0 and external linkage indicates that we use dynamic
@@ -207,19 +188,12 @@ struct ConvertTritonGPUToLLVM
207188 // Ask for 16B alignment on global_smem because that's the largest we should
208189 // ever need (4xi32).
209190 auto arrayTy = LLVM::LLVMArrayType::get (elemTy, 0 );
210- auto global = b.create <LLVM::GlobalOp>(
191+ b.create <LLVM::GlobalOp>(
211192 loc, arrayTy, /* isConstant=*/ false , LLVM::Linkage::External,
212193 " global_smem" , /* value=*/ Attribute (), /* alignment=*/ 16 ,
213194 // Add ROCm support.
214195 static_cast <unsigned >(NVVM::NVVMMemorySpace::kSharedMemorySpace ));
215196 }
216-
217- static Value promoteOperand (OpBuilder &builder, Location loc, Value operand,
218- Type promotedType) {
219- Type tensorPromotedType = cast<RankedTensorType>(operand.getType ())
220- .cloneWith (std::nullopt , promotedType);
221- return builder.create <triton::FpToFpOp>(loc, tensorPromotedType, operand);
222- }
223197};
224198
225199} // anonymous namespace
0 commit comments