-
Notifications
You must be signed in to change notification settings - Fork 15.3k
[MLIR][NVVM] Update support for conversions to f8x2 and f6x2 types #137781
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1079,7 +1079,7 @@ def CVTFP6TypeAttr : EnumAttr<NVVM_Dialect, CVTFP6Type, "cvt_fp6_type"> { | |
| let assemblyFormat = "`<` $value `>`"; | ||
| } | ||
|
|
||
| def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> { | ||
| def NVVM_CvtF32x2ToF6x2Op : NVVM_Op<"cvt.f32x2.to.f6x2"> { | ||
| let summary = "Convert a pair of float inputs to f6x2"; | ||
| let description = [{ | ||
| This Op converts each of the given float inputs to the specified fp6 type. | ||
|
|
@@ -1096,6 +1096,7 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> { | |
|
|
||
| [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) | ||
| }]; | ||
|
|
||
| let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); | ||
| let arguments = (ins | ||
| CVTFP6TypeAttr:$type, | ||
|
|
@@ -1110,7 +1111,7 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> { | |
| }]; | ||
|
|
||
| string llvmBuilder = [{ | ||
| auto intId = NVVM::CvtToF6x2Op::getIntrinsicID($type, $relu); | ||
| auto intId = NVVM::CvtF32x2ToF6x2Op::getIntrinsicID($type, $relu); | ||
| llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b}); | ||
| if(op.getDst().getType().isInteger(16)) | ||
| $dst = packedI16; | ||
|
|
@@ -1120,6 +1121,153 @@ def NVVM_CvtToF6x2Op : NVVM_Op<"cvt.to.f6x2"> { | |
| }]; | ||
| } | ||
|
|
||
| def CVTFP8E4M3 : I32EnumAttrCase<"E4M3", 0, "e4m3">; | ||
| def CVTFP8E5M2 : I32EnumAttrCase<"E5M2", 1, "e5m2">; | ||
| def CVTFP8UE8M0 : I32EnumAttrCase<"UE8M0", 2, "ue8m0">; | ||
|
Comment on lines
+1124
to
+1126
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Creating type for each op won't really scale. I think should use existing types.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah, that makes sense. I did it this way here since the other existing type enums were prefixed with the corresponding Op name. But we should probably unify all of them and rename them with an
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes Guray, I brought this up on the first fp6 cvt Op review itself (more from a re-use perspective, though) With a unified enum (let's say, for all the FP types), we may need to update/tighten the
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we really need enum embedded in NVVM dialect? I'm asking can we just reuse existing MLIR builtin types. At this point, I assume we have all the exotic types. What do you think?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually tried using the builtin types ( |
||
|
|
||
| def CVTFP8Type : I32EnumAttr<"CVTFP8Type", "NVVM CVTFP8Type kind", | ||
| [CVTFP8E4M3, CVTFP8E5M2, CVTFP8UE8M0]> { | ||
| let genSpecializedAttr = 0; | ||
| let cppNamespace = "::mlir::NVVM"; | ||
| } | ||
| def CVTFP8TypeAttr : EnumAttr<NVVM_Dialect, CVTFP8Type, "cvt_fp8_type"> { | ||
| let assemblyFormat = "`<` $value `>`"; | ||
| } | ||
|
|
||
| def NVVM_CvtF32x2ToF8x2Op : NVVM_Op<"cvt.f32x2.to.f8x2"> { | ||
| let summary = "Convert a pair of float inputs to f8x2"; | ||
| let description = [{ | ||
| This Op converts each of the given float inputs to the specified fp8 type. | ||
| The result `dst` is represented as an i16 type or as a vector | ||
| of two i8 types. | ||
| If `dst` is returned as an i16 type, the converted values are packed such | ||
| that the value converted from `a` is stored in the upper 8 bits of `dst` | ||
| and the value converted from `b` is stored in the lower 8 bits of `dst`. | ||
| If `dst` is returned as a vector type, each converted value is stored as an | ||
| i8 element in the vector. | ||
| The `rnd` and `sat` attributes specify the rounding and saturation modes respectively. | ||
| The `relu` attribute, when set, lowers to the '.relu' variant of | ||
| the cvt instruction. | ||
|
|
||
| [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) | ||
| }]; | ||
|
|
||
| let hasVerifier = 1; | ||
| let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); | ||
| let arguments = (ins | ||
| CVTFP8TypeAttr:$type, | ||
| F32:$a, | ||
| F32:$b, | ||
| DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd, | ||
| DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat, | ||
| DefaultValuedAttr<BoolAttr, "false">:$relu); | ||
| let assemblyFormat = "$type $a `,` $b attr-dict `:` type($dst)"; | ||
|
|
||
| let extraClassDeclaration = [{ | ||
| static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to, | ||
| NVVM::FPRoundingMode rnd, | ||
| NVVM::SaturationMode sat, | ||
| bool hasRelu); | ||
| }]; | ||
|
|
||
| string llvmBuilder = [{ | ||
| auto intId = NVVM::CvtF32x2ToF8x2Op::getIntrinsicID($type, $rnd, $sat, $relu); | ||
| llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a, $b}); | ||
| if(op.getDst().getType().isInteger(16)) | ||
| $dst = packedI16; | ||
| else | ||
| $dst = builder.CreateBitCast(packedI16, | ||
| llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2)); | ||
| }]; | ||
| } | ||
|
|
||
| def NVVM_CvtF16x2ToF8x2Op : NVVM_Op<"cvt.f16x2.to.f8x2"> { | ||
| let summary = "Convert an f16x2 input to f8x2"; | ||
| let description = [{ | ||
| This Op converts the given f16 inputs in an f16x2 vector to the specified | ||
| f8 type. | ||
| The result `dst` is represented as an i16 type or as a vector | ||
| of two i8 types. | ||
| If `dst` is returned as an i16 type, the converted values from `a` | ||
| are packed such that the value converted from the first element of `a` | ||
| is stored in the upper 8 bits of `dst` and the value converted from the | ||
| second element of `a` is stored in the lower 8 bits of `dst`. | ||
| If `dst` is returned as a vector type, each converted value is stored as an | ||
| i8 element in the vector. | ||
| The `relu` attribute, when set, lowers to the '.relu' variant of | ||
| the cvt instruction. | ||
|
|
||
| [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) | ||
| }]; | ||
|
|
||
| let hasVerifier = 1; | ||
| let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); | ||
| let arguments = (ins | ||
| CVTFP8TypeAttr:$type, | ||
| VectorOfLengthAndType<[2], [F16]>:$a, | ||
| DefaultValuedAttr<BoolAttr, "false">:$relu); | ||
| let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)"; | ||
|
|
||
| let extraClassDeclaration = [{ | ||
| static llvm::Intrinsic::ID getIntrinsicID(NVVM::CVTFP8Type to, | ||
| bool hasRelu); | ||
| }]; | ||
|
|
||
| string llvmBuilder = [{ | ||
| auto intId = NVVM::CvtF16x2ToF8x2Op::getIntrinsicID($type, $relu); | ||
| llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a}); | ||
| if(op.getDst().getType().isInteger(16)) | ||
| $dst = packedI16; | ||
| else | ||
| $dst = builder.CreateBitCast(packedI16, | ||
| llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2)); | ||
| }]; | ||
| } | ||
|
|
||
| def NVVM_CvtBF16x2ToF8x2Op : NVVM_Op<"cvt.bf16x2.to.f8x2"> { | ||
| let summary = "Convert a pair of bf16 inputs to f8x2"; | ||
| let description = [{ | ||
| This Op converts the given bf16 inputs in a bf16x2 vector to the specified | ||
| f8 type. | ||
| The result `dst` is represented as an i16 type or as a vector | ||
| of two i8 types. | ||
| If `dst` is returned as an i16 type, the converted values from `a` | ||
| are packed such that the value converted from the first element of `a` | ||
| is stored in the upper 8 bits of `dst` and the value converted from the | ||
| second element of `a` is stored in the lower 8 bits of `dst`. | ||
| If `dst` is returned as a vector type, each converted value is stored as an | ||
| i8 element in the vector. | ||
| The `rnd` and `sat` attributes specify the rounding and saturation modes | ||
| respectively. | ||
|
|
||
| [For more information, see PTX ISA](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cvt) | ||
| }]; | ||
|
|
||
| let hasVerifier = 1; | ||
| let results = (outs AnyTypeOf<[I16, VectorOfLengthAndType<[2], [I8]>]>:$dst); | ||
| let arguments = (ins | ||
| CVTFP8TypeAttr:$type, | ||
| VectorOfLengthAndType<[2], [BF16]>:$a, | ||
| DefaultValuedAttr<FPRoundingModeAttr, "FPRoundingMode::NONE">:$rnd, | ||
| DefaultValuedAttr<SaturationModeAttr, "SaturationMode::NONE">:$sat); | ||
| let assemblyFormat = "$type $a attr-dict `:` type($a) `->` type($dst)"; | ||
|
|
||
| let extraClassDeclaration = [{ | ||
| static llvm::Intrinsic::ID getIntrinsicID(NVVM::FPRoundingMode rnd, | ||
| NVVM::SaturationMode sat); | ||
| }]; | ||
|
|
||
| string llvmBuilder = [{ | ||
| auto intId = NVVM::CvtBF16x2ToF8x2Op::getIntrinsicID($rnd, $sat); | ||
| llvm::Value *packedI16 = createIntrinsicCall(builder, intId, {$a}); | ||
| if(op.getDst().getType().isInteger(16)) | ||
| $dst = packedI16; | ||
| else | ||
| $dst = builder.CreateBitCast(packedI16, | ||
| llvm::FixedVectorType::get(llvm::Type::getInt8Ty(builder.getContext()), 2)); | ||
| }]; | ||
| } | ||
|
|
||
| //===----------------------------------------------------------------------===// | ||
| // NVVM MMA Ops | ||
| //===----------------------------------------------------------------------===// | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,22 +1,21 @@ | ||
| // RUN: mlir-translate -mlir-to-llvmir %s | FileCheck %s | ||
|
|
||
| // CHECK-LABEL: @convert_float_to_fp6x2_packed | ||
| llvm.func @convert_float_to_fp6x2_packed(%srcA : f32, %srcB : f32) { | ||
| // CHECK-LABEL: @convert_f32x2_to_fp6x2_packed | ||
| llvm.func @convert_f32x2_to_fp6x2_packed(%srcA : f32, %srcB : f32) { | ||
| //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}}) | ||
| %res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : i16 | ||
| %res1 = nvvm.cvt.f32x2.to.f6x2 <e2m3> %srcA, %srcB : i16 | ||
| //CHECK: %{{.*}} = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}}) | ||
| %res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : i16 | ||
| %res2 = nvvm.cvt.f32x2.to.f6x2 <e3m2> %srcA, %srcB : i16 | ||
| llvm.return | ||
| } | ||
|
|
||
| // CHECK-LABEL: @convert_float_to_fp6x2_vector | ||
| llvm.func @convert_float_to_fp6x2_vector(%srcA : f32, %srcB : f32) { | ||
| // CHECK-LABEL: @convert_f32x2_to_fp6x2_vector | ||
| llvm.func @convert_f32x2_to_fp6x2_vector(%srcA : f32, %srcB : f32) { | ||
| //CHECK: %[[res0:.*]] = call i16 @llvm.nvvm.ff.to.e2m3x2.rn.satfinite(float %{{.*}}, float %{{.*}}) | ||
| //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res0]] to <2 x i8> | ||
| %res1 = nvvm.cvt.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8> | ||
| %res1 = nvvm.cvt.f32x2.to.f6x2 <e2m3> %srcA, %srcB : vector<2xi8> | ||
| //CHECK: %[[res1:.*]] = call i16 @llvm.nvvm.ff.to.e3m2x2.rn.satfinite(float %{{.*}}, float %{{.*}}) | ||
| //CHECK-NEXT: %{{.*}} = bitcast i16 %[[res1]] to <2 x i8> | ||
| %res2 = nvvm.cvt.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8> | ||
| %res2 = nvvm.cvt.f32x2.to.f6x2 <e3m2> %srcA, %srcB : vector<2xi8> | ||
| llvm.return | ||
| } | ||
|
|
Uh oh!
There was an error while loading. Please reload this page.