Skip to content

Commit 9e62654

Browse files
authored
[BACKEND][NFC] Remove unused passes, variables, and functions (#5888)
1 parent 83229ce commit 9e62654

File tree

5 files changed

+35
-48
lines changed

5 files changed

+35
-48
lines changed

include/triton/Conversion/TritonGPUToLLVM/TypeConverter.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@ class TritonGPUToLLVMTypeConverter : public LLVMTypeConverter {
1414
public:
1515
using TypeConverter::convertType;
1616

17-
TritonGPUToLLVMTypeConverter(MLIRContext *ctx, LowerToLLVMOptions &option,
17+
TritonGPUToLLVMTypeConverter(MLIRContext *ctx,
18+
const LowerToLLVMOptions &option,
19+
const TargetInfoBase &targetInfo,
20+
const DataLayoutAnalysis *analysis = nullptr);
21+
TritonGPUToLLVMTypeConverter(MLIRContext *ctx,
1822
const TargetInfoBase &targetInfo,
1923
const DataLayoutAnalysis *analysis = nullptr);
2024

lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,13 @@ using ::mlir::triton::gpu::getTotalElemsPerThread;
1212
using ::mlir::triton::gpu::MemDescType;
1313

1414
TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
15-
MLIRContext *ctx, LowerToLLVMOptions &options,
15+
MLIRContext *ctx, const TargetInfoBase &targetInfo,
16+
const DataLayoutAnalysis *analysis)
17+
: TritonGPUToLLVMTypeConverter(ctx, LowerToLLVMOptions(ctx), targetInfo,
18+
analysis) {}
19+
20+
TritonGPUToLLVMTypeConverter::TritonGPUToLLVMTypeConverter(
21+
MLIRContext *ctx, const LowerToLLVMOptions &options,
1622
const TargetInfoBase &targetInfo, const DataLayoutAnalysis *analysis)
1723
: LLVMTypeConverter(ctx, options, analysis) {
1824
addConversion([ctx](triton::PointerType type) -> std::optional<Type> {

third_party/nvidia/backend/compiler.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -307,15 +307,13 @@ def make_llir(self, src, metadata, options, capability):
307307
nvidia.passes.ttnvgpuir.add_lower_mma(pm)
308308
passes.ttgpuir.add_combine_tensor_select_and_if(pm)
309309
passes.convert.add_scf_to_cf(pm)
310-
passes.convert.add_index_to_llvmir(pm)
311310
passes.ttgpuir.add_allocate_shared_memory(pm)
312311
nvidia.passes.ttnvgpuir.add_allocate_tensor_memory(pm)
313312
passes.ttgpuir.add_allocate_global_scratch_memory(pm)
314313
nvidia.passes.ttgpuir.add_to_llvmir(pm, capability, ptx_version)
315314
passes.common.add_canonicalizer(pm)
316315
passes.common.add_cse(pm)
317316
nvidia.passes.ttnvgpuir.add_nvgpu_to_llvm(pm)
318-
passes.convert.add_arith_to_llvmir(pm)
319317
passes.common.add_canonicalizer(pm)
320318
passes.common.add_cse(pm)
321319
passes.common.add_symbol_dce(pm)
@@ -348,6 +346,10 @@ def make_llir(self, src, metadata, options, capability):
348346
llvm.optimize_module(llvm_mod, llvm.OPTIMIZE_O3)
349347

350348
# Get some metadata
349+
# warp-specialization mutates num_warps
350+
num_warp_groups = src.get_int_attr("ttg.num-warp-groups-per-cta")
351+
if num_warp_groups is not None:
352+
metadata["num_warps"] *= num_warp_groups
351353
metadata["shared"] = src.get_int_attr("ttg.shared")
352354
metadata["tmem_size"] = src.get_int_attr("ttg.tensor_memory_size")
353355
metadata["global_scratch_size"] = src.get_int_attr("ttg.global_scratch_memory_size")

third_party/nvidia/include/TritonNVIDIAGPUToLLVM/Passes.td

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def ConvertTritonGPUToLLVM : Pass<"convert-triton-gpu-to-llvm", "mlir::ModuleOp"
1818
"mlir::triton::TritonDialect",
1919
"mlir::triton::gpu::TritonGPUDialect",
2020
"mlir::triton::nvidia_gpu::TritonNvidiaGPUDialect",
21+
"mlir::triton::nvgpu::NVGPUDialect",
2122
"mlir::NVVM::NVVMDialect"];
2223

2324
let options = [

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/TritonGPUToLLVM.cpp

Lines changed: 18 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include "mlir/Conversion/ControlFlowToLLVM/ControlFlowToLLVM.h"
66
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
77
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
8-
#include "mlir/Conversion/SCFToControlFlow/SCFToControlFlow.h"
98
#include "mlir/Conversion/UBToLLVM/UBToLLVM.h"
109
#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
1110
#include "mlir/Dialect/LLVMIR/NVVMDialect.h"
@@ -36,12 +35,6 @@ using namespace mlir::triton::NVIDIA;
3635

3736
namespace {
3837

39-
// pass ws related named attrs.
40-
static void addAttrs(Operation *op, ArrayRef<mlir::NamedAttribute> attrs) {
41-
for (const NamedAttribute attr : attrs)
42-
op->setAttr(attr.getName(), attr.getValue());
43-
}
44-
4538
class TritonLLVMFunctionConversionTarget : public ConversionTarget {
4639
public:
4740
explicit TritonLLVMFunctionConversionTarget(MLIRContext &ctx)
@@ -71,55 +64,41 @@ struct ConvertTritonGPUToLLVM
7164
: public triton::impl::ConvertTritonGPUToLLVMBase<ConvertTritonGPUToLLVM> {
7265
using ConvertTritonGPUToLLVMBase::ConvertTritonGPUToLLVMBase;
7366

74-
void getDependentDialects(DialectRegistry &registry) const override {
75-
registry.insert<triton::nvgpu::NVGPUDialect, LLVM::LLVMDialect,
76-
NVVM::NVVMDialect>();
77-
}
78-
7967
ConvertTritonGPUToLLVM(int32_t computeCapability)
8068
: ConvertTritonGPUToLLVMBase({computeCapability}) {}
81-
8269
ConvertTritonGPUToLLVM(int32_t computeCapability, int32_t ptxVersion)
8370
: ConvertTritonGPUToLLVMBase({computeCapability, ptxVersion}) {}
8471

8572
void runOnOperation() override {
8673
MLIRContext *context = &getContext();
8774
ModuleOp mod = getOperation();
88-
89-
mlir::LowerToLLVMOptions option(context);
90-
option.overrideIndexBitwidth(32);
9175
TargetInfo targetInfo(computeCapability, ptxVersion);
92-
TritonGPUToLLVMTypeConverter typeConverter(context, option, targetInfo);
93-
TritonLLVMConversionTarget convTarget(*context);
94-
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
95-
int threadsPerWarp = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
9676

9777
// Allocate shared memory and set barrier
9878
ModuleAllocation allocation(mod);
9979
ModuleMembarAnalysis membarPass(&allocation);
10080
membarPass.run();
10181

10282
// Lower functions
103-
{
104-
mlir::LowerToLLVMOptions option(context);
105-
TritonGPUToLLVMTypeConverter typeConverter(context, option, targetInfo);
106-
TritonLLVMFunctionConversionTarget funcTarget(*context);
107-
RewritePatternSet funcPatterns(context);
108-
mlir::triton::populateFuncOpConversionPattern(
109-
typeConverter, funcPatterns, targetInfo, patternBenefitDefault);
110-
mlir::cf::populateControlFlowToLLVMConversionPatterns(typeConverter,
111-
funcPatterns);
112-
if (failed(
113-
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
114-
return signalPassFailure();
115-
}
83+
TritonLLVMFunctionConversionTarget funcTarget(*context);
84+
RewritePatternSet funcPatterns(context);
85+
TritonGPUToLLVMTypeConverter funcTypeConverter(context, targetInfo);
86+
mlir::triton::populateFuncOpConversionPattern(
87+
funcTypeConverter, funcPatterns, targetInfo, patternBenefitDefault);
88+
mlir::cf::populateControlFlowToLLVMConversionPatterns(funcTypeConverter,
89+
funcPatterns);
90+
if (failed(
91+
applyPartialConversion(mod, funcTarget, std::move(funcPatterns))))
92+
return signalPassFailure();
11693

11794
// initSharedMemory is run before the conversion of call and ret ops,
11895
// because the call op has to know the shared memory base address of each
11996
// function
97+
mlir::LowerToLLVMOptions option(context);
98+
option.overrideIndexBitwidth(32);
99+
TritonGPUToLLVMTypeConverter typeConverter(context, option, targetInfo);
120100
initSharedMemory(typeConverter);
121101
ModuleAxisInfoAnalysis axisInfoAnalysis(mod);
122-
OpBuilder::InsertPoint indexInsertPoint;
123102

124103
RewritePatternSet patterns(context);
125104
int benefit = patternBenefitPrioritizeOverLLVMConversions;
@@ -178,13 +157,16 @@ struct ConvertTritonGPUToLLVM
178157
typeConverter, patterns, benefit);
179158
mlir::triton::populateMakeRangeOpToLLVMPattern(typeConverter, targetInfo,
180159
patterns, benefit);
181-
populateTCGen5MMAOpToLLVMPattern(typeConverter, patterns, benefit);
160+
mlir::triton::NVIDIA::populateTCGen5MMAOpToLLVMPattern(typeConverter,
161+
patterns, benefit);
182162
mlir::triton::NVIDIA::populateFp4ToFpToLLVMPatterns(typeConverter, patterns,
183163
benefit);
164+
TritonLLVMConversionTarget convTarget(*context);
184165
if (failed(applyPartialConversion(mod, convTarget, std::move(patterns))))
185166
return signalPassFailure();
186167

187168
// Fold CTAId when there is only 1 CTA.
169+
int numCTAs = triton::gpu::TritonGPUDialect::getNumCTAs(mod);
188170
if (numCTAs == 1) {
189171
mod.walk([](triton::nvgpu::ClusterCTAIdOp id) {
190172
OpBuilder b(id);
@@ -198,7 +180,6 @@ struct ConvertTritonGPUToLLVM
198180
void initSharedMemory(LLVMTypeConverter &typeConverter) {
199181
ModuleOp mod = getOperation();
200182
OpBuilder b(mod.getBodyRegion());
201-
auto ctx = mod.getContext();
202183
auto loc = mod.getLoc();
203184
auto elemTy = typeConverter.convertType(b.getIntegerType(8));
204185
// Set array size 0 and external linkage indicates that we use dynamic
@@ -207,19 +188,12 @@ struct ConvertTritonGPUToLLVM
207188
// Ask for 16B alignment on global_smem because that's the largest we should
208189
// ever need (4xi32).
209190
auto arrayTy = LLVM::LLVMArrayType::get(elemTy, 0);
210-
auto global = b.create<LLVM::GlobalOp>(
191+
b.create<LLVM::GlobalOp>(
211192
loc, arrayTy, /*isConstant=*/false, LLVM::Linkage::External,
212193
"global_smem", /*value=*/Attribute(), /*alignment=*/16,
213194
// Add ROCm support.
214195
static_cast<unsigned>(NVVM::NVVMMemorySpace::kSharedMemorySpace));
215196
}
216-
217-
static Value promoteOperand(OpBuilder &builder, Location loc, Value operand,
218-
Type promotedType) {
219-
Type tensorPromotedType = cast<RankedTensorType>(operand.getType())
220-
.cloneWith(std::nullopt, promotedType);
221-
return builder.create<triton::FpToFpOp>(loc, tensorPromotedType, operand);
222-
}
223197
};
224198

225199
} // anonymous namespace

0 commit comments

Comments
 (0)