Skip to content

Commit b4a0d7e

Browse files
authored
[NVPTX] Fix PTX and SM conditions for narrow FP conversions (#168680)
This change fixes the PTX and SM conditions for narrow FP conversion intrinsics and adds support for family-conditionals.
1 parent 5c5c83d commit b4a0d7e

File tree

2 files changed

+51
-28
lines changed

2 files changed

+51
-28
lines changed

llvm/lib/Target/NVPTX/NVPTXIntrinsics.td

Lines changed: 30 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2071,34 +2071,36 @@ def : Pat<(int_nvvm_ull2d_rp i64:$a), (CVT_f64_u64 $a, CvtRP)>;
20712071
def : Pat<(int_nvvm_f2h_rn_ftz f32:$a), (CVT_f16_f32 $a, CvtRN_FTZ)>;
20722072
def : Pat<(int_nvvm_f2h_rn f32:$a), (CVT_f16_f32 $a, CvtRN)>;
20732073

2074-
def : Pat<(int_nvvm_ff_to_e4m3x2_rn f32:$a, f32:$b),
2075-
(CVT_e4m3x2_f32 $a, $b, CvtRN)>;
2076-
def : Pat<(int_nvvm_ff_to_e4m3x2_rn_relu f32:$a, f32:$b),
2077-
(CVT_e4m3x2_f32 $a, $b, CvtRN_RELU)>;
2078-
def : Pat<(int_nvvm_ff_to_e5m2x2_rn f32:$a, f32:$b),
2079-
(CVT_e5m2x2_f32 $a, $b, CvtRN)>;
2080-
def : Pat<(int_nvvm_ff_to_e5m2x2_rn_relu f32:$a, f32:$b),
2081-
(CVT_e5m2x2_f32 $a, $b, CvtRN_RELU)>;
2082-
2083-
def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn v2f16:$a),
2084-
(CVT_e4m3x2_f16x2 $a, CvtRN)>;
2085-
def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn_relu v2f16:$a),
2086-
(CVT_e4m3x2_f16x2 $a, CvtRN_RELU)>;
2087-
def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn v2f16:$a),
2088-
(CVT_e5m2x2_f16x2 $a, CvtRN)>;
2089-
def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn_relu v2f16:$a),
2090-
(CVT_e5m2x2_f16x2 $a, CvtRN_RELU)>;
2091-
2092-
def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn i16:$a),
2093-
(CVT_f16x2_e4m3x2 $a, CvtRN)>;
2094-
def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn_relu i16:$a),
2095-
(CVT_f16x2_e4m3x2 $a, CvtRN_RELU)>;
2096-
def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn i16:$a),
2097-
(CVT_f16x2_e5m2x2 $a, CvtRN)>;
2098-
def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu i16:$a),
2099-
(CVT_f16x2_e5m2x2 $a, CvtRN_RELU)>;
2100-
2101-
let Predicates = [hasPTX<86>, hasSM<100>, hasArchAccelFeatures] in {
2074+
let Predicates = [callSubtarget<"hasFP8ConversionSupport">] in {
2075+
def : Pat<(int_nvvm_ff_to_e4m3x2_rn f32:$a, f32:$b),
2076+
(CVT_e4m3x2_f32 $a, $b, CvtRN)>;
2077+
def : Pat<(int_nvvm_ff_to_e4m3x2_rn_relu f32:$a, f32:$b),
2078+
(CVT_e4m3x2_f32 $a, $b, CvtRN_RELU)>;
2079+
def : Pat<(int_nvvm_ff_to_e5m2x2_rn f32:$a, f32:$b),
2080+
(CVT_e5m2x2_f32 $a, $b, CvtRN)>;
2081+
def : Pat<(int_nvvm_ff_to_e5m2x2_rn_relu f32:$a, f32:$b),
2082+
(CVT_e5m2x2_f32 $a, $b, CvtRN_RELU)>;
2083+
2084+
def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn v2f16:$a),
2085+
(CVT_e4m3x2_f16x2 $a, CvtRN)>;
2086+
def : Pat<(int_nvvm_f16x2_to_e4m3x2_rn_relu v2f16:$a),
2087+
(CVT_e4m3x2_f16x2 $a, CvtRN_RELU)>;
2088+
def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn v2f16:$a),
2089+
(CVT_e5m2x2_f16x2 $a, CvtRN)>;
2090+
def : Pat<(int_nvvm_f16x2_to_e5m2x2_rn_relu v2f16:$a),
2091+
(CVT_e5m2x2_f16x2 $a, CvtRN_RELU)>;
2092+
2093+
def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn i16:$a),
2094+
(CVT_f16x2_e4m3x2 $a, CvtRN)>;
2095+
def : Pat<(int_nvvm_e4m3x2_to_f16x2_rn_relu i16:$a),
2096+
(CVT_f16x2_e4m3x2 $a, CvtRN_RELU)>;
2097+
def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn i16:$a),
2098+
(CVT_f16x2_e5m2x2 $a, CvtRN)>;
2099+
def : Pat<(int_nvvm_e5m2x2_to_f16x2_rn_relu i16:$a),
2100+
(CVT_f16x2_e5m2x2 $a, CvtRN_RELU)>;
2101+
}
2102+
2103+
let Predicates = [callSubtarget<"hasNarrowFPConversionSupport">] in {
21022104
def : Pat<(int_nvvm_ff_to_e2m3x2_rn_satfinite f32:$a, f32:$b),
21032105
(CVT_e2m3x2_f32_sf $a, $b, CvtRN)>;
21042106
def : Pat<(int_nvvm_ff_to_e2m3x2_rn_relu_satfinite f32:$a, f32:$b),

llvm/lib/Target/NVPTX/NVPTXSubtarget.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,27 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
177177
hasPTXWithAccelSMs(86, {100, 101});
178178
}
179179

180+
// Checks support for conversions involving e4m3x2 and e5m2x2.
181+
bool hasFP8ConversionSupport() const {
182+
if (PTXVersion >= 81)
183+
return SmVersion >= 89;
184+
185+
if (PTXVersion >= 78)
186+
return SmVersion >= 90;
187+
188+
return false;
189+
}
190+
191+
// Checks support for conversions involving the following types:
192+
// - e2m3x2/e3m2x2
193+
// - e2m1x2
194+
// - ue8m0x2
195+
bool hasNarrowFPConversionSupport() const {
196+
return hasPTXWithFamilySMs(90, {100, 110, 120}) ||
197+
hasPTXWithFamilySMs(88, {100, 101, 120}) ||
198+
hasPTXWithAccelSMs(86, {100, 101, 120});
199+
}
200+
180201
// Prior to CUDA 12.3 ptxas did not recognize that the trap instruction
181202
// terminates a basic block. Instead, it would assume that control flow
182203
// continued to the next instruction. The next instruction could be in the

0 commit comments

Comments
 (0)