Skip to content

Commit c489597

Browse files
maksleventalAndreyPavlenko
authored andcommitted
Syncronized with triton-lang/triton#5684
1 parent f9e8aa5 commit c489597

File tree

1 file changed

+5
-4
lines changed
  • third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM

1 file changed

+5
-4
lines changed

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ enum class mxfpKind { mxf8f6f4 = 0, mxf4 = 1, mxf4nvf4 = 2 };
6060
inline mxfpKind getMXFPKind(ScaleDotElemType typeA, ScaleDotElemType typeB,
6161
Type scaleAType, Type scaleBType) {
6262
if (typeA == ScaleDotElemType::E2M1 && typeB == ScaleDotElemType::E2M1) {
63-
if (scaleAType.isFloat8E4M3FN() && scaleBType.isFloat8E4M3FN()) {
63+
if (llvm::isa<Float8E4M3FNType>(scaleAType) &&
64+
llvm::isa<Float8E4M3FNType>(scaleBType)) {
6465
return mxfpKind::mxf4nvf4;
6566
}
6667
return mxfpKind::mxf4;
@@ -100,9 +101,9 @@ static Value createInstDescriptor(ConversionPatternRewriter &rewriter,
100101
return 1;
101102
if (type.isF32())
102103
return 2;
103-
if (type.isFloat8E4M3FN())
104+
if (llvm::isa<Float8E4M3FNType>(type))
104105
return 0;
105-
if (type.isFloat8E5M2())
106+
if (llvm::isa<Float8E5M2Type>(type))
106107
return 1;
107108
llvm_unreachable("Unsupported type.");
108109
};
@@ -224,7 +225,7 @@ static void createGen5MMA(ConversionPatternRewriter &rewriter, Location loc,
224225
opcode += "f16";
225226
else if (srcElementTy.isF32())
226227
opcode += "tf32";
227-
else if (srcElementTy.isFloat8E4M3FN() || srcElementTy.isFloat8E5M2())
228+
else if (llvm::isa<Float8E4M3FNType, Float8E5M2Type>(srcElementTy))
228229
opcode += "f8f6f4";
229230
else
230231
assert(0 && "Unsupported type.");

0 commit comments

Comments
 (0)