Skip to content

Commit 824170b

Browse files
[Intel] Cleanup TritonGPUOpToLLVM patterns (#3691)
- Remove unused pattern declarations. - Reduce `TritonIntelGPUToLLVMTypeConverter` usage. - NFC code reordering. Signed-off-by: Whitney Tsang <[email protected]>
1 parent b6bb784 commit 824170b

File tree

5 files changed

+54
-69
lines changed

5 files changed

+54
-69
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/DotOpToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ struct DotOpConversion : public ConvertTritonGPUOpToLLVMPattern<triton::DotOp> {
5353
} // namespace
5454

5555
void mlir::triton::intel::populateDotOpToLLVMPatterns(
56-
TritonIntelGPUToLLVMTypeConverter &typeConverter,
57-
RewritePatternSet &patterns, PatternBenefit benefit) {
56+
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
57+
PatternBenefit benefit) {
5858
patterns.add<DotOpConversion>(typeConverter, benefit);
5959
}

third_party/intel/lib/TritonIntelGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -295,8 +295,7 @@ struct PrefetchOpConversion
295295
triton::gpu::intel::PrefetchOp>::ConvertTritonGPUOpToLLVMPattern;
296296

297297
PrefetchOpConversion(
298-
TritonGPUToLLVMTypeConverter &converter,
299-
const triton::intel::TargetInfo &targetInfo,
298+
LLVMTypeConverter &converter, const triton::intel::TargetInfo &targetInfo,
300299
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
301300
PatternBenefit benefit)
302301
: ConvertTritonGPUOpToLLVMPattern<triton::gpu::intel::PrefetchOp>(
@@ -473,8 +472,7 @@ struct LoadOpToBlockIOConversion
473472
using ValueTable = std::map<std::pair<int, int>, Value>;
474473

475474
LoadOpToBlockIOConversion(
476-
TritonIntelGPUToLLVMTypeConverter &converter,
477-
const triton::intel::TargetInfo &targetInfo,
475+
LLVMTypeConverter &converter, const triton::intel::TargetInfo &targetInfo,
478476
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
479477
PatternBenefit benefit)
480478
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
@@ -567,7 +565,7 @@ struct LoadOpToBlockIOConversion
567565
dpasInstShape[1]};
568566
unsigned elemsPerLanePerDPASInst =
569567
product<unsigned>(elemsPerDPASInst) / threadsPerWarp;
570-
TritonGPUToLLVMTypeConverter *typeConverter = getTypeConverter();
568+
LLVMTypeConverter *typeConverter = getTypeConverter();
571569
Type unpackedDPASOperandType = LLVM::getFixedVectorType(
572570
typeConverter->convertType(eltTy), elemsPerLanePerDPASInst);
573571

@@ -964,8 +962,7 @@ struct LoadOpConversion
964962
using ValueTable = std::map<std::pair<int, int>, Value>;
965963

966964
LoadOpConversion(
967-
TritonIntelGPUToLLVMTypeConverter &converter,
968-
const triton::intel::TargetInfo &targetInfo,
965+
LLVMTypeConverter &converter, const triton::intel::TargetInfo &targetInfo,
969966
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
970967
PatternBenefit benefit, bool oneMatrixPerLoadForBT)
971968
: ConvertTritonGPUOpToLLVMPattern<triton::LoadOp>(converter, benefit),
@@ -1159,7 +1156,7 @@ struct LoadOpConversion
11591156
}
11601157
}
11611158

1162-
TritonGPUToLLVMTypeConverter *typeConverter = getTypeConverter();
1159+
LLVMTypeConverter *typeConverter = getTypeConverter();
11631160
Type llvmResultStructTy = typeConverter->convertType(op.getType());
11641161
Value resultStruct = packLLElements(
11651162
loc, typeConverter, unpackedLoadedVals, rewriter, llvmResultStructTy);
@@ -1176,7 +1173,7 @@ struct LoadOpConversion
11761173
dpasInstShape[1]};
11771174
unsigned elemsPerLanePerDPASInst =
11781175
product<unsigned>(elemsPerDPASInst) / threadsPerWarp;
1179-
TritonGPUToLLVMTypeConverter *typeConverter = getTypeConverter();
1176+
LLVMTypeConverter *typeConverter = getTypeConverter();
11801177
Type unpackedDPASOperandType = LLVM::getFixedVectorType(
11811178
typeConverter->convertType(eltTy), elemsPerLanePerDPASInst);
11821179

@@ -1640,8 +1637,7 @@ struct StoreOpConversion
16401637
triton::StoreOp>::ConvertTritonGPUOpToLLVMPattern;
16411638

16421639
StoreOpConversion(
1643-
TritonIntelGPUToLLVMTypeConverter &converter,
1644-
const triton::intel::TargetInfo &targetInfo,
1640+
LLVMTypeConverter &converter, const triton::intel::TargetInfo &targetInfo,
16451641
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
16461642
PatternBenefit benefit)
16471643
: ConvertTritonGPUOpToLLVMPattern<triton::StoreOp>(converter, benefit),
@@ -1660,7 +1656,7 @@ struct StoreOpConversion
16601656
return failure();
16611657

