Skip to content

Commit 7e67848

Browse files
committed
Move conversion logic two two helper functions
1 parent e616cf6 commit 7e67848

File tree

1 file changed

+70
-48
lines changed

1 file changed

+70
-48
lines changed

mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Lines changed: 70 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -989,26 +989,17 @@ mfmaOpToScaledIntrinsic(ScaledMFMAOp smfma, Chipset chipset) {
989989
smfma.getN(), smfma.getK(), 1u, chipset);
990990
}
991991

992-
/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
993-
/// if one exists. This includes checking to ensure the intrinsic is supported
994-
/// on the architecture you are compiling for.
995-
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
996-
Chipset chipset) {
997-
auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
998-
auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
999-
auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1000-
Type elemSourceType = sourceVectorType.getElementType();
1001-
Type elemBSourceType = sourceBVectorType.getElementType();
1002-
Type elemDestType = destVectorType.getElementType();
1003-
1004-
const uint32_t k = wmma.getK();
1005-
const bool isRDNA3 = chipset.majorVersion == 11;
1006-
const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
992+
/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
993+
/// for RDNA3/4 architectures.
994+
static std::optional<StringRef>
995+
wmmaOpToIntrinsicRDNA(Type elemSourceType, Type elemBSourceType,
996+
Type elemDestType, uint32_t k, bool isRDNA3) {
997+
using fp8 = Float8E4M3FNType;
998+
using bf8 = Float8E5M2Type;
1007999

1000+
// Handle k == 16 for RDNA3/4.
10081001
if (k == 16) {
1009-
if (!isRDNA3 && !isRDNA4) // gfx1250 does not have any wmma ops with k=16.
1010-
return std::nullopt;
1011-
1002+
// Common patterns for RDNA3 and RDNA4.
10121003
if (elemSourceType.isF16() && elemDestType.isF32())
10131004
return ROCDL::wmma_f32_16x16x16_f16::getOperationName();
10141005
if (elemSourceType.isBF16() && elemDestType.isF32())
@@ -1019,22 +1010,15 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
10191010
return ROCDL::wmma_bf16_16x16x16_bf16::getOperationName();
10201011
if (elemSourceType.isInteger(8) && elemDestType.isInteger(32))
10211012
return ROCDL::wmma_i32_16x16x16_iu8::getOperationName();
1022-
if (chipset.majorVersion == 11) {
1013+
1014+
// RDNA3 specific patterns.
1015+
if (isRDNA3) {
10231016
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
10241017
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
1025-
}
1026-
}
1027-
if (isRDNA3)
1028-
return std::nullopt;
1029-
1030-
using fp8 = Float8E4M3FNType;
1031-
using bf8 = Float8E5M2Type;
1032-
1033-
// gfx12+
1034-
if (k == 16) {
1035-
if (!isRDNA4) // gfx1250 does not have any wmma ops with k=16.
10361018
return std::nullopt;
1019+
}
10371020

1021+
// RDNA4 specific patterns (fp8/bf8).
10381022
if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
10391023
elemDestType.isF32())
10401024
return ROCDL::wmma_f32_16x16x16_fp8_fp8::getOperationName();
@@ -1047,20 +1031,38 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
10471031
if (isa<bf8>(elemSourceType) && isa<fp8>(elemBSourceType) &&
10481032
elemDestType.isF32())
10491033
return ROCDL::wmma_f32_16x16x16_bf8_fp8::getOperationName();
1050-
10511034
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
10521035
return ROCDL::wmma_i32_16x16x16_iu4::getOperationName();
10531036

10541037
return std::nullopt;
10551038
}
1056-
if (k == 32) {
1057-
if (isRDNA4) {
1058-
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1059-
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1060-
return std::nullopt;
1061-
}
10621039

1063-
// gfx1250
1040+
// Handle k == 32 for RDNA4.
1041+
if (k == 32 && !isRDNA3) {
1042+
if (elemSourceType.isInteger(4) && elemDestType.isInteger(32))
1043+
return ROCDL::wmma_i32_16x16x32_iu4::getOperationName();
1044+
}
1045+
1046+
llvm_unreachable("Unsupported k value");
1047+
}
1048+
1049+
/// Return the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1050+
/// for the gfx1250 architecture.
1051+
static std::optional<StringRef> wmmaOpToIntrinsicGfx1250(Type elemSourceType,
1052+
Type elemBSourceType,
1053+
Type elemDestType,
1054+
uint32_t k) {
1055+
using fp8 = Float8E4M3FNType;
1056+
using bf8 = Float8E5M2Type;
1057+
1058+
if (k == 4) {
1059+
if (elemSourceType.isF32() && elemDestType.isF32())
1060+
return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1061+
1062+
return std::nullopt;
1063+
}
1064+
1065+
if (k == 32) {
10641066
if (elemSourceType.isF16() && elemDestType.isF32())
10651067
return ROCDL::wmma_f32_16x16x32_f16::getOperationName();
10661068
if (elemSourceType.isBF16() && elemDestType.isF32())
@@ -1073,16 +1075,6 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
10731075
return std::nullopt;
10741076
}
10751077

1076-
if (isRDNA4)
1077-
return std::nullopt;
1078-
1079-
// gfx1250
1080-
if (k == 4) {
1081-
if (elemSourceType.isF32() && elemDestType.isF32())
1082-
return ROCDL::wmma_f32_16x16x4_f32::getOperationName();
1083-
return std::nullopt;
1084-
}
1085-
10861078
if (k == 64) {
10871079
if (isa<fp8>(elemSourceType) && isa<fp8>(elemBSourceType)) {
10881080
if (elemDestType.isF32())
@@ -1142,6 +1134,36 @@ static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
11421134

11431135
return std::nullopt;
11441136
}
1137+
1138+
llvm_unreachable("Unsupported k value");
1139+
}
1140+
1141+
/// Returns the `rocdl` intrinsic corresponding to a WMMA operation `wmma`
1142+
/// if one exists. This includes checking to ensure the intrinsic is supported
1143+
/// on the architecture you are compiling for.
1144+
static std::optional<StringRef> wmmaOpToIntrinsic(WMMAOp wmma,
1145+
Chipset chipset) {
1146+
auto sourceVectorType = cast<VectorType>(wmma.getSourceA().getType());
1147+
auto sourceBVectorType = cast<VectorType>(wmma.getSourceB().getType());
1148+
auto destVectorType = cast<VectorType>(wmma.getDestC().getType());
1149+
Type elemSourceType = sourceVectorType.getElementType();
1150+
Type elemBSourceType = sourceBVectorType.getElementType();
1151+
Type elemDestType = destVectorType.getElementType();
1152+
1153+
const uint32_t k = wmma.getK();
1154+
const bool isRDNA3 = chipset.majorVersion == 11;
1155+
const bool isRDNA4 = chipset.majorVersion == 12 && chipset.minorVersion == 0;
1156+
1157+
// Handle RDNA3 and RDNA4.
1158+
if (isRDNA3 || isRDNA4)
1159+
return wmmaOpToIntrinsicRDNA(elemSourceType, elemBSourceType, elemDestType,
1160+
k, isRDNA3);
1161+
1162+
// Handle gfx1250.
1163+
if (chipset == Chipset{12, 5, 0})
1164+
return wmmaOpToIntrinsicGfx1250(elemSourceType, elemBSourceType,
1165+
elemDestType, k);
1166+
11451167
llvm_unreachable("unhandled WMMA case");
11461168
}
11471169

0 commit comments

Comments
 (0)