Skip to content

Commit 4a99671

Browse files
AndreyPavlenkowhitneywhtsang
authored andcommitted
Replace isF...() LLVM API calls with the corresponding isa<...>() (#3268)
1 parent d812777 commit 4a99671

File tree

4 files changed

+14
-19
lines changed

4 files changed

+14
-19
lines changed

third_party/intel/lib/Analysis/DPAS.cpp

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,9 @@ DPASAnalysis::getDPASType(OpTy op) {
125125
if (aElemTy.isF32() && op.getInputPrecision() == InputPrecision::TF32)
126126
return DPASEngineType::FP32_FP32_TF32_TF32;
127127
// For FP8XFP8->FP32, upcast to FP16
128-
if (aElemTy.isFloat8E5M2())
128+
if (isa<Float8E5M2Type>(aElemTy))
129129
return DPASEngineType::FP32_FP32_FP16_FP16;
130-
if (aElemTy.isFloat8E4M3FN())
130+
if (isa<Float8E4M3FNType>(aElemTy))
131131
return DPASEngineType::FP32_FP32_FP16_FP16;
132132
} else if (dElemTy.isF16()) {
133133
if (aElemTy.isF16())
@@ -147,36 +147,32 @@ DPASAnalysis::getDPASType(OpTy op) {
147147

148148
if (isa<FloatType>(dElemTy)) {
149149
if (dElemTy.isF32()) {
150-
if (aElemTy.isBF16() &&
151-
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
150+
if (aElemTy.isBF16() && isa<Float8E4M3FNType, Float8E5M2Type>(bElemTy))
152151
return DPASEngineType::FP32_FP32_BF16_FP8;
153152
// 2 E2M1 are packed into 1 int8
154153
if (aElemTy.isBF16() && bElemTy.isInteger(8))
155154
return DPASEngineType::FP32_FP32_BF16_FP4;
156-
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
157-
bElemTy.isBF16())
155+
if (isa<Float8E4M3FNType, Float8E5M2Type>(aElemTy) && bElemTy.isBF16())
158156
return DPASEngineType::FP32_FP32_FP8_BF16;
159-
if (aElemTy.isF16() &&
160-
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
157+
if (aElemTy.isF16() && isa<Float8E4M3FNType, Float8E5M2Type>(bElemTy))
161158
return DPASEngineType::FP32_FP32_FP16_FP8;
162159
// 2 E2M1 are packed into 1 int8
163160
if (aElemTy.isF16() && bElemTy.isInteger(8))
164161
return DPASEngineType::FP32_FP32_FP16_FP4;
165-
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
166-
bElemTy.isF16())
162+
if (isa<Float8E4M3FNType, Float8E5M2Type>(aElemTy) && bElemTy.isF16())
167163
return DPASEngineType::FP32_FP32_FP8_FP16;
168-
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
169-
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
164+
if (isa<Float8E4M3FNType, Float8E5M2Type>(aElemTy) &&
165+
isa<Float8E4M3FNType, Float8E5M2Type>(bElemTy))
170166
return DPASEngineType::FP32_FP32_FP8_FP8;
171-
if ((aElemTy.isFloat8E4M3FN() || aElemTy.isFloat8E5M2()) &&
167+
if ((isa<Float8E4M3FNType>(aElemTy) || isa<Float8E5M2Type>(aElemTy)) &&
172168
bElemTy.isInteger(8))
173169
return DPASEngineType::FP32_FP32_FP8_FP4;
174170
if (aElemTy.isInteger(8) && bElemTy.isBF16())
175171
return DPASEngineType::FP32_FP32_FP4_BF16;
176172
if (aElemTy.isInteger(8) && bElemTy.isF16())
177173
return DPASEngineType::FP32_FP32_FP4_FP16;
178174
if (aElemTy.isInteger(8) &&
179-
(bElemTy.isFloat8E4M3FN() || bElemTy.isFloat8E5M2()))
175+
isa<Float8E4M3FNType, Float8E5M2Type>(bElemTy))
180176
return DPASEngineType::FP32_FP32_FP4_FP8;
181177
}
182178
}

third_party/intel/lib/Dialect/TritonIntelGPU/IR/Dialect.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -405,7 +405,7 @@ unsigned DpasEncodingAttr::getOpsPerChannel(Type elemType) {
405405
assert(elemType.isIntOrFloat() && "unsupported type for DpasEncodingAttr");
406406

407407
unsigned dpasElemBitWidths = elemType.getIntOrFloatBitWidth();
408-
if (elemType.isFloat8E5M2() || elemType.isFloat8E4M3FN())
408+
if (llvm::isa<Float8E5M2Type, Float8E4M3FNType>(elemType))
409409
dpasElemBitWidths *= 2; // We are upcasting FP8 to FP16.
410410

411411
return DPASCapability::opsChanBitWidths / dpasElemBitWidths;

third_party/intel/lib/TritonIntelGPUToLLVM/ElementwiseOpToLLVM.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -996,7 +996,7 @@ struct FpToFpOpConversion
996996
auto dstElementType = getElementType(op.getResult());
997997
auto roundingMode = op.getRounding();
998998

999-
if (dstElementType.isFloat8E5M2() || dstElementType.isFloat8E4M3FN()) {
999+
if (isa<Float8E5M2Type, Float8E4M3FNType>(dstElementType)) {
10001000
assert(roundingMode.has_value() &&
10011001
"Rounding mode must be specified for conversions to fp8");
10021002

third_party/intel/lib/TritonIntelGPUTransforms/AccelerateMatmul.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,7 @@ class BlockedToDPAS : public OpRewritePattern<tt::DotOp> {
132132
oldAType.getElementType().getIntOrFloatBitWidth();
133133

134134
// We are upcasting FP8 to FP16
135-
if (oldAType.getElementType().isFloat8E5M2() ||
136-
oldAType.getElementType().isFloat8E4M3FN())
135+
if (isa<Float8E5M2Type, Float8E4M3FNType>(oldAType.getElementType()))
137136
dpasElemBitWidths = 2 * dpasElemBitWidths;
138137

139138
// Enlarge the repCluster size to use the large 2D load for A and B
@@ -488,7 +487,7 @@ static void decomposeMixedModeDotOp(ModuleOp mod) {
488487

489488
Type promoteType;
490489
if (dpasLayout) {
491-
bool isNativeFP8 = AElType.isFloat8E5M2() || AElType.isFloat8E4M3FN();
490+
bool isNativeFP8 = isa<Float8E5M2Type, Float8E4M3FNType>(AElType);
492491
// fp8 is not natively supported by the the DPAS instruction, promote it
493492
// to fp16.
494493
if (!isNativeFP8)

0 commit comments

Comments
 (0)