16621658
auto dpasLayout = cast<DpasEncodingAttr>(tensorType.getEncoding());
1663-
TritonGPUToLLVMTypeConverter *typeConverter = getTypeConverter();
1659+
LLVMTypeConverter *typeConverter = getTypeConverter();
16641660
MLIRContext *ctx = rewriter.getContext();
16651661

16661662
Type eltTy = tensorType.getElementType();
@@ -1919,8 +1915,7 @@ struct AtomicCASOpConversion
19191915
triton::AtomicCASOp>::ConvertTritonGPUOpToLLVMPattern;
19201916

19211917
AtomicCASOpConversion(
1922-
TritonIntelGPUToLLVMTypeConverter &converter,
1923-
const triton::intel::TargetInfo &targetInfo,
1918+
LLVMTypeConverter &converter, const triton::intel::TargetInfo &targetInfo,
19241919
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
19251920
PatternBenefit benefit)
19261921
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicCASOp>(converter,
@@ -2044,8 +2039,7 @@ struct AtomicRMWOpConversion
20442039
triton::AtomicRMWOp>::ConvertTritonGPUOpToLLVMPattern;
20452040

20462041
AtomicRMWOpConversion(
2047-
TritonIntelGPUToLLVMTypeConverter &converter,
2048-
const triton::intel::TargetInfo &targetInfo,
2042+
LLVMTypeConverter &converter, const triton::intel::TargetInfo &targetInfo,
20492043
const triton::intel::ModuleAxisInfoAnalysis &axisAnalysisPass,
20502044
PatternBenefit benefit)
20512045
: ConvertTritonGPUOpToLLVMPattern<triton::AtomicRMWOp>(converter,
@@ -2312,8 +2306,8 @@ struct AtomicRMWOpConversion
23122306
} // namespace
23132307

23142308
void mlir::triton::intel::populateLoadStoreOpToLLVMPatterns(
2315-
TritonIntelGPUToLLVMTypeConverter &typeConverter,
2316-
const TargetInfo &targetInfo, RewritePatternSet &patterns,
2309+
LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo,
2310+
RewritePatternSet &patterns,
23172311
const intel::ModuleAxisInfoAnalysis &axisInfoAnalysis,
23182312
PatternBenefit benefit, bool oneMatrixPerLoadForBT) {
23192313
patterns.add<AtomicCASOpConversion, AtomicRMWOpConversion, StoreOpConversion,

third_party/intel/lib/TritonIntelGPUToLLVM/PatternTritonGPUOpToLLVM.h

Lines changed: 36 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -11,84 +11,76 @@ namespace mlir::triton::intel {
1111

1212
constexpr int patternBenefitAddSPIRVEnv = 30;
1313

14-
// Custom Arith Dialect patterns.
14+
/* Advanced path custom patterns start */
15+
1516
void populateArithOpsToLLVMPatterns(
1617
TritonIntelGPUToLLVMTypeConverter &typeConverter,
1718
RewritePatternSet &patterns, PatternBenefit benefit);
1819

19-
void populateTritonOpsToLLVMPatterns(
20-
TritonIntelGPUToLLVMTypeConverter &typeConverter,
21-
RewritePatternSet &patterns, PatternBenefit benefit);
20+
void populateBF16CastsLLVMPatterns(LLVMTypeConverter &typeConverter,
21+
RewritePatternSet &patterns,
22+
PatternBenefit benefit);
2223

23-
void populateBarrierOpToLLVMPatterns(
24+
void populateTritonOpsToLLVMPatterns(
2425
TritonIntelGPUToLLVMTypeConverter &typeConverter,
2526
RewritePatternSet &patterns, PatternBenefit benefit);
2627

27-
void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
28-
const TargetInfo &targetInfo,
29-
RewritePatternSet &patterns,
30-
PatternBenefit benefit);
28+
/* Advanced path custom patterns end */
3129

32-
void populateDotOpToLLVMPatterns(
33-
TritonIntelGPUToLLVMTypeConverter &typeConverter,
34-
RewritePatternSet &patterns, PatternBenefit benefit);
30+
/* Specialized common patterns start */
3531

3632
void populateElementwiseOpToLLVMPatterns(
3733
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
3834
ModuleAxisInfoAnalysis &axisInfoAnalysis, const TargetInfoBase &targetInfo,
3935
PatternBenefit benefit);
4036

41-
void populateFp4ToFpToLLVMPatterns(LLVMTypeConverter &typeConverter,
42-
RewritePatternSet &patterns,
43-
PatternBenefit benefit);
44-
45-
void populateBF16CastsLLVMPatterns(LLVMTypeConverter &typeConverter,
46-
RewritePatternSet &patterns,
47-
PatternBenefit benefit);
48-
4937
void populateHistogramOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
5038
RewritePatternSet &patterns,
5139
const TargetInfoBase &targetInfo,
5240
PatternBenefit benefit);
5341

54-
void populateLoadStoreOpToLLVMPatterns(
55-
TritonIntelGPUToLLVMTypeConverter &typeConverter,
56-
const TargetInfo &targetInfo, RewritePatternSet &patterns,
57-
const ModuleAxisInfoAnalysis &axisInfoAnalysis, PatternBenefit benefit,
58-
bool oneMatrixPerLoadForBT);
59-
6042
void populateReduceOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
6143
RewritePatternSet &patterns,
6244
const TargetInfoBase &targetInfo,
6345
PatternBenefit benefit);
6446

65-
void populateTensorPtrOpsToLLVMPatterns(
66-
TritonIntelGPUToLLVMTypeConverter &typeConverter,
67-
RewritePatternSet &patterns, PatternBenefit benefit);
47+
void populateConvertLayoutOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
48+
const TargetInfo &targetInfo,
49+
RewritePatternSet &patterns,
50+
PatternBenefit benefit);
6851

