Skip to content

Commit f3067cd

Browse files
authored
[AMD] NFC: simplify pass/pattern constructor declaration (#7665)
1 parent 0cd5b90 commit f3067cd

15 files changed

+55
-111
lines changed

lib/Dialect/TritonGPU/Transforms/CoalesceAsyncCopy.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,10 @@ struct ClipAsyncCopySizePerThread
105105
}
106106
};
107107

108-
class CoalesceAsyncCopyPass
109-
: public impl::TritonGPUCoalesceAsyncCopyBase<CoalesceAsyncCopyPass> {
110-
public:
108+
struct CoalesceAsyncCopyPass
109+
: impl::TritonGPUCoalesceAsyncCopyBase<CoalesceAsyncCopyPass> {
110+
using Base::Base;
111+
111112
void runOnOperation() override {
112113
ModuleOp m = getOperation();
113114
MLIRContext *context = &getContext();

third_party/amd/lib/TritonAMDGPUDialectToLLVM/ExtractSliceOpToLLVM.cpp

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,7 @@ namespace {
2121

2222
struct ExtractSliceOpConversion
2323
: public ConvertOpToLLVMPattern<amdgpu::ExtractSliceOp> {
24-
explicit ExtractSliceOpConversion(LLVMTypeConverter &typeConverter,
25-
PatternBenefit benefit = 1)
26-
: ConvertOpToLLVMPattern<amdgpu::ExtractSliceOp>(typeConverter, benefit) {
27-
}
24+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
2825

2926
LogicalResult processLayout(amdgpu::ExtractSliceOp op, OpAdaptor adaptor,
3027
ConversionPatternRewriter &rewriter) const {

third_party/amd/lib/TritonAMDGPUDialectToLLVM/InThreadTransposeOpToTTG.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@ namespace {
1010
struct InThreadTransposeOpConversion
1111
: public OpConversionPattern<triton::amdgpu::InThreadTransposeOp> {
1212
public:
13-
explicit InThreadTransposeOpConversion(MLIRContext *ctx,
14-
PatternBenefit benefit)
15-
: OpConversionPattern(ctx, benefit) {}
13+
using OpConversionPattern::OpConversionPattern;
1614

1715
LogicalResult
1816
matchAndRewrite(triton::amdgpu::InThreadTransposeOp op, OpAdaptor adaptor,

third_party/amd/lib/TritonAMDGPUToLLVM/BuiltinFuncToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace {
1919
class CallOpConversion : public OpRewritePattern<LLVM::CallOp> {
2020
public:
2121
CallOpConversion(mlir::MLIRContext *context, bool ftz)
22-
: OpRewritePattern<LLVM::CallOp>(context, 1), ftz(ftz) {}
22+
: OpRewritePattern(context, 1), ftz(ftz) {}
2323

2424
LogicalResult
2525
matchAndRewrite(LLVM::CallOp callOp,

third_party/amd/lib/TritonAMDGPUToLLVM/ConvertLayoutOpToLLVM.cpp

Lines changed: 23 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
#include "Analysis/AMDGPUAllocation.h"
22
#include "PatternTritonGPUOpToLLVM.h"
3-
#include "Utility.h"
43
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
54
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
6-
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
75

8-
using ::mlir::transferWithinBlockPadding;
9-
using ::mlir::transferWithinBlockSwizzling;
106
using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
117
using ::mlir::triton::gpu::ConvertLayoutOp;
128
using ::triton::gpu::LinearEncodingAttr;
@@ -27,15 +23,14 @@ static bool matchMFMAAndLinearLayoutCase(RankedTensorType srcTy,
2723
storeLL.value_or(LinearLayout::empty());
2824
};
2925

30-
struct ConvertLayoutOpMFMAToLinearConversion
26+
class ConvertLayoutOpMFMAToLinearConversion
3127
: public ConvertOpToLLVMPattern<triton::gpu::ConvertLayoutOp> {
3228
public:
33-
explicit ConvertLayoutOpMFMAToLinearConversion(
34-
LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
35-
PatternBenefit benefit)
36-
: ConvertOpToLLVMPattern<triton::gpu::ConvertLayoutOp>(typeConverter,
37-
benefit),
38-
targetInfo(targetInfo) {}
29+
ConvertLayoutOpMFMAToLinearConversion(LLVMTypeConverter &typeConverter,
30+
const TargetInfoBase &targetInfo,
31+
PatternBenefit benefit)
32+
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
33+
}
3934

4035
LogicalResult
4136
matchAndRewrite(triton::gpu::ConvertLayoutOp op, OpAdaptor adaptor,
@@ -116,18 +111,18 @@ struct ConvertLayoutOpMFMAToLinearConversion
116111
return success();
117112
}
118113

119-
protected:
114+
private:
120115
const TargetInfoBase &targetInfo;
121116
};
122117

123-
struct ConvertLayoutForcedPadding
118+
class ConvertLayoutForcedPadding
124119
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
125-
126-
explicit ConvertLayoutForcedPadding(LLVMTypeConverter &typeConverter,
127-
const TargetInfoBase &targetInfo,
128-
PatternBenefit benefit)
129-
: ConvertOpToLLVMPattern<ConvertLayoutOp>(typeConverter, benefit),
130-
targetInfo(targetInfo) {}
120+
public:
121+
ConvertLayoutForcedPadding(LLVMTypeConverter &typeConverter,
122+
const TargetInfoBase &targetInfo,
123+
PatternBenefit benefit)
124+
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
125+
}
131126

132127
LogicalResult
133128
matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor,
@@ -145,18 +140,18 @@ struct ConvertLayoutForcedPadding
145140
return success();
146141
}
147142

148-
protected:
143+
private:
149144
const TargetInfoBase &targetInfo;
150145
};
151146

152-
struct ConvertLayoutForcedSwizzling
147+
class ConvertLayoutForcedSwizzling
153148
: public ConvertOpToLLVMPattern<ConvertLayoutOp> {
154-
155-
explicit ConvertLayoutForcedSwizzling(LLVMTypeConverter &typeConverter,
156-
const TargetInfoBase &targetInfo,
157-
PatternBenefit benefit)
158-
: ConvertOpToLLVMPattern<ConvertLayoutOp>(typeConverter, benefit),
159-
targetInfo(targetInfo) {}
149+
public:
150+
ConvertLayoutForcedSwizzling(LLVMTypeConverter &typeConverter,
151+
const TargetInfoBase &targetInfo,
152+
PatternBenefit benefit)
153+
: ConvertOpToLLVMPattern(typeConverter, benefit), targetInfo(targetInfo) {
154+
}
160155

161156
LogicalResult
162157
matchAndRewrite(ConvertLayoutOp op, OpAdaptor adaptor,
@@ -176,7 +171,7 @@ struct ConvertLayoutForcedSwizzling
176171
return success();
177172
}
178173

179-
protected:
174+
private:
180175
const TargetInfoBase &targetInfo;
181176
};
182177

third_party/amd/lib/TritonAMDGPUToLLVM/DotOpToLLVM.cpp

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -56,14 +56,6 @@ struct DotOpConversion : public ConvertOpToLLVMPattern<triton::DotOp> {
5656
struct ScaledDotOpConversion
5757
: public ConvertOpToLLVMPattern<triton::DotScaledOp> {
5858
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
59-
int mfmaVersion;
60-
int nonKDim;
61-
int kPack;
62-
63-
ScaledDotOpConversion(LLVMTypeConverter &typeConverter, int mfmaVersion,
64-
int nonKDim, int kPack, PatternBenefit benefit = 1)
65-
: ConvertOpToLLVMPattern(typeConverter, benefit),
66-
mfmaVersion(mfmaVersion), nonKDim(nonKDim), kPack(kPack) {}
6759

6860
LogicalResult
6961
matchAndRewrite(triton::DotScaledOp op, OpAdaptor adaptor,

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 9 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -465,13 +465,11 @@ struct DirectToLdsLoadConversionBase : public LoadStoreConversionBase {
465465

466466
struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
467467
public LoadStoreConversionBase {
468-
using ConvertOpToLLVMPattern<triton::LoadOp>::ConvertOpToLLVMPattern;
469-
470468
LoadOpConversion(LLVMTypeConverter &converter,
471469
const AMD::TargetInfo &targetInfo,
472470
ModuleAxisInfoAnalysis &axisAnalysisPass,
473471
PatternBenefit benefit)
474-
: ConvertOpToLLVMPattern<triton::LoadOp>(converter, benefit),
472+
: ConvertOpToLLVMPattern(converter, benefit),
475473
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
476474

477475
LogicalResult
@@ -562,15 +560,11 @@ struct LoadOpConversion : public ConvertOpToLLVMPattern<triton::LoadOp>,
562560
struct BufferLoadOpConversion
563561
: public ConvertOpToLLVMPattern<triton::amdgpu::BufferLoadOp>,
564562
public LoadStoreConversionBase {
565-
using ConvertOpToLLVMPattern<
566-
triton::amdgpu::BufferLoadOp>::ConvertOpToLLVMPattern;
567-
568563
BufferLoadOpConversion(LLVMTypeConverter &converter,
569564
const AMD::TargetInfo &targetInfo,
570565
ModuleAxisInfoAnalysis &axisAnalysisPass,
571566
PatternBenefit benefit)
572-
: ConvertOpToLLVMPattern<triton::amdgpu::BufferLoadOp>(converter,
573-
benefit),
567+
: ConvertOpToLLVMPattern(converter, benefit),
574568
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
575569

576570
LogicalResult
@@ -952,13 +946,11 @@ struct AsyncCopyGlobalToLocalOpConversion
952946

953947
struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
954948
public LoadStoreConversionBase {
955-
using ConvertOpToLLVMPattern<triton::StoreOp>::ConvertOpToLLVMPattern;
956-
957949
StoreOpConversion(LLVMTypeConverter &converter,
958950
const AMD::TargetInfo &targetInfo,
959951
ModuleAxisInfoAnalysis &axisAnalysisPass,
960952
PatternBenefit benefit)
961-
: ConvertOpToLLVMPattern<triton::StoreOp>(converter, benefit),
953+
: ConvertOpToLLVMPattern(converter, benefit),
962954
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
963955

964956
LogicalResult
@@ -1038,15 +1030,11 @@ struct StoreOpConversion : public ConvertOpToLLVMPattern<triton::StoreOp>,
10381030
struct BufferAtomicRMWOpConversion
10391031
: public ConvertOpToLLVMPattern<triton::amdgpu::BufferAtomicRMWOp>,
10401032
public LoadStoreConversionBase {
1041-
using ConvertOpToLLVMPattern<
1042-
triton::amdgpu::BufferAtomicRMWOp>::ConvertOpToLLVMPattern;
1043-
10441033
BufferAtomicRMWOpConversion(LLVMTypeConverter &converter,
10451034
const AMD::TargetInfo &targetInfo,
10461035
ModuleAxisInfoAnalysis &axisAnalysisPass,
10471036
PatternBenefit benefit)
1048-
: ConvertOpToLLVMPattern<triton::amdgpu::BufferAtomicRMWOp>(converter,
1049-
benefit),
1037+
: ConvertOpToLLVMPattern(converter, benefit),
10501038
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
10511039

10521040
LogicalResult
@@ -1176,15 +1164,11 @@ struct BufferAtomicRMWOpConversion
11761164
struct BufferAtomicCASOpConversion
11771165
: public ConvertOpToLLVMPattern<triton::amdgpu::BufferAtomicCASOp>,
11781166
public LoadStoreConversionBase {
1179-
using ConvertOpToLLVMPattern<
1180-
triton::amdgpu::BufferAtomicCASOp>::ConvertOpToLLVMPattern;
1181-
11821167
BufferAtomicCASOpConversion(LLVMTypeConverter &converter,
11831168
const AMD::TargetInfo &targetInfo,
11841169
ModuleAxisInfoAnalysis &axisAnalysisPass,
11851170
PatternBenefit benefit)
1186-
: ConvertOpToLLVMPattern<triton::amdgpu::BufferAtomicCASOp>(converter,
1187-
benefit),
1171+
: ConvertOpToLLVMPattern(converter, benefit),
11881172
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
11891173

11901174
LogicalResult
@@ -1290,15 +1274,11 @@ struct BufferAtomicCASOpConversion
12901274
struct BufferStoreOpConversion
12911275
: public ConvertOpToLLVMPattern<triton::amdgpu::BufferStoreOp>,
12921276
public LoadStoreConversionBase {
1293-
using ConvertOpToLLVMPattern<
1294-
triton::amdgpu::BufferStoreOp>::ConvertOpToLLVMPattern;
1295-
12961277
BufferStoreOpConversion(LLVMTypeConverter &converter,
12971278
const AMD::TargetInfo &targetInfo,
12981279
ModuleAxisInfoAnalysis &axisAnalysisPass,
12991280
PatternBenefit benefit)
1300-
: ConvertOpToLLVMPattern<triton::amdgpu::BufferStoreOp>(converter,
1301-
benefit),
1281+
: ConvertOpToLLVMPattern(converter, benefit),
13021282
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
13031283

13041284
LogicalResult
@@ -1371,13 +1351,11 @@ struct BufferStoreOpConversion
13711351
struct AtomicCASOpConversion
13721352
: public ConvertOpToLLVMPattern<triton::AtomicCASOp>,
13731353
public LoadStoreConversionBase {
1374-
using ConvertOpToLLVMPattern<triton::AtomicCASOp>::ConvertOpToLLVMPattern;
1375-
13761354
AtomicCASOpConversion(LLVMTypeConverter &converter,
13771355
const AMD::TargetInfo &targetInfo,
13781356
ModuleAxisInfoAnalysis &axisAnalysisPass,
13791357
PatternBenefit benefit)
1380-
: ConvertOpToLLVMPattern<triton::AtomicCASOp>(converter, benefit),
1358+
: ConvertOpToLLVMPattern(converter, benefit),
13811359
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
13821360

13831361
LogicalResult
@@ -1533,13 +1511,11 @@ bool supportsGlobalAtomicF16PackedAndDpp(ISAFamily isaFamily) {
15331511
struct AtomicRMWOpConversion
15341512
: public ConvertOpToLLVMPattern<triton::AtomicRMWOp>,
15351513
public LoadStoreConversionBase {
1536-
using ConvertOpToLLVMPattern<triton::AtomicRMWOp>::ConvertOpToLLVMPattern;
1537-
15381514
AtomicRMWOpConversion(LLVMTypeConverter &converter,
15391515
const AMD::TargetInfo &targetInfo,
15401516
ModuleAxisInfoAnalysis &axisAnalysisPass,
15411517
PatternBenefit benefit)
1542-
: ConvertOpToLLVMPattern<triton::AtomicRMWOp>(converter, benefit),
1518+
: ConvertOpToLLVMPattern(converter, benefit),
15431519
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
15441520

15451521
LogicalResult
@@ -1772,7 +1748,7 @@ struct AsyncWaitOpConversion : public ConvertOpToLLVMPattern<AsyncWaitOp> {
17721748

17731749
struct AsyncCommitGroupOpConversion
17741750
: public ConvertOpToLLVMPattern<AsyncCommitGroupOp> {
1775-
using ConvertOpToLLVMPattern<AsyncCommitGroupOp>::ConvertOpToLLVMPattern;
1751+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
17761752

17771753
LogicalResult
17781754
matchAndRewrite(AsyncCommitGroupOp op, OpAdaptor adaptor,

third_party/amd/lib/TritonAMDGPUToLLVM/MemoryOpToLLVM.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
#include "AsyncUtility.h"
22
#include "Dialect/TritonAMDGPU/IR/Dialect.h"
33
#include "PatternTritonGPUOpToLLVM.h"
4-
#include "Utility.h"
4+
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
55
#include "triton/Conversion/TritonGPUToLLVM/PatternTritonGPUOpToLLVM.h"
66
#include "triton/Conversion/TritonGPUToLLVM/Utility.h"
77
#include "triton/Dialect/TritonGPU/Transforms/Utility.h"
88

9-
using ::mlir::LLVM::AMD::isUsedByDotScaledOp;
109
using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
1110
using ::mlir::triton::gpu::DotOperandEncodingAttr;
1211
using ::mlir::triton::gpu::MemDescType;
1312

1413
namespace {
1514
template <typename LocalLoadOpType>
16-
struct TransLocalLoadOpConversion
15+
class TransLocalLoadOpConversion
1716
: public ConvertOpToLLVMPattern<LocalLoadOpType> {
1817
public:
1918
TransLocalLoadOpConversion(const LLVMTypeConverter &converter,

third_party/amd/lib/TritonAMDGPUToLLVM/SPMDOpToLLVM.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,7 @@ namespace {
99

1010
struct GetNumProgramsOpConversion
1111
: public ConvertOpToLLVMPattern<triton::GetNumProgramsOp> {
12-
using ConvertOpToLLVMPattern<
13-
triton::GetNumProgramsOp>::ConvertOpToLLVMPattern;
12+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
1413

1514
LogicalResult
1615
matchAndRewrite(triton::GetNumProgramsOp op, OpAdaptor adaptor,

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1293,12 +1293,9 @@ class AccelerateBlocked : public OpRewritePattern<DotOp> {
12931293
#define GEN_PASS_DEF_TRITONAMDGPUACCELERATEMATMUL
12941294
#include "TritonAMDGPUTransforms/Passes.h.inc"
12951295

1296-
class TritonAMDGPUAccelerateMatmulPass
1297-
: public impl::TritonAMDGPUAccelerateMatmulBase<
1298-
TritonAMDGPUAccelerateMatmulPass> {
1299-
public:
1300-
using impl::TritonAMDGPUAccelerateMatmulBase<
1301-
TritonAMDGPUAccelerateMatmulPass>::TritonAMDGPUAccelerateMatmulBase;
1296+
struct TritonAMDGPUAccelerateMatmulPass
1297+
: impl::TritonAMDGPUAccelerateMatmulBase<TritonAMDGPUAccelerateMatmulPass> {
1298+
using Base::Base;
13021299

13031300
void runOnOperation() override {
13041301

0 commit comments

Comments
 (0)