@@ -402,7 +402,18 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
402
402
(ins Float32Regs:$a, f32imm:$b),
403
403
op_str # ".f32 \t$dst, $a, $b;",
404
404
[(set f32:$dst, (op_pat f32:$a, fpimm:$b))]>;
405
-
405
+ def f32x2rr_ftz :
406
+ NVPTXInst<(outs Int64Regs:$dst),
407
+ (ins Int64Regs:$a, Int64Regs:$b),
408
+ op_str # ".ftz.f32x2 \t$dst, $a, $b;",
409
+ [(set v2f32:$dst, (op_pat v2f32:$a, v2f32:$b))]>,
410
+ Requires<[doF32FTZ, hasF32x2Instructions]>;
411
+ def f32x2rr :
412
+ NVPTXInst<(outs Int64Regs:$dst),
413
+ (ins Int64Regs:$a, Int64Regs:$b),
414
+ op_str # ".f32x2 \t$dst, $a, $b;",
415
+ [(set v2f32:$dst, (op_pat v2f32:$a, v2f32:$b))]>,
416
+ Requires<[hasF32x2Instructions]>;
406
417
def f16rr_ftz :
407
418
NVPTXInst<(outs Int16Regs:$dst),
408
419
(ins Int16Regs:$a, Int16Regs:$b),
@@ -434,7 +445,6 @@ multiclass F3<string op_str, SDPatternOperator op_pat> {
434
445
op_str # ".bf16 \t$dst, $a, $b;",
435
446
[(set bf16:$dst, (op_pat bf16:$a, bf16:$b))]>,
436
447
Requires<[hasBF16Math]>;
437
-
438
448
def bf16x2rr :
439
449
NVPTXInst<(outs Int32Regs:$dst),
440
450
(ins Int32Regs:$a, Int32Regs:$b),
@@ -1348,6 +1358,13 @@ multiclass FMA_BF16<string OpcStr, ValueType T, RegisterClass RC, Predicate Pred
1348
1358
Requires<[hasBF16Math, Pred]>;
1349
1359
}
1350
1360
1361
+ class FMA_F32x2<string OpcStr, Predicate Pred>
1362
+ : NVPTXInst<(outs Int64Regs:$res),
1363
+ (ins Int64Regs:$a, Int64Regs:$b, Int64Regs:$c),
1364
+ OpcStr # ".f32x2 \t$res, $a, $b, $c;",
1365
+ [(set v2f32:$res, (fma v2f32:$a, v2f32:$b, v2f32:$c))]>,
1366
+ Requires<[hasF32x2Instructions, Pred]>;
1367
+
1351
1368
defm FMA16_ftz : FMA_F16<"fma.rn.ftz.f16", f16, Int16Regs, doF32FTZ>;
1352
1369
defm FMA16 : FMA_F16<"fma.rn.f16", f16, Int16Regs, True>;
1353
1370
defm FMA16x2_ftz : FMA_F16<"fma.rn.ftz.f16x2", v2f16, Int32Regs, doF32FTZ>;
@@ -1356,6 +1373,8 @@ defm BFMA16 : FMA_BF16<"fma.rn.bf16", bf16, Int16Regs, True>;
1356
1373
defm BFMA16x2 : FMA_BF16<"fma.rn.bf16x2", v2bf16, Int32Regs, True>;
1357
1374
defm FMA32_ftz : FMA<"fma.rn.ftz.f32", Float32Regs, f32imm, doF32FTZ>;
1358
1375
defm FMA32 : FMA<"fma.rn.f32", Float32Regs, f32imm, True>;
1376
+ def FMA32x2_ftz : FMA_F32x2<"fma.rn.ftz", doF32FTZ>;
1377
+ def FMA32x2 : FMA_F32x2<"fma.rn", True>;
1359
1378
defm FMA64 : FMA<"fma.rn.f64", Float64Regs, f64imm, True>;
1360
1379
1361
1380
// sin/cos
0 commit comments