Skip to content

Commit 4c1a649

Browse files
committed
support fadd, fsub, fmul, fma and load on v2f32
1 parent 4e0c273 commit 4c1a649

File tree

2 files changed

+23
-6
lines changed

2 files changed

+23
-6
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1097,10 +1097,14 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
10971097
// Vector Setting
10981098
unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
10991099
if (SimpleVT.isVector()) {
1100-
assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
1101-
"Unexpected vector type");
1102-
// v2f16/v2bf16/v2i16 is loaded using ld.b32
1103-
FromTypeWidth = 32;
1100+
if (Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8)
1101+
// v2f16/v2bf16/v2i16 is loaded using ld.b32
1102+
FromTypeWidth = 32;
1103+
else if (LoadedVT == MVT::v2f32)
1104+
// v2f32 is loaded using ld.b64
1105+
FromTypeWidth = 64;
1106+
else
1107+
llvm_unreachable("Unexpected vector type");
11041108
}
11051109

11061110
if (PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD))

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,7 @@ def F64RT : RegTyInfo<f64, Float64Regs, f64imm, fpimm>;
237237
def F16RT : RegTyInfo<f16, Int16Regs, f16imm, fpimm, supports_imm = 0>;
238238
def BF16RT : RegTyInfo<bf16, Int16Regs, bf16imm, fpimm, supports_imm = 0>;
239239

240+
def F32X2RT : RegTyInfo<v2f32, Int64Regs, ?, ?, supports_imm = 0>;
240241
def F16X2RT : RegTyInfo<v2f16, Int32Regs, ?, ?, supports_imm = 0>;
241242
def BF16X2RT : RegTyInfo<v2bf16, Int32Regs, ?, ?, supports_imm = 0>;
242243

@@ -415,7 +416,18 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
415416
(ins Float32Regs:$a, f32imm:$b),
416417
op_str # ".f32 \t$dst, $a, $b;",
417418
[(set f32:$dst, (op_pat f32:$a, fpimm:$b))]>;
418-
419+
def f32x2rr_ftz :
420+
NVPTXInst<(outs Int64Regs:$dst),
421+
(ins Int64Regs:$a, Int64Regs:$b),
422+
op_str # ".ftz.f32x2 \t$dst, $a, $b;",
423+
[(set v2f32:$dst, (op_pat v2f32:$a, v2f32:$b))]>,
424+
Requires<[doF32FTZ, hasF32x2Instructions]>;
425+
def f32x2rr :
426+
NVPTXInst<(outs Int64Regs:$dst),
427+
(ins Int64Regs:$a, Int64Regs:$b),
428+
op_str # ".f32x2 \t$dst, $a, $b;",
429+
[(set v2f32:$dst, (op_pat v2f32:$a, v2f32:$b))]>,
430+
Requires<[hasF32x2Instructions]>;
419431
def f16rr_ftz :
420432
NVPTXInst<(outs Int16Regs:$dst),
421433
(ins Int16Regs:$a, Int16Regs:$b),
@@ -447,7 +459,6 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
447459
op_str # ".bf16 \t$dst, $a, $b;",
448460
[(set bf16:$dst, (op_pat bf16:$a, bf16:$b))]>,
449461
Requires<[hasBF16Math]>;
450-
451462
def bf16x2rr :
452463
NVPTXInst<(outs Int32Regs:$dst),
453464
(ins Int32Regs:$a, Int32Regs:$b),
@@ -1370,6 +1381,8 @@ defm BFMA16 : FMA<"fma.rn.bf16", BF16RT, [hasBF16Math]>;
13701381
defm BFMA16x2 : FMA<"fma.rn.bf16x2", BF16X2RT, [hasBF16Math]>;
13711382
defm FMA32_ftz : FMA<"fma.rn.ftz.f32", F32RT, [doF32FTZ]>;
13721383
defm FMA32 : FMA<"fma.rn.f32", F32RT>;
1384+
defm FMA32x2_ftz : FMA<"fma.rn.ftz.f32x2", F32X2RT, [doF32FTZ]>;
1385+
defm FMA32x2 : FMA<"fma.rn.f32x2", F32X2RT>;
13731386
defm FMA64 : FMA<"fma.rn.f64", F64RT>;
13741387

13751388
// sin/cos

0 commit comments

Comments
 (0)