Skip to content

Commit 225dd9a

Browse files
committed
support fadd, fsub, fmul, fma and load on v2f32
1 parent b2b8ee7 commit 225dd9a

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,10 +1113,14 @@ bool NVPTXDAGToDAGISel::tryLoad(SDNode *N) {
11131113
// Vector Setting
11141114
unsigned VecType = NVPTX::PTXLdStInstCode::Scalar;
11151115
if (SimpleVT.isVector()) {
1116-
assert((Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8) &&
1117-
"Unexpected vector type");
1118-
// v2f16/v2bf16/v2i16 is loaded using ld.b32
1119-
FromTypeWidth = 32;
1116+
if (Isv2x16VT(LoadedVT) || LoadedVT == MVT::v4i8)
1117+
// v2f16/v2bf16/v2i16 is loaded using ld.b32
1118+
FromTypeWidth = 32;
1119+
else if (LoadedVT == MVT::v2f32)
1120+
// v2f32 is loaded using ld.b64
1121+
FromTypeWidth = 64;
1122+
else
1123+
llvm_unreachable("Unexpected vector type");
11201124
}
11211125

11221126
if (PlainLoad && (PlainLoad->getExtensionType() == ISD::SEXTLOAD))

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,18 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
402402
(ins Float32Regs:$a, f32imm:$b),
403403
op_str # ".f32 \t$dst, $a, $b;",
404404
[(set f32:$dst, (op_pat f32:$a, fpimm:$b))]>;
405-
405+
def f32x2rr_ftz :
406+
NVPTXInst<(outs Int64Regs:$dst),
407+
(ins Int64Regs:$a, Int64Regs:$b),
408+
op_str # ".ftz.f32x2 \t$dst, $a, $b;",
409+
[(set v2f32:$dst, (op_pat v2f32:$a, v2f32:$b))]>,
410+
Requires<[doF32FTZ, hasF32x2Instructions]>;
411+
def f32x2rr :
412+
NVPTXInst<(outs Int64Regs:$dst),
413+
(ins Int64Regs:$a, Int64Regs:$b),
414+
op_str # ".f32x2 \t$dst, $a, $b;",
415+
[(set v2f32:$dst, (op_pat v2f32:$a, v2f32:$b))]>,
416+
Requires<[hasF32x2Instructions]>;
406417
def f16rr_ftz :
407418
NVPTXInst<(outs Int16Regs:$dst),
408419
(ins Int16Regs:$a, Int16Regs:$b),
@@ -434,7 +445,6 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
434445
op_str # ".bf16 \t$dst, $a, $b;",
435446
[(set bf16:$dst, (op_pat bf16:$a, bf16:$b))]>,
436447
Requires<[hasBF16Math]>;
437-
438448
def bf16x2rr :
439449
NVPTXInst<(outs Int32Regs:$dst),
440450
(ins Int32Regs:$a, Int32Regs:$b),
@@ -1348,6 +1358,13 @@ multiclass FMA_BF16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred
13481358
Requires<[hasBF16Math, Pred]>;
13491359
}
13501360

1361+
class FMA_F32x2<string OpcStr, Predicate Pred>
1362+
: NVPTXInst<(outs Int64Regs:$res),
1363+
(ins Int64Regs:$a, Int64Regs:$b, Int64Regs:$c),
1364+
OpcStr # ".f32x2 \t$res, $a, $b, $c;",
1365+
[(set v2f32:$res, (fma v2f32:$a, v2f32:$b, v2f32:$c))]>,
1366+
Requires<[hasF32x2Instructions, Pred]>;
1367+
13511368
defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>;
13521369
defm FMA16 : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>;
13531370
defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>;
@@ -1356,6 +1373,8 @@ defm BFMA16 : FMA_BF16<"fma.rn.bf16", bf16, Int16Regs, True>;
13561373
defm BFMA16x2 : FMA_BF16<"fma.rn.bf16x2", v2bf16, Int32Regs, True>;
13571374
defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
13581375
defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
1376+
def FMA32x2_ftz : FMA_F32x2<"fma.rn.ftz", doF32FTZ>;
1377+
def FMA32x2 : FMA_F32x2<"fma.rn", True>;
13591378
defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True>;
13601379

13611380
// sin/cos

0 commit comments

Comments
 (0)