Skip to content

Commit 8ab917a

Browse files
authored
Reland "[NVPTX] Legalize aext-load to zext-load to expose more DAG combines" (#155063)
The original version of this change inadvertently dropped b6e19b3. This version retains that fix as well as adding tests for it and an explanation for why it is needed.
1 parent 2a586a8 commit 8ab917a

File tree

13 files changed

+877
-868
lines changed

13 files changed

+877
-868
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15137,7 +15137,7 @@ SDValue DAGCombiner::visitANY_EXTEND(SDNode *N) {
1513715137
return foldedExt;
1513815138
} else if (ISD::isNON_EXTLoad(N0.getNode()) &&
1513915139
ISD::isUNINDEXEDLoad(N0.getNode()) &&
15140-
TLI.isLoadExtLegal(ISD::EXTLOAD, VT, N0.getValueType())) {
15140+
TLI.isLoadExtLegalOrCustom(ISD::EXTLOAD, VT, N0.getValueType())) {
1514115141
bool DoXform = true;
1514215142
SmallVector<SDNode *, 4> SetCCs;
1514315143
if (!N0.hasOneUse())

llvm/lib/Target/NVPTX/NVPTXISelLowering.cpp

Lines changed: 74 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -702,57 +702,66 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
702702
// intrinsics.
703703
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
704704

705-
// Turn FP extload into load/fpextend
706-
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::f16, Expand);
707-
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f16, Expand);
708-
setLoadExtAction(ISD::EXTLOAD, MVT::f32, MVT::bf16, Expand);
709-
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::bf16, Expand);
710-
setLoadExtAction(ISD::EXTLOAD, MVT::f64, MVT::f32, Expand);
711-
setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2f16, Expand);
712-
setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f16, Expand);
713-
setLoadExtAction(ISD::EXTLOAD, MVT::v2f32, MVT::v2bf16, Expand);
714-
setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2bf16, Expand);
715-
setLoadExtAction(ISD::EXTLOAD, MVT::v2f64, MVT::v2f32, Expand);
716-
setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4f16, Expand);
717-
setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f16, Expand);
718-
setLoadExtAction(ISD::EXTLOAD, MVT::v4f32, MVT::v4bf16, Expand);
719-
setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4bf16, Expand);
720-
setLoadExtAction(ISD::EXTLOAD, MVT::v4f64, MVT::v4f32, Expand);
721-
setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8f16, Expand);
722-
setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8f16, Expand);
723-
setLoadExtAction(ISD::EXTLOAD, MVT::v8f32, MVT::v8bf16, Expand);
724-
setLoadExtAction(ISD::EXTLOAD, MVT::v8f64, MVT::v8bf16, Expand);
725-
// Turn FP truncstore into trunc + store.
726-
// FIXME: vector types should also be expanded
727-
setTruncStoreAction(MVT::f32, MVT::f16, Expand);
728-
setTruncStoreAction(MVT::f64, MVT::f16, Expand);
729-
setTruncStoreAction(MVT::f32, MVT::bf16, Expand);
730-
setTruncStoreAction(MVT::f64, MVT::bf16, Expand);
731-
setTruncStoreAction(MVT::f64, MVT::f32, Expand);
732-
setTruncStoreAction(MVT::v2f32, MVT::v2f16, Expand);
733-
setTruncStoreAction(MVT::v2f32, MVT::v2bf16, Expand);
705+
// FP extload/truncstore is not legal in PTX. We need to expand all these.
706+
for (auto FloatVTs :
707+
{MVT::fp_valuetypes(), MVT::fp_fixedlen_vector_valuetypes()}) {
708+
for (MVT ValVT : FloatVTs) {
709+
for (MVT MemVT : FloatVTs) {
710+
setLoadExtAction(ISD::EXTLOAD, ValVT, MemVT, Expand);
711+
setTruncStoreAction(ValVT, MemVT, Expand);
712+
}
713+
}
714+
}
734715

735-
// PTX does not support load / store predicate registers
736-
setOperationAction(ISD::LOAD, MVT::i1, Custom);
737-
setOperationAction(ISD::STORE, MVT::i1, Custom);
716+
// To improve CodeGen we'll legalize any-extend loads to zext loads. This is
717+
// how they'll be lowered in ISel anyway, and by doing this a little earlier
718+
// we allow for more DAG combine opportunities.
719+
for (auto IntVTs :
720+
{MVT::integer_valuetypes(), MVT::integer_fixedlen_vector_valuetypes()})
721+
for (MVT ValVT : IntVTs)
722+
for (MVT MemVT : IntVTs)
723+
if (isTypeLegal(ValVT))
724+
setLoadExtAction(ISD::EXTLOAD, ValVT, MemVT, Custom);
738725

