@@ -1079,7 +1079,7 @@ def CVTFP6TypeAttr : EnumAttr<NVVM_Dialect, CVTFP6Type, "cvt_fp6_type"> {
10791079 let assemblyFormat = "`<` $value `>`";
10801080}
10811081
1082- def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
1082+ def NVVM_CvtF32x2ToF6x2Op : NVVM_Op<"cvt.f32x2 .to.f6x2"> {
10831083 let summary = "Convert a pair of float inputs to f6x2";
10841084 let description = [{
10851085 This Op converts each of the given float inputs to the specified fp6 type.
@@ -1096,6 +1096,7 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
10961096
10971097 [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
10981098 }];
1099+
10991100 let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
11001101 let arguments = (ins
11011102 CVTFP6TypeAttr:$type,
@@ -1110,7 +1111,7 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
11101111 }];
11111112
11121113 string llvmBuilder = [{
1113- auto intId = NVVM::CvtToF6x2Op ::getIntrinsicID($type, $relu);
1114+ auto intId = NVVM::CvtF32x2ToF6x2Op ::getIntrinsicID($type, $relu);
11141115 llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
11151116 if(op.getDst().getType().isInteger(16))
11161117 $dst = packedI16;
@@ -1120,6 +1121,153 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> {
11201121 }];
11211122}
11221123
1124+ def CVTFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">;
1125+ def CVTFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">;
1126+ def CVTFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">;
1127+
1128+ def CVTFP8Type : I32EnumAttr<"CVTFP8Type", "NVVM CVTFP8Type kind",
1129+ [CVTFP8E4M3, CVTFP8E5M2, CVTFP8UE8M0]> {
1130+ let genSpecializedAttr = 0;
1131+ let cppNamespace = "::mlir::NVVM";
1132+ }
1133+ def CVTFP8TypeAttr : EnumAttr<NVVM_Dialect, CVTFP8Type, "cvt_fp8_type"> {
1134+ let assemblyFormat = "`<` $value `>`";
1135+ }
1136+
1137+ def NVVM_CvtF32x2ToF8x2Op : NVVM_Op<"cvt.f32x2.to.f8x2"> {
1138+ let summary = "Convert a pair of float inputs to f8x2";
1139+ let description = [{
1140+ This Op converts each of the given float inputs to the specified fp8 type.
1141+ The result `dst` is represented as an i16 type or as a vector
1142+ of two i8 types.
1143+ If `dst` is returned as an i16 type, the converted values are packed such
1144+ that the value converted from `a` is stored in the upper 8 bits of `dst`
1145+ and the value converted from `b` is stored in the lower 8 bits of `dst`.
1146+ If `dst` is returned as a vector type, each converted value is stored as an
1147+ i8 element in the vector.
1148+ The `rnd` and `sat` attributes specify the rounding and saturation modes respectively.
1149+ The `relu` attribute, when set, lowers to the '.relu' variant of
1150+ the cvt instruction.
1151+
1152+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1153+ }];
1154+
1155+ let hasVerifier = 1;
1156+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1157+ let arguments = (ins
1158+ CVTFP8TypeAttr:$type,
1159+ F32:$a,
1160+ F32:$b,
1161+ DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
1162+ DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat,
1163+ DefaultValuedAttr<BoolAttr, "false">:$relu);
1164+ let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)";
1165+
1166+ let extraClassDeclaration = [{
1167+ static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
1168+ NVVM::FPRoundingMode rnd,
1169+ NVVM::SaturationMode sat,
1170+ bool hasRelu);
1171+ }];
1172+
1173+ string llvmBuilder = [{
1174+ auto intId = NVVM::CvtF32x2ToF8x2Op::getIntrinsicID($type, $rnd, $sat, $relu);
1175+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b});
1176+ if(op.getDst().getType().isInteger(16))
1177+ $dst = packedI16;
1178+ else
1179+ $dst = builder.CreateBitCast(packedI16,
1180+ llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
1181+ }];
1182+ }
1183+
1184+ def NVVM_CvtF16x2ToF8x2Op : NVVM_Op<"cvt.f16x2.to.f8x2"> {
1185+ let summary = "Convert an f16x2 input to f8x2";
1186+ let description = [{
1187+ This Op converts the given f16 inputs in an f16x2 vector to the specified
1188+ f8 type.
1189+ The result `dst` is represented as an i16 type or as a vector
1190+ of two i8 types.
1191+ If `dst` is returned as an i16 type, the converted values from `a`
1192+ are packed such that the value converted from the first element of `a`
1193+ is stored in the upper 8 bits of `dst` and the value converted from the
1194+ second element of `a` is stored in the lower 8 bits of `dst`.
1195+ If `dst` is returned as a vector type, each converted value is stored as an
1196+ i8 element in the vector.
1197+ The `relu` attribute, when set, lowers to the '.relu' variant of
1198+ the cvt instruction.
1199+
1200+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1201+ }];
1202+
1203+ let hasVerifier = 1;
1204+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1205+ let arguments = (ins
1206+ CVTFP8TypeAttr:$type,
1207+ VectorOfLengthAndType<[2], [F16]>:$a,
1208+ DefaultValuedAttr<BoolAttr, "false">:$relu);
1209+ let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
1210+
1211+ let extraClassDeclaration = [{
1212+ static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to,
1213+ bool hasRelu);
1214+ }];
1215+
1216+ string llvmBuilder = [{
1217+ auto intId = NVVM::CvtF16x2ToF8x2Op::getIntrinsicID($type, $relu);
1218+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
1219+ if(op.getDst().getType().isInteger(16))
1220+ $dst = packedI16;
1221+ else
1222+ $dst = builder.CreateBitCast(packedI16,
1223+ llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
1224+ }];
1225+ }
1226+
1227+ def NVVM_CvtBF16x2ToF8x2Op : NVVM_Op<"cvt.bf16x2.to.f8x2"> {
1228+ let summary = "Convert a pair of bf16 inputs to f8x2";
1229+ let description = [{
1230+ This Op converts the given bf16 inputs in a bf16x2 vector to the specified
1231+ f8 type.
1232+ The result `dst` is represented as an i16 type or as a vector
1233+ of two i8 types.
1234+ If `dst` is returned as an i16 type, the converted values from `a`
1235+ are packed such that the value converted from the first element of `a`
1236+ is stored in the upper 8 bits of `dst` and the value converted from the
1237+ second element of `a` is stored in the lower 8 bits of `dst`.
1238+ If `dst` is returned as a vector type, each converted value is stored as an
1239+ i8 element in the vector.
1240+ The `rnd` and `sat` attributes specify the rounding and saturation modes
1241+ respectively.
1242+
1243+ [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt)
1244+ }];
1245+
1246+ let hasVerifier = 1;
1247+ let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst);
1248+ let arguments = (ins
1249+ CVTFP8TypeAttr:$type,
1250+ VectorOfLengthAndType<[2], [BF16]>:$a,
1251+ DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd,
1252+ DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat);
1253+ let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)";
1254+
1255+ let extraClassDeclaration = [{
1256+ static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd,
1257+ NVVM::SaturationMode sat);
1258+ }];
1259+
1260+ string llvmBuilder = [{
1261+ auto intId = NVVM::CvtBF16x2ToF8x2Op::getIntrinsicID($rnd, $sat);
1262+ llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a});
1263+ if(op.getDst().getType().isInteger(16))
1264+ $dst = packedI16;
1265+ else
1266+ $dst = builder.CreateBitCast(packedI16,
1267+ llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2));
1268+ }];
1269+ }
1270+
11231271//===----------------------------------------------------------------------===//
11241272// NVVM MMA Ops
11251273//===----------------------------------------------------------------------===//
0 commit comments