@@ -125,9 +125,9 @@ DPASAnalysis::getDPASType(OpTy op) {
125125 if (aElemTy.isF32 () && op.getInputPrecision () == InputPrecision::TF32)
126126 return DPASEngineType::FP32_FP32_TF32_TF32;
127127 // For FP8XFP8->FP32, upcast to FP16
128- if (aElemTy. isFloat8E5M2 ( ))
128+ if (isa<Float8E5M2Type>(aElemTy ))
129129 return DPASEngineType::FP32_FP32_FP16_FP16;
130- if (aElemTy. isFloat8E4M3FN ( ))
130+ if (isa<Float8E4M3FNType>(aElemTy ))
131131 return DPASEngineType::FP32_FP32_FP16_FP16;
132132 } else if (dElemTy.isF16 ()) {
133133 if (aElemTy.isF16 ())
@@ -147,36 +147,32 @@ DPASAnalysis::getDPASType(OpTy op) {
147147
148148 if (isa<FloatType>(dElemTy)) {
149149 if (dElemTy.isF32 ()) {
150- if (aElemTy.isBF16 () &&
151- (bElemTy.isFloat8E4M3FN () || bElemTy.isFloat8E5M2 ()))
150+ if (aElemTy.isBF16 () && isa<Float8E4M3FNType, Float8E5M2Type>(bElemTy))
152151 return DPASEngineType::FP32_FP32_BF16_FP8;
153152 // 2 E2M1 are packed into 1 int8
154153 if (aElemTy.isBF16 () && bElemTy.isInteger (8 ))
155154 return DPASEngineType::FP32_FP32_BF16_FP4;
156- if ((aElemTy.isFloat8E4M3FN () || aElemTy.isFloat8E5M2 ()) &&
157- bElemTy.isBF16 ())
155+ if (isa<Float8E4M3FNType, Float8E5M2Type>(aElemTy) && bElemTy.isBF16 ())
158156 return DPASEngineType::FP32_FP32_FP8_BF16;
159- if (aElemTy.isF16 () &&
160- (bElemTy.isFloat8E4M3FN () || bElemTy.isFloat8E5M2 ()))
157+ if (aElemTy.isF16 () && isa<Float8E4M3FNType, Float8E5M2Type>(bElemTy))
161158 return DPASEngineType::FP32_FP32_FP16_FP8;
162159 // 2 E2M1 are packed into 1 int8
163160 if (aElemTy.isF16 () && bElemTy.isInteger (8 ))
164161 return DPASEngineType::FP32_FP32_FP16_FP4;
165- if ((aElemTy.isFloat8E4M3FN () || aElemTy.isFloat8E5M2 ()) &&
166- bElemTy.isF16 ())
162+ if (isa<Float8E4M3FNType, Float8E5M2Type>(aElemTy) && bElemTy.isF16 ())
167163 return DPASEngineType::FP32_FP32_FP8_FP16;
168- if ((aElemTy. isFloat8E4M3FN () || aElemTy. isFloat8E5M2 () ) &&
169- (bElemTy. isFloat8E4M3FN () || bElemTy. isFloat8E5M2 () ))
164+ if (isa<Float8E4M3FNType, Float8E5M2Type>(aElemTy ) &&
165+ isa<Float8E4M3FNType, Float8E5M2Type>(bElemTy ))
170166 return DPASEngineType::FP32_FP32_FP8_FP8;
171- if ((aElemTy. isFloat8E4M3FN ( ) || aElemTy. isFloat8E5M2 ( )) &&
167+ if ((isa<Float8E4M3FNType>(aElemTy ) || isa<Float8E5M2Type>(aElemTy )) &&
172168 bElemTy.isInteger (8 ))
173169 return DPASEngineType::FP32_FP32_FP8_FP4;
174170 if (aElemTy.isInteger (8 ) && bElemTy.isBF16 ())
175171 return DPASEngineType::FP32_FP32_FP4_BF16;
176172 if (aElemTy.isInteger (8 ) && bElemTy.isF16 ())
177173 return DPASEngineType::FP32_FP32_FP4_FP16;
178174 if (aElemTy.isInteger (8 ) &&
179- (bElemTy. isFloat8E4M3FN () || bElemTy. isFloat8E5M2 () ))
175+ isa<Float8E4M3FNType, Float8E5M2Type>(bElemTy ))
180176 return DPASEngineType::FP32_FP32_FP4_FP8;
181177 }
182178 }
0 commit comments