69-
void populateTritonGPUToLLVMPatterns(
70-
TritonIntelGPUToLLVMTypeConverter &typeConverter,
71-
RewritePatternSet &patterns, PatternBenefit benefit);
52+
void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter,
53+
RewritePatternSet &patterns,
54+
const TargetInfoBase &targetInfo,
55+
PatternBenefit benefit);
7256

73-
void populatePrintOpToLLVMPattern(
74-
TritonIntelGPUToLLVMTypeConverter &typeConverter,
75-
RewritePatternSet &patterns, const TargetInfoBase &targetInfo,
76-
PatternBenefit benefit);
57+
void populatePrintOpToLLVMPattern(LLVMTypeConverter &typeConverter,
58+
RewritePatternSet &patterns,
59+
const TargetInfoBase &targetInfo,
60+
PatternBenefit benefit);
7761

78-
void populateMemoryOpToLLVMPattern(LLVMTypeConverter &typeConverter,
79-
const TargetInfoBase &targetInfo,
62+
/* Specialized common patterns end */
63+
64+
/* Third party patterns start */
65+
66+
void populateDotOpToLLVMPatterns(LLVMTypeConverter &typeConverter,
67+
RewritePatternSet &patterns,
68+
PatternBenefit benefit);
69+
70+
void populateFp4ToFpToLLVMPatterns(LLVMTypeConverter &typeConverter,
8071
RewritePatternSet &patterns,
8172
PatternBenefit benefit);
8273

83-
void populateControlFlowOpToLLVMPattern(LLVMTypeConverter &typeConverter,
74+
void populateLoadStoreOpToLLVMPatterns(
75+
LLVMTypeConverter &typeConverter, const TargetInfo &targetInfo,
76+
RewritePatternSet &patterns, const ModuleAxisInfoAnalysis &axisInfoAnalysis,
77+
PatternBenefit benefit, bool oneMatrixPerLoadForBT);
78+
79+
void populateTensorPtrOpsToLLVMPatterns(LLVMTypeConverter &typeConverter,
8480
RewritePatternSet &patterns,
85-
const TargetInfoBase &targetInfo,
8681
PatternBenefit benefit);
8782

88-
void populateSPMDOpToLLVMPattern(LLVMTypeConverter &typeConverter,
89-
RewritePatternSet &patterns,
90-
const TargetInfoBase &targetInfo,
91-
PatternBenefit benefit);
83+
/* Third party patterns end */
9284

9385
} // namespace mlir::triton::intel
9486

third_party/intel/lib/TritonIntelGPUToLLVM/PrintOpToLLVM.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -241,8 +241,7 @@ struct PrintOpConversion
241241
} // namespace
242242

243243
void mlir::triton::intel::populatePrintOpToLLVMPattern(
244-
TritonIntelGPUToLLVMTypeConverter &typeConverter,
245-
RewritePatternSet &patterns, const TargetInfoBase &targetInfo,
246-
PatternBenefit benefit) {
244+
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
245+
const TargetInfoBase &targetInfo, PatternBenefit benefit) {
247246
patterns.add<PrintOpConversion>(typeConverter, targetInfo, benefit);
248247
}

third_party/intel/lib/TritonIntelGPUToLLVM/TensorPtrOpsToLLVM.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ struct AdvanceOpConversion
9797
} // namespace
9898

9999
void mlir::triton::intel::populateTensorPtrOpsToLLVMPatterns(
100-
TritonIntelGPUToLLVMTypeConverter &typeConverter,
101-
RewritePatternSet &patterns, PatternBenefit benefit) {
100+
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns,
101+
PatternBenefit benefit) {
102102
patterns.add<MakeTensorPtrOpConversion>(typeConverter, benefit);
103103
patterns.add<AdvanceOpConversion>(typeConverter, benefit);
104104
return;

0 commit comments

Comments
 (0)