@@ -158,6 +158,7 @@ def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
158158def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
159159def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
160160def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
161+ def hasF32x2Instructions : Predicate<"Subtarget->hasF32x2Instructions()">;
161162
162163def True : Predicate<"true">;
163164def False : Predicate<"false">;
@@ -193,6 +194,7 @@ class ValueToRegClass<ValueType T> {
193194 !eq(name, "bf16"): Int16Regs,
194195 !eq(name, "v2bf16"): Int32Regs,
195196 !eq(name, "f32"): Float32Regs,
197+ !eq(name, "v2f32"): Int64Regs,
196198 !eq(name, "f64"): Float64Regs,
197199 !eq(name, "ai32"): Int32ArgRegs,
198200 !eq(name, "ai64"): Int64ArgRegs,
@@ -239,6 +241,7 @@ def BF16RT : RegTyInfo<bf16, Int16Regs, bf16imm, fpimm, supports_imm = 0>;
239241
240242def F16X2RT : RegTyInfo<v2f16, Int32Regs, ?, ?, supports_imm = 0>;
241243def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
244+ def F32X2RT : RegTyInfo<v2f32, Int64Regs, ?, ?, supports_imm = 0>;
242245
243246
244247// This class provides a basic wrapper around an NVPTXInst that abstracts the
@@ -461,6 +464,18 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
461464 [(set f16:$dst, (op_pat f16:$a, f16:$b))]>,
462465 Requires<[useFP16Math]>;
463466
467+ def f32x2rr_ftz :
468+ BasicNVPTXInst<(outs Int64Regs:$dst),
469+ (ins Int64Regs:$a, Int64Regs:$b),
470+ op_str # ".ftz.f32x2",
471+ [(set v2f32:$dst, (op_pat v2f32:$a, v2f32:$b))]>,
472+ Requires<[hasF32x2Instructions, doF32FTZ]>;
473+ def f32x2rr :
474+ BasicNVPTXInst<(outs Int64Regs:$dst),
475+ (ins Int64Regs:$a, Int64Regs:$b),
476+ op_str # ".f32x2",
477+ [(set v2f32:$dst, (op_pat v2f32:$a, v2f32:$b))]>,
478+ Requires<[hasF32x2Instructions]>;
464479 def f16x2rr_ftz :
465480 BasicNVPTXInst<(outs Int32Regs:$dst),
466481 (ins Int32Regs:$a, Int32Regs:$b),
@@ -839,6 +854,9 @@ def : Pat<(vt (select i1:$p, vt:$a, vt:$b)),
839854 (SELP_b32rr $a, $b, $p)>;
840855}
841856
857+ def : Pat<(v2f32 (select i1:$p, v2f32:$a, v2f32:$b)),
858+ (SELP_b64rr $a, $b, $p)>;
859+
842860//-----------------------------------
843861// Test Instructions
844862//-----------------------------------
@@ -1387,6 +1405,8 @@ defm BFMA16 : FMA<"fma.rn.bf16", BF16RT, [hasBF16Math]>;
13871405defm BFMA16x2 : FMA<"fma.rn.bf16x2", BF16X2RT, [hasBF16Math]>;
13881406defm FMA32_ftz : FMA<"fma.rn.ftz.f32", F32RT, [doF32FTZ]>;
13891407defm FMA32 : FMA<"fma.rn.f32", F32RT>;
1408+ defm FMA32x2_ftz : FMA<"fma.rn.ftz.f32x2", F32X2RT, [hasF32x2Instructions, doF32FTZ]>;
1409+ defm FMA32x2 : FMA<"fma.rn.f32x2", F32X2RT, [hasF32x2Instructions]>;
13901410defm FMA64 : FMA<"fma.rn.f64", F64RT>;
13911411
13921412// sin/cos
@@ -2739,6 +2759,7 @@ def : Pat<(i32 (trunc (sra i64:$s, (i32 32)))), (I64toI32H $s)>;
27392759def: Pat<(i32 (sext (extractelt v2i16:$src, 0))),
27402760 (CVT_INREG_s32_s16 $src)>;
27412761
2762+ // Handle extracting one element from the pair (32-bit types)
27422763foreach vt = [v2f16, v2bf16, v2i16] in {
27432764 def : Pat<(extractelt vt:$src, 0), (I32toI16L_Sink $src)>, Requires<[hasPTX<71>]>;
27442765 def : Pat<(extractelt vt:$src, 1), (I32toI16H_Sink $src)>, Requires<[hasPTX<71>]>;
@@ -2750,10 +2771,21 @@ foreach vt = [v2f16, v2bf16, v2i16] in {
27502771 (V2I16toI32 $a, $b)>;
27512772}
27522773
2774+ // Same thing for the 64-bit type v2f32.
2775+ foreach vt = [v2f32] in {
2776+ def : Pat<(extractelt vt:$src, 0), (I64toI32L_Sink $src)>, Requires<[hasPTX<71>]>;
2777+ def : Pat<(extractelt vt:$src, 1), (I64toI32H_Sink $src)>, Requires<[hasPTX<71>]>;
2778+
2779+ def : Pat<(extractelt vt:$src, 0), (I64toI32L $src)>;
2780+ def : Pat<(extractelt vt:$src, 1), (I64toI32H $src)>;
2781+
2782+ def : Pat<(vt (build_vector vt.ElementType:$a, vt.ElementType:$b)),
2783+ (V2I32toI64 $a, $b)>;
2784+ }
2785+
27532786def: Pat<(v2i16 (scalar_to_vector i16:$a)),
27542787 (CVT_u32_u16 $a, CvtNONE)>;
27552788
2756-
27572789def nvptx_build_vector : SDNode<"NVPTXISD::BUILD_VECTOR", SDTypeProfile<1, 2, []>, []>;
27582790
27592791def : Pat<(i64 (nvptx_build_vector i32:$a, i32:$b)),
0 commit comments