726+
// PTX does not support load / store predicate registers
727+
setOperationAction({ISD::LOAD, ISD::STORE}, MVT::i1, Custom);
739728
for (MVT VT : MVT::integer_valuetypes()) {
740-
setLoadExtAction(ISD::SEXTLOAD, VT, MVT::i1, Promote);
741-
setLoadExtAction(ISD::ZEXTLOAD, VT, MVT::i1, Promote);
742-
setLoadExtAction(ISD::EXTLOAD, VT, MVT::i1, Promote);
729+
setLoadExtAction({ISD::SEXTLOAD, ISD::ZEXTLOAD, ISD::EXTLOAD}, VT, MVT::i1,
730+
Promote);
743731
setTruncStoreAction(VT, MVT::i1, Expand);
744732
}
745733

734+
// Disable generations of extload/truncstore for v2i16/v2i8. The generic
735+
// expansion for these nodes when they are unaligned is incorrect if the
736+
// type is a vector.
737+
//
738+
// TODO: Fix the generic expansion for these nodes found in
739+
// TargetLowering::expandUnalignedLoad/Store.
740+
setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v2i16,
741+
MVT::v2i8, Expand);
742+
setTruncStoreAction(MVT::v2i16, MVT::v2i8, Expand);
743+
744+
// Register custom handling for illegal type loads/stores. We'll try to custom
745+
// lower almost all illegal types and logic in the lowering will discard cases
746+
// we can't handle.
747+
setOperationAction({ISD::LOAD, ISD::STORE}, {MVT::i128, MVT::f128}, Custom);
748+
for (MVT VT : MVT::fixedlen_vector_valuetypes())
749+
if (!isTypeLegal(VT) && VT.getStoreSizeInBits() <= 256)
750+
setOperationAction({ISD::STORE, ISD::LOAD}, VT, Custom);
751+
752+
// Custom legalization for LDU intrinsics.
753+
// TODO: The logic to lower these is not very robust and we should rewrite it.
754+
// Perhaps LDU should not be represented as an intrinsic at all.
755+
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
756+
for (MVT VT : MVT::fixedlen_vector_valuetypes())
757+
if (IsPTXVectorType(VT))
758+
setOperationAction(ISD::INTRINSIC_W_CHAIN, VT, Custom);
759+
746760
setCondCodeAction({ISD::SETNE, ISD::SETEQ, ISD::SETUGE, ISD::SETULE,
747761
ISD::SETUGT, ISD::SETULT, ISD::SETGT, ISD::SETLT,
748762
ISD::SETGE, ISD::SETLE},
749763
MVT::i1, Expand);
750764

751-
// expand extload of vector of integers.
752-
setLoadExtAction({ISD::EXTLOAD, ISD::SEXTLOAD, ISD::ZEXTLOAD}, MVT::v2i16,
753-
MVT::v2i8, Expand);
754-
setTruncStoreAction(MVT::v2i16, MVT::v2i8, Expand);
755-
756765
// This is legal in NVPTX
757766
setOperationAction(ISD::ConstantFP, MVT::f64, Legal);
758767
setOperationAction(ISD::ConstantFP, MVT::f32, Legal);
@@ -767,24 +776,12 @@ NVPTXTargetLowering::NVPTXTargetLowering(const NVPTXTargetMachine &TM,
767776
// DEBUGTRAP can be lowered to PTX brkpt
768777
setOperationAction(ISD::DEBUGTRAP, MVT::Other, Legal);
769778

770-
// Register custom handling for vector loads/stores
771-
for (MVT VT : MVT::fixedlen_vector_valuetypes())
772-
if (IsPTXVectorType(VT))
773-
setOperationAction({ISD::LOAD, ISD::STORE, ISD::INTRINSIC_W_CHAIN}, VT,
774-
Custom);
775-
776-
setOperationAction({ISD::LOAD, ISD::STORE, ISD::INTRINSIC_W_CHAIN},
777-
{MVT::i128, MVT::f128}, Custom);
778-
779779
// Support varargs.
780780
setOperationAction(ISD::VASTART, MVT::Other, Custom);
781781
setOperationAction(ISD::VAARG, MVT::Other, Custom);
782782
setOperationAction(ISD::VACOPY, MVT::Other, Expand);
783783
setOperationAction(ISD::VAEND, MVT::Other, Expand);
784784

785-
// Custom handling for i8 intrinsics
786-
setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::i8, Custom);
787-
788785
setOperationAction({ISD::ABS, ISD::SMIN, ISD::SMAX, ISD::UMIN, ISD::UMAX},
789786
{MVT::i16, MVT::i32, MVT::i64}, Legal);
790787

