Skip to content

Commit 3a5bddf

Browse files
Use isa<> with multiple temlates, syncronized with triton-lang/triton#5684
1 parent f27188e commit 3a5bddf

File tree

15 files changed

+62
-77
lines changed

15 files changed

+62
-77
lines changed

include/triton/Conversion/MLIRTypes.h

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,17 +26,15 @@ inline Type f32Ty(MLIRContext *ctx) { return Float32Type::get(ctx); }
2626
inline Type f64Ty(MLIRContext *ctx) { return Float64Type::get(ctx); }
2727
inline Type bf16Ty(MLIRContext *ctx) { return BFloat16Type::get(ctx); }
2828

29-
inline bool isFloat(Type type) {
30-
return type.isF32() || type.isF64() || type.isF16() || type.isF128() ||
31-
type.isBF16() || isa<Float8E4M3B11FNUZType>(type) ||
32-
isa<Float8E4M3FNType>(type) || isa<Float8E4M3FNUZType>(type) ||
33-
isa<Float8E5M2Type>(type) || isa<Float8E5M2FNUZType>(type);
29+
inline bool isFloat8(Type type) {
30+
return isa<Float8E4M3B11FNUZType, Float8E4M3FNType, Float8E4M3FNUZType,
31+
Float8E5M2Type, Float8E5M2FNUZType>(type);
3432
}
3533

36-
inline bool isFloat8(Type type) {
37-
return isa<Float8E4M3B11FNUZType>(type) || isa<Float8E4M3FNType>(type) ||
38-
isa<Float8E4M3FNUZType>(type) || isa<Float8E5M2Type>(type) ||
39-
isa<Float8E5M2FNUZType>(type);
34+
inline bool isFloat(Type type) {
35+
return type.isF32() || type.isF64() || type.isF16() || type.isF128() ||
36+
type.isBF16() || llvm::isa<Float8E4M3B11FNUZType>(type) ||
37+
isFloat8(type);
4038
}
4139

4240
inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }

lib/Analysis/Utility.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -738,14 +738,14 @@ bool supportMMA(triton::DotOp op, int version) {
738738
return false;
739739
if (!(numWarps % 4 == 0 && retShapePerCTA[rank - 2] % 64 == 0 &&
740740
retShapePerCTA[rank - 1] % 8 == 0 &&
741-
(isa<Float8E5M2Type>(aElemTy) || isa<Float8E4M3FNType>(aElemTy) ||
741+
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy) ||
742742
aElemTy.isInteger(8) || aElemTy.isF16() || aElemTy.isBF16() ||
743743
aElemTy.isF32()))) {
744744
return false;
745745
}
746746
// We cannot use MMA_V3 if we need to accumulate in F32 within the MMA op.
747747
if (op.getMaxNumImpreciseAcc() < 32 &&
748-
(isa<Float8E5M2Type>(aElemTy) || isa<Float8E4M3FNType>(aElemTy)) &&
748+
(llvm::isa<Float8E5M2Type, Float8E4M3FNType>(aElemTy)) &&
749749
cast<RankedTensorType>(op.getType()).getElementType().isF32()) {
750750
return false;
751751
}
@@ -766,9 +766,8 @@ bool supportMMA(Value value, int version) {
766766
cast<triton::gpu::TensorOrMemDesc>(value.getType()).getElementType();
767767
// FP8 is not natively supported on all mma versions but it can always be
768768
// promoted to fp16 therefore we can always support it.
769-
bool isFP8 = isa<Float8E5M2Type>(elemTy) || isa<Float8E4M3FNType>(elemTy) ||
770-
isa<Float8E5M2FNUZType>(elemTy) ||
771-
isa<Float8E4M3FNUZType>(elemTy);
769+
bool isFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
770+
Float8E4M3FNUZType>(elemTy);
772771
return isFP8 || elemTy.isF16() || elemTy.isBF16() ||
773772
(elemTy.isF32() && version >= 2) ||
774773
(elemTy.isInteger(8) && version >= 2);

lib/Dialect/TritonGPU/Transforms/AccelerateMatmul.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod, int computeCapability) {
344344
NvidiaMmaEncodingAttr mmaLayout =
345345
dyn_cast<NvidiaMmaEncodingAttr>(D.getType().getEncoding());
346346
if (mmaLayout) {
347-
bool isNativeFP8 =
348-
isa<Float8E5M2Type>(AElType) || isa<Float8E4M3FNType>(AElType);
347+
bool isNativeFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType>(AElType);
349348
// promote operands for sm < 89 since fp8 mma is not natively supported
350349
// promote operands for sm >= 90 when mma is not v3
351350
if (!isNativeFP8 ||

lib/Dialect/TritonGPU/Transforms/Utility.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,9 @@ SmallVector<unsigned, 3> mmaVersionToInstrShape(int version,
4444
SmallVector<unsigned> validN;
4545

4646
// MMAv3 with larger instruction shape is preferred.
47-
if (isa<Float8E5M2Type>(eltType) || isa<Float8E4M3FNType>(eltType) ||
48-
isa<Float8E4M3FNUZType>(eltType) || eltType.isF16() ||
49-
eltType.isBF16() || eltType.isF32()) {
47+
if (llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E4M3FNUZType>(
48+
eltType) ||
49+
eltType.isF16() || eltType.isBF16() || eltType.isF32()) {
5050
validN.assign({256, 248, 240, 232, 224, 216, 208, 200, 192, 184, 176,
5151
168, 160, 152, 144, 136, 128, 120, 112, 104, 96, 88,
5252
80, 72, 64, 56, 48, 40, 32, 24, 16, 8});

lib/Dialect/TritonNvidiaGPU/IR/Ops.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,8 @@ bool WarpGroupDotOp::needsPartialAccumulator() {
7777
const auto &d = getD();
7878
auto aTensorTy = cast<triton::gpu::TensorOrMemDesc>(a.getType());
7979
auto aElTy = cast<triton::gpu::TensorOrMemDesc>(a.getType()).getElementType();
80-
bool isFP8 = isa<Float8E5M2Type>(aElTy) || isa<Float8E4M3FNType>(aElTy) ||
81-
isa<Float8E5M2FNUZType>(aElTy) || isa<Float8E4M3FNUZType>(aElTy);
80+
bool isFP8 = llvm::isa<Float8E5M2Type, Float8E4M3FNType, Float8E5M2FNUZType,
81+
Float8E4M3FNUZType>(aElTy);
8282
bool accFP32 =
8383
cast<triton::gpu::TensorOrMemDesc>(d.getType()).getElementType().isF32();
8484
uint32_t maxNumImpreciseAcc = getMaxNumImpreciseAcc();

third_party/amd/lib/TritonAMDGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -985,18 +985,16 @@ struct FpToFpOpConversion
985985
return outVals;
986986
}
987987
size_t numElements = 4;
988-
if (isa<Float8E4M3FNType>(srcElementType) ||
989-
isa<Float8E4M3FNType>(dstElementType) ||
990-
isa<Float8E4M3FNUZType>(srcElementType) ||
991-
isa<Float8E4M3FNUZType>(dstElementType) ||
992-
isa<Float8E5M2FNUZType>(srcElementType) ||
993-
isa<Float8E5M2FNUZType>(dstElementType)) {
988+
if (llvm::isa<Float8E4M3FNType, Float8E4M3FNUZType, Float8E5M2FNUZType>(
989+
srcElementType) ||
990+
llvm::isa<Float8E4M3FNType, Float8E4M3FNUZType, Float8E5M2FNUZType>(
991+
dstElementType)) {
994992
numElements = 2;
995993
}
996994
bool useFP16IntermediateSrc =
997-
srcElementType.isF32() && !(isaFamily == AMD::ISAFamily::CDNA3 &&
998-
(isa<Float8E4M3FNUZType>(dstElementType) ||
999-
isa<Float8E5M2FNUZType>(dstElementType)));
995+
srcElementType.isF32() &&
996+
!(isaFamily == AMD::ISAFamily::CDNA3 &&
997+
(llvm::isa<Float8E4M3FNUZType, Float8E5M2FNUZType>(dstElementType)));
1000998
bool isDstFP32 = dstElementType.isF32();
1001999
Type srcType = useFP16IntermediateSrc ? f16_ty : srcElementType;
10021000
Type dstType = isDstFP32 ? f16_ty : dstElementType;

third_party/amd/lib/TritonAMDGPUTransforms/AccelerateAMDMatmul.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -416,8 +416,7 @@ class BlockedToMFMA : public OpRewritePattern<tt::DotOp> {
416416
// store instructions, except for fp8 matmul kernels due to regression
417417
// TODO (lixun): investigate the regression and enable this feature again
418418
auto aElemTy = mfmaInstr.getElementTypeA();
419-
bool isFP8 =
420-
isa<Float8E5M2FNUZType>(aElemTy) || isa<Float8E4M3FNUZType>(aElemTy);
419+
bool isFP8 = llvm::isa<Float8E5M2FNUZType, Float8E4M3FNUZType>(aElemTy);
421420
bool isTransposed = isChainDot(dotOp) || !isFP8;
422421
mfmaEnc = ttg::AMDMfmaEncodingAttr::get(
423422
oldRetType.getContext(),

third_party/amd/lib/TritonAMDGPUTransforms/MfmaGroup.cpp

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,24 @@ static MfmaTypeId chooseAppropriateMfmaId(mlir::Type dataTypeA,
2020
if (dataTypeA.isInteger(8) && dataTypeB.isInteger(8)) {
2121
return MfmaTypeId::I8TyId;
2222
}
23-
if (isa<Float8E4M3FNUZType>(dataTypeA) &&
24-
isa<Float8E4M3FNUZType>(dataTypeB)) {
23+
if (llvm::isa<Float8E4M3FNUZType>(dataTypeA) &&
24+
llvm::isa<Float8E4M3FNUZType>(dataTypeB)) {
2525
return MfmaTypeId::Fp8Fp8TyId;
2626
}
27-
if (isa<Float8E4M3FNUZType>(dataTypeA) &&
28-
isa<Float8E5M2FNUZType>(dataTypeB)) {
27+
if (llvm::isa<Float8E4M3FNUZType>(dataTypeA) &&
28+
llvm::isa<Float8E5M2FNUZType>(dataTypeB)) {
2929
return MfmaTypeId::Fp8Bf8TyId;
3030
}
31-
if (isa<Float8E5M2FNUZType>(dataTypeA) &&
32-
isa<Float8E4M3FNUZType>(dataTypeB)) {
31+
if (llvm::isa<Float8E5M2FNUZType>(dataTypeA) &&
32+
llvm::isa<Float8E4M3FNUZType>(dataTypeB)) {
3333
return MfmaTypeId::Bf8Fp8TyId;
3434
}
35-
if (isa<Float8E5M2FNUZType>(dataTypeA) &&
36-
isa<Float8E5M2FNUZType>(dataTypeB)) {
35+
if (llvm::isa<Float8E5M2FNUZType>(dataTypeA) &&
36+
llvm::isa<Float8E5M2FNUZType>(dataTypeB)) {
3737
return MfmaTypeId::Bf8Bf8TyId;
3838
}
39-
if (isa<Float8E5M2Type>(dataTypeA) && isa<Float8E5M2Type>(dataTypeB)) {
39+
if (llvm::isa<Float8E5M2Type>(dataTypeA) &&
40+
llvm::isa<Float8E5M2Type>(dataTypeB)) {
4041
return MfmaTypeId::Fp16TyId;
4142
}
4243
llvm_unreachable("Unsupported input argument type.");

third_party/intel/lib/Analysis/DPAS.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -147,36 +147,32 @@ DPASAnalysis::getDPASType(OpTy op) {
147147

148148
if (isa<FloatType>(dElemTy)) {
149149
if (dElemTy.isF32()) {
150-
if (aElemTy.isBF16() &&
151-
(isa<Float8E4M3FNType>(bElemTy) || isa<Float8E5M2Type>(bElemTy)))
150+
if (aElemTy.isBF16() && isa<Float8E4M3FNType, Float8E5M2Type>(bElemTy))
152151
return DPASEngineType::FP32_FP32_BF16_FP8;
153152
// 2 E2M1 are packed into 1 int8
154153
if (aElemTy.isBF16() && bElemTy.isInteger(8))
155154
return DPASEngineType::FP32_FP32_BF16_FP4;
156-
if ((isa<Float8E4M3FNType>(aElemTy) || isa<Float8E5M2Type>(aElemTy)) &&
157-
bElemTy.isBF16())
155+
if (isa<Float8E4M3FNType, Float8E5M2Type>(aElemTy) && bElemTy.isBF16())
158156
return DPASEngineType::FP32_FP32_FP8_BF16;
159-
if (aElemTy.isF16() &&
160-
(isa<Float8E4M3FNType>(bElemTy) || isa<Float8E5M2Type>(bElemTy)))
157+
if (aElemTy.isF16() && isa<Float8E4M3FNType, Float8E5M2Type>(bElemTy))
161158
return DPASEngineType::FP32_FP32_FP16_FP8;
162159
// 2 E2M1 are packed into 1 int8
163160
if (aElemTy.isF16() && bElemTy.isInteger(8))
164161
return DPASEngineType::FP32_FP32_FP16_FP4;
165-
if ((isa<Float8E4M3FNType>(aElemTy) || isa<Float8E5M2Type>(aElemTy)) &&
166-
bElemTy.isF16())
162+
if (isa<Float8E4M3FNType, Float8E5M2Type>(aElemTy) && bElemTy.isF16())
167163
return DPASEngineType::FP32_FP32_FP8_FP16;
168-
if ((isa<Float8E4M3FNType>(aElemTy) || isa<Float8E5M2Type>(aElemTy)) &&
169-
(isa<Float8E4M3FNType>(bElemTy) || isa<Float8E5M2Type>(bElemTy)))
164+
if (isa<Float8E4M3FNType, Float8E5M2Type>(aElemTy) &&
165+
isa<Float8E4M3FNType, Float8E5M2Type>(bElemTy))
170166
return DPASEngineType::FP32_FP32_FP8_FP8;
171-
if ((isa<Float8E4M3FNType>(aElemTy) || isa<Float8E5M2Type>(aElemTy)) &&
167+
if (isa<Float8E4M3FNType, Float8E5M2Type>(aElemTy) &&
172168
bElemTy.isInteger(8))
173169
return DPASEngineType::FP32_FP32_FP8_FP4;
174170
if (aElemTy.isInteger(8) && bElemTy.isBF16())
175171
return DPASEngineType::FP32_FP32_FP4_BF16;
176172
if (aElemTy.isInteger(8) && bElemTy.isF16())
177173
return DPASEngineType::FP32_FP32_FP4_FP16;
178174
if (aElemTy.isInteger(8) &&
179-
(isa<Float8E4M3FNType>(bElemTy) || isa<Float8E5M2Type>(bElemTy)))
175+
isa<Float8E4M3FNType, Float8E5M2Type>(bElemTy))
180176
return DPASEngineType::FP32_FP32_FP4_FP8;
181177
}
182178
}

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -405,8 +405,7 @@ unsigned DpasEncodingAttr::getOpsPerChannel(Type elemType) {
405405
assert(elemType.isIntOrFloat() && "unsupported type for DpasEncodingAttr");
406406

407407
unsigned dpasElemBitWidths = elemType.getIntOrFloatBitWidth();
408-
if (llvm::isa<Float8E5M2Type>(elemType) ||
409-
llvm::isa<Float8E4M3FNType>(elemType))
408+
if (llvm::isa<Float8E5M2Type, Float8E4M3FNType>(elemType))
410409
dpasElemBitWidths *= 2; // We are upcasting FP8 to FP16.
411410

412411
return DPASCapability::opsChanBitWidths / dpasElemBitWidths;

0 commit comments

Comments
 (0)