Skip to content

Commit b2b8ee7

Browse files
committed
legalize v2f32 as i64 reg and add test cases
1 parent 0e3049c commit b2b8ee7

File tree

6 files changed

+416
-5
lines changed

6 files changed

+416
-5
lines changed

llvm/lib/Target/NVPTX/NVPTXISelDAGToDAG.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1041,6 +1041,7 @@ pickOpcodeForVT(MVT::SimpleValueType VT, unsigned Opcode_i8,
10411041
case MVT::i32:
10421042
return Opcode_i32;
10431043
case MVT::i64:
1044+
case MVT::v2f32:
10441045
return Opcode_i64;
10451046
case MVT::f16:
10461047
case MVT::bf16:

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,8 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
292292
// TargetLoweringBase::getVectorTypeBreakdown() which is invoked in
293293
// ComputePTXValueVTs() cannot currently break down non-power-of-2 sized
294294
// vectors.
295-
if ((Is16bitsType(EltVT.getSimpleVT())) && NumElts % 2 == 0 &&
296-
isPowerOf2_32(NumElts)) {
295+
if ((Is16bitsType(EltVT.getSimpleVT()) || EltVT == MVT::f32) &&
296+
NumElts % 2 == 0 && isPowerOf2_32(NumElts)) {
297297
// Vectors with an even number of f16 elements will be passed to
298298
// us as an array of v2f16/v2bf16 elements. We must match this so we
299299
// stay in sync with Ins/Outs.
@@ -307,6 +307,9 @@ static void ComputePTXValueVTs(const TargetLowering &TLI, const DataLayout &DL,
307307
case MVT::i16:
308308
EltVT = MVT::v2i16;
309309
break;
310+
case MVT::f32:
311+
EltVT = MVT::v2f32;
312+
break;
310313
default:
311314
llvm_unreachable("Unexpected type");
312315
}
@@ -580,6 +583,7 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
580583
addRegisterClass(MVT::v2f16, &NVPTX::Int32RegsRegClass);
581584
addRegisterClass(MVT::bf16, &NVPTX::Int16RegsRegClass);
582585
addRegisterClass(MVT::v2bf16, &NVPTX::Int32RegsRegClass);
586+
addRegisterClass(MVT::v2f32, &NVPTX::Int64RegsRegClass);
583587

584588
// Conversion to/from FP16/FP16x2 is always legal.
585589
setOperationAction(ISD::BUILD_VECTOR, MVT::v2f16, Custom);
@@ -845,6 +849,8 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
845849
setBF16OperationAction(Op, MVT::bf16, Legal, Promote);
846850
if (getOperationAction(Op, MVT::bf16) == Promote)
847851
AddPromotedToType(Op, MVT::bf16, MVT::f32);
852+
if (STI.hasF32x2Instructions())
853+
setOperationAction(Op, MVT::v2f32, Legal);
848854
}
849855

850856
// On SM80, we select add/mul/sub as fma to avoid promotion to float
@@ -3414,6 +3420,8 @@ SDValue NVPTXTargetLowering::LowerFormalArguments(
34143420
// vectors which contain v2f16 or v2bf16 elements. So we must load
34153421
// using i32 here and then bitcast back.
34163422
LoadVT = MVT::i32;
3423+
else if (EltVT == MVT::v2f32)
3424+
LoadVT = MVT::i64;
34173425

34183426
EVT VecVT = EVT::getVectorVT(F->getContext(), LoadVT, NumElts);
34193427
SDValue VecAddr =

llvm/lib/Target/NVPTX/NVPTXInstrInfo.td

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def hasHWROT32 : Predicate<"Subtarget->hasHWROT32()">;
160160
def noHWROT32 : Predicate<"!Subtarget->hasHWROT32()">;
161161
def hasDotInstructions : Predicate<"Subtarget->hasDotInstructions()">;
162162
def hasTcgen05Instructions : Predicate<"Subtarget->hasTcgen05Instructions()">;
163+
def hasF32x2Instructions : Predicate<"Subtarget->hasF32x2Instructions()">;
163164

164165
def True : Predicate<"true">;
165166
def False : Predicate<"false">;
@@ -2452,13 +2453,13 @@ class LastCallArgInstVT<NVPTXRegClass regclass, ValueType vt> :
24522453
NVPTXInst<(outs), (ins regclass:$a), "$a",
24532454
[(LastCallArg (i32 0), vt:$a)]>;
24542455

2455-
def CallArgI64 : CallArgInst<Int64Regs>;
2456+
def CallArgI64 : CallArgInstVT<Int64Regs, i64>;
24562457
def CallArgI32 : CallArgInstVT<Int32Regs, i32>;
24572458
def CallArgI16 : CallArgInstVT<Int16Regs, i16>;
24582459
def CallArgF64 : CallArgInst<Float64Regs>;
24592460
def CallArgF32 : CallArgInst<Float32Regs>;
24602461

2461-
def LastCallArgI64 : LastCallArgInst<Int64Regs>;
2462+
def LastCallArgI64 : LastCallArgInstVT<Int64Regs, i64>;
24622463
def LastCallArgI32 : LastCallArgInstVT<Int32Regs, i32>;
24632464
def LastCallArgI16 : LastCallArgInstVT<Int16Regs, i16>;
24642465
def LastCallArgF64 : LastCallArgInst<Float64Regs>;
@@ -2975,6 +2976,9 @@ let hasSideEffects = false in {
29752976
def V2F32toF64 : NVPTXInst<(outs Float64Regs:$d),
29762977
(ins Float32Regs:$s1, Float32Regs:$s2),
29772978
"mov.b64 \t$d, {{$s1, $s2}};", []>;
2979+
def V2F32toI64 : NVPTXInst<(outs Int64Regs:$d),
2980+
(ins Float32Regs:$s1, Float32Regs:$s2),
2981+
"mov.b64 \t$d, {{$s1, $s2}};", []>;
29782982

29792983
// unpack a larger int register to a set of smaller int registers
29802984
def I64toV4I16 : NVPTXInst<(outs Int16Regs:$d1, Int16Regs:$d2,
@@ -3039,6 +3043,8 @@ def : Pat<(v2bf16 (build_vector bf16:$a, bf16:$b)),
30393043
(V2I16toI32 $a, $b)>;
30403044
def : Pat<(v2i16 (build_vector i16:$a, i16:$b)),
30413045
(V2I16toI32 $a, $b)>;
3046+
def : Pat<(v2f32 (build_vector f32:$a, f32:$b)),
3047+
(V2F32toI64 $a, $b)>;
30423048

30433049
def: Pat<(v2i16 (scalar_to_vector i16:$a)),
30443050
(CVT_u32_u16 $a, CvtNONE)>;

llvm/lib/Target/NVPTX/NVPTXRegisterInfo.td

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,9 @@ def Int16Regs : NVPTXRegClass<[i16, f16, bf16], 16, (add (sequence "RS%u", 0, 4)
6262
def Int32Regs : NVPTXRegClass<[i32, v2f16, v2bf16, v2i16, v4i8], 32,
6363
(add (sequence "R%u", 0, 4),
6464
VRFrame32, VRFrameLocal32)>;
65-
def Int64Regs : NVPTXRegClass<[i64], 64, (add (sequence "RL%u", 0, 4), VRFrame64, VRFrameLocal64)>;
65+
def Int64Regs : NVPTXRegClass<[i64, v2f32], 64,
66+
(add (sequence "RL%u", 0, 4),
67+
VRFrame64, VRFrameLocal64)>;
6668
// 128-bit regs are not defined as general regs in NVPTX. They are used for inlineASM only.
6769
def Int128Regs : NVPTXRegClass<[i128], 128, (add (sequence "RQ%u", 0, 4))>;
6870
def Float32Regs : NVPTXRegClass<[f32], 32, (add (sequence "F%u", 0, 4))>;

llvm/lib/Target/NVPTX/NVPTXSubtarget.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ class NVPTXSubtarget : public NVPTXGenSubtargetInfo {
112112
return HasTcgen05 && PTXVersion >= 86;
113113
}
114114

115+
bool hasF32x2Instructions() const {
116+
return SmVersion >= 100 && PTXVersion >= 86;
117+
}
118+
115119
// Prior to CUDA 12.3 ptxas did not recognize that the trap instruction
116120
// terminates a basic block. Instead, it would assume that control flow
117121
// continued to the next instruction. The next instruction could be in the

0 commit comments

Comments
 (0)