@@ -3092,39 +3089,14 @@ static void replaceLoadVector(SDNode *N, SelectionDAG &DAG,
30923089
SmallVectorImpl<SDValue> &Results,
30933090
const NVPTXSubtarget &STI);
30943091

3095-
SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
3096-
if (Op.getValueType() == MVT::i1)
3097-
return LowerLOADi1(Op, DAG);
3098-
3099-
EVT VT = Op.getValueType();
3100-
3101-
if (NVPTX::isPackedVectorTy(VT)) {
3102-
// v2f32/v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to
3103-
// handle unaligned loads and have to handle it here.
3104-
LoadSDNode *Load = cast<LoadSDNode>(Op);
3105-
EVT MemVT = Load->getMemoryVT();
3106-
if (!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
3107-
MemVT, *Load->getMemOperand())) {
3108-
SDValue Ops[2];
3109-
std::tie(Ops[0], Ops[1]) = expandUnalignedLoad(Load, DAG);
3110-
return DAG.getMergeValues(Ops, SDLoc(Op));
3111-
}
3112-
}
3113-
3114-
return SDValue();
3115-
}
3116-
31173092
// v = ld i1* addr
31183093
// =>
31193094
// v1 = ld i8* addr (-> i16)
31203095
// v = trunc i16 to i1
3121-
SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
3122-
SDNode *Node = Op.getNode();
3123-
LoadSDNode *LD = cast<LoadSDNode>(Node);
3124-
SDLoc dl(Node);
3096+
static SDValue lowerLOADi1(LoadSDNode *LD, SelectionDAG &DAG) {
3097+
SDLoc dl(LD);
31253098
assert(LD->getExtensionType() == ISD::NON_EXTLOAD);
3126-
assert(Node->getValueType(0) == MVT::i1 &&
3127-
"Custom lowering for i1 load only");
3099+
assert(LD->getValueType(0) == MVT::i1 && "Custom lowering for i1 load only");
31283100
SDValue newLD = DAG.getExtLoad(ISD::ZEXTLOAD, dl, MVT::i16, LD->getChain(),
31293101
LD->getBasePtr(), LD->getPointerInfo(),
31303102
MVT::i8, LD->getAlign(),
@@ -3133,8 +3105,27 @@ SDValue NVPTXTargetLowering::LowerLOADi1(SDValue Op, SelectionDAG &DAG) const {
31333105
// The legalizer (the caller) is expecting two values from the legalized
31343106
// load, so we build a MergeValues node for it. See ExpandUnalignedLoad()
31353107
// in LegalizeDAG.cpp which also uses MergeValues.
3136-
SDValue Ops[] = { result, LD->getChain() };
3137-
return DAG.getMergeValues(Ops, dl);
3108+
return DAG.getMergeValues({result, LD->getChain()}, dl);
3109+
}
3110+
3111+
SDValue NVPTXTargetLowering::LowerLOAD(SDValue Op, SelectionDAG &DAG) const {
3112+
LoadSDNode *LD = cast<LoadSDNode>(Op);
3113+
3114+
if (Op.getValueType() == MVT::i1)
3115+
return lowerLOADi1(LD, DAG);
3116+
3117+
// To improve CodeGen we'll legalize any-extend loads to zext loads. This is
3118+
// how they'll be lowered in ISel anyway, and by doing this a little earlier
3119+
// we allow for more DAG combine opportunities.
3120+
if (LD->getExtensionType() == ISD::EXTLOAD) {
3121+
assert(LD->getValueType(0).isInteger() && LD->getMemoryVT().isInteger() &&
3122+
"Unexpected fpext-load");
3123+
return DAG.getExtLoad(ISD::ZEXTLOAD, SDLoc(Op), Op.getValueType(),
3124+
LD->getChain(), LD->getBasePtr(), LD->getMemoryVT(),
3125+
LD->getMemOperand());
3126+
}
3127+
3128+
llvm_unreachable("Unexpected custom lowering for load");
31383129
}
31393130

31403131
SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
@@ -3144,17 +3135,6 @@ SDValue NVPTXTargetLowering::LowerSTORE(SDValue Op, SelectionDAG &DAG) const {
31443135
if (VT == MVT::i1)
31453136
return LowerSTOREi1(Op, DAG);
31463137

3147-
// v2f32/v2f16/v2bf16/v2i16/v4i8 are legal, so we can't rely on legalizer to
3148-
// handle unaligned stores and have to handle it here.
3149-
if (NVPTX::isPackedVectorTy(VT) &&
3150-
!allowsMemoryAccessForAlignment(*DAG.getContext(), DAG.getDataLayout(),
3151-
VT, *Store->getMemOperand()))
3152-
return expandUnalignedStore(Store, DAG);
3153-
3154-
// v2f16/v2bf16/v2i16 don't need special handling.
3155-
if (NVPTX::isPackedVectorTy(VT) && VT.is32BitVector())
3156-
return SDValue();
3157-
31583138
// Lower store of any other vector type, including v2f32 as we want to break
31593139
// it apart since this is not a widely-supported type.
31603140
return LowerSTOREVector(Op, DAG);
@@ -4010,14 +3990,8 @@ bool NVPTXTargetLowering::getTgtMemIntrinsic(
40103990
case Intrinsic::nvvm_ldu_global_i:
40113991
case Intrinsic::nvvm_ldu_global_f:
40123992
case Intrinsic::nvvm_ldu_global_p: {
4013-
auto &DL = I.getDataLayout();
40143993
Info.opc = ISD::INTRINSIC_W_CHAIN;
4015-
if (Intrinsic == Intrinsic::nvvm_ldu_global_i)
4016-
Info.memVT = getValueType(DL, I.getType());
4017-
else if(Intrinsic == Intrinsic::nvvm_ldu_global_p)
4018-
Info.memVT = getPointerTy(DL);
4019-
else
4020-
Info.memVT = getValueType(DL, I.getType());
3994+
Info.memVT = getValueType(I.getDataLayout(), I.getType());
40213995
Info.ptrVal = I.getArgOperand(0);
40223996
Info.offset = 0;
40233997
Info.flags = MachineMemOperand::MOLoad;

llvm/lib/Target/NVPTX/NVPTXISelLowering.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -309,8 +309,6 @@ class NVPTXTargetLowering : public TargetLowering {
309309
SDValue LowerFP_EXTEND(SDValue Op, SelectionDAG &DAG) const;
310310

311311
SDValue LowerLOAD(SDValue Op, SelectionDAG &DAG) const;
312-
SDValue LowerLOADi1(SDValue Op, SelectionDAG &DAG) const;
313-
314312
SDValue LowerSTORE(SDValue Op, SelectionDAG &DAG) const;
315313
SDValue LowerSTOREi1(SDValue Op, SelectionDAG &DAG) const;
316314
SDValue LowerSTOREVector(SDValue Op, SelectionDAG &DAG) const;

llvm/test/CodeGen/Mips/implicit-sret.ll

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,7 @@ define internal void @test() unnamed_addr nounwind {
1919
; CHECK-NEXT: ld $6, 24($sp)
2020
; CHECK-NEXT: ld $5, 16($sp)
2121
; CHECK-NEXT: ld $7, 32($sp)
22-
; CHECK-NEXT: lw $1, 0($sp)
23-
; CHECK-NEXT: # implicit-def: $a0_64
24-
; CHECK-NEXT: move $4, $1
22+
; CHECK-NEXT: lw $4, 0($sp)
2523
; CHECK-NEXT: jal use_sret
2624
; CHECK-NEXT: nop
2725
; CHECK-NEXT: ld $ra, 56($sp) # 8-byte Folded Reload
@@ -64,15 +62,9 @@ define internal void @test2() unnamed_addr nounwind {
6462
; CHECK-NEXT: daddiu $4, $sp, 0
6563
; CHECK-NEXT: jal implicit_sret_decl2
6664
; CHECK-NEXT: nop
67-
; CHECK-NEXT: lw $1, 20($sp)
68-
; CHECK-NEXT: lw $2, 12($sp)
69-
; CHECK-NEXT: lw $3, 4($sp)
70-
; CHECK-NEXT: # implicit-def: $a0_64
71-
; CHECK-NEXT: move $4, $3
72-
; CHECK-NEXT: # implicit-def: $a1_64
73-
; CHECK-NEXT: move $5, $2
74-
; CHECK-NEXT: # implicit-def: $a2_64
75-
; CHECK-NEXT: move $6, $1
65+
; CHECK-NEXT: lw $6, 20($sp)
66+
; CHECK-NEXT: lw $5, 12($sp)
67+
; CHECK-NEXT: lw $4, 4($sp)
7668
; CHECK-NEXT: jal use_sret2
7769
; CHECK-NEXT: nop
7870
; CHECK-NEXT: ld $ra, 24($sp) # 8-byte Folded Reload

llvm/test/CodeGen/Mips/msa/basic_operations.ll

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1904,7 +1904,7 @@ define void @insert_v16i8_vidx(i32 signext %a) nounwind {
19041904
; N64-NEXT: daddu $1, $1, $25
19051905
; N64-NEXT: daddiu $1, $1, %lo(%neg(%gp_rel(insert_v16i8_vidx)))
19061906
; N64-NEXT: ld $2, %got_disp(i32)($1)
1907-
; N64-NEXT: lw $2, 0($2)
1907+
; N64-NEXT: lwu $2, 0($2)
19081908
; N64-NEXT: andi $2, $2, 15
19091909
; N64-NEXT: ld $1, %got_disp(v16i8)($1)
19101910
; N64-NEXT: daddu $1, $1, $2
@@ -1953,7 +1953,7 @@ define void @insert_v8i16_vidx(i32 signext %a) nounwind {
19531953
; N64-NEXT: daddu $1, $1, $25
19541954
; N64-NEXT: daddiu $1, $1, %lo(%neg(%gp_rel(insert_v8i16_vidx)))
19551955
; N64-NEXT: ld $2, %got_disp(i32)($1)
1956-
; N64-NEXT: lw $2, 0($2)
1956+
; N64-NEXT: lwu $2, 0($2)
19571957
; N64-NEXT: andi $2, $2, 7
19581958
; N64-NEXT: ld $1, %got_disp(v8i16)($1)
19591959
; N64-NEXT: dlsa $1, $2, $1, 1
@@ -2002,7 +2002,7 @@ define void @insert_v4i32_vidx(i32 signext %a) nounwind {
20022002
; N64-NEXT: daddu $1, $1, $25
20032003
; N64-NEXT: daddiu $1, $1, %lo(%neg(%gp_rel(insert_v4i32_vidx)))
20042004
; N64-NEXT: ld $2, %got_disp(i32)($1)
2005-
; N64-NEXT: lw $2, 0($2)
2005+
; N64-NEXT: lwu $2, 0($2)
20062006
; N64-NEXT: andi $2, $2, 3
20072007
; N64-NEXT: ld $1, %got_disp(v4i32)($1)
20082008
; N64-NEXT: dlsa $1, $2, $1, 2
@@ -2053,7 +2053,7 @@ define void @insert_v2i64_vidx(i64 signext %a) nounwind {
20532053
; N64-NEXT: daddu $1, $1, $25
20542054
; N64-NEXT: daddiu $1, $1, %lo(%neg(%gp_rel(insert_v2i64_vidx)))
20552055
; N64-NEXT: ld $2, %got_disp(i32)($1)
2056-
; N64-NEXT: lw $2, 0($2)
2056+
; N64-NEXT: lwu $2, 0($2)
20572057
; N64-NEXT: andi $2, $2, 1
20582058
; N64-NEXT: ld $1, %got_disp(v2i64)($1)
20592059
; N64-NEXT: dlsa $1, $2, $1, 3

llvm/test/CodeGen/NVPTX/bf16x2-instructions.ll

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -711,11 +711,11 @@ define <2 x bfloat> @test_copysign(<2 x bfloat> %a, <2 x bfloat> %b) #0 {
711711
; CHECK-NEXT: .reg .b32 %r<6>;
712712
; CHECK-EMPTY:
713713
; CHECK-NEXT: // %bb.0:
714-
; CHECK-NEXT: ld.param.b32 %r1, [test_copysign_param_0];
715-
; CHECK-NEXT: ld.param.b32 %r2, [test_copysign_param_1];
716-
; CHECK-NEXT: and.b32 %r3, %r2, -2147450880;
717-
; CHECK-NEXT: and.b32 %r4, %r1, 2147450879;
718-
; CHECK-NEXT: or.b32 %r5, %r4, %r3;
714+
; CHECK-NEXT: ld.param.b32 %r1, [test_copysign_param_1];
715+
; CHECK-NEXT: and.b32 %r2, %r1, -2147450880;
716+
; CHECK-NEXT: ld.param.b32 %r3, [test_copysign_param_0];
717+
; CHECK-NEXT: and.b32 %r4, %r3, 2147450879;
718+
; CHECK-NEXT: or.b32 %r5, %r4, %r2;
719719
; CHECK-NEXT: st.param.b32 [func_retval0], %r5;
720720
; CHECK-NEXT: ret;
721721
%r = call <2 x bfloat> @llvm.copysign.f16(<2 x bfloat> %a, <2 x bfloat> %b)

0 commit comments

Comments
 (0)