diff --git a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp index 0e8e4c9618bb2..38353a57370e4 100644 --- a/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp +++ b/llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp @@ -7633,7 +7633,7 @@ SDValue DAGCombiner::visitAND(SDNode *N) { if (SDValue(GN0, 0).hasOneUse() && isConstantSplatVectorMaskForType(N1.getNode(), ScalarVT) && - TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) { + TLI.isVectorLoadExtDesirable(SDValue(N, 0))) { SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(), GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()}; @@ -15724,7 +15724,7 @@ SDValue DAGCombiner::visitSIGN_EXTEND_INREG(SDNode *N) { // fold (sext_inreg (masked_gather x)) -> (sext_masked_gather x) if (auto *GN0 = dyn_cast(N0)) { if (SDValue(GN0, 0).hasOneUse() && ExtVT == GN0->getMemoryVT() && - TLI.isVectorLoadExtDesirable(SDValue(SDValue(GN0, 0)))) { + TLI.isVectorLoadExtDesirable(SDValue(N, 0))) { SDValue Ops[] = {GN0->getChain(), GN0->getPassThru(), GN0->getMask(), GN0->getBasePtr(), GN0->getIndex(), GN0->getScale()}; diff --git a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp index 4f13a14d24649..2775bdb175556 100644 --- a/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp +++ b/llvm/lib/Target/AArch64/AArch64ISelLowering.cpp @@ -6439,7 +6439,9 @@ bool AArch64TargetLowering::isVectorLoadExtDesirable(SDValue ExtVal) const { } } - return true; + EVT PreExtScalarVT = ExtVal->getOperand(0).getValueType().getScalarType(); + return PreExtScalarVT == MVT::i8 || PreExtScalarVT == MVT::i16 || + PreExtScalarVT == MVT::i32 || PreExtScalarVT == MVT::i64; } unsigned getGatherVecOpcode(bool IsScaled, bool IsSigned, bool NeedsExtend) { diff --git a/llvm/test/CodeGen/AArch64/sve-intrinsics-ldst-ext.ll b/llvm/test/CodeGen/AArch64/sve-intrinsics-ldst-ext.ll index 4153f0be611a1..9698f1a6768fd 100644 --- a/llvm/test/CodeGen/AArch64/sve-intrinsics-ldst-ext.ll +++ b/llvm/test/CodeGen/AArch64/sve-intrinsics-ldst-ext.ll @@ -231,3 +231,27 @@ define @sload_8i8_8i64(ptr %a) { %aext = sext %aval to ret %aext } + +; Ensure we don't try to promote a predicate load to a sign-extended load. +define @sload_16i1_16i8(ptr %addr) { +; CHECK-LABEL: sload_16i1_16i8: +; CHECK: // %bb.0: +; CHECK-NEXT: ldr p0, [x0] +; CHECK-NEXT: mov z0.b, p0/z, #-1 // =0xffffffffffffffff +; CHECK-NEXT: ret + %load = load , ptr %addr + %zext = sext %load to + ret %zext +} + +; Ensure we don't try to promote a predicate load to a zero-extended load. +define @zload_16i1_16i8(ptr %addr) { +; CHECK-LABEL: zload_16i1_16i8: +; CHECK: // %bb.0: +; CHECK-NEXT: ldr p0, [x0] +; CHECK-NEXT: mov z0.b, p0/z, #1 // =0x1 +; CHECK-NEXT: ret + %load = load , ptr %addr + %zext = zext %load to + ret %zext +}