diff --git a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp index a6a14837bf473..41d0b669989cd 100644 --- a/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp +++ b/llvm/lib/Target/RISCV/RISCVTargetTransformInfo.cpp @@ -2847,8 +2847,11 @@ bool RISCVTTIImpl::isProfitableToSinkOperands( if (!Op || any_of(Ops, [&](Use *U) { return U->get() == Op; })) continue; - // We are looking for a splat that can be sunk. - if (!match(Op, m_Shuffle(m_InsertElt(m_Undef(), m_Value(), m_ZeroInt()), + // We are looking for a splat/vp.splat that can be sunk. + bool IsVPSplat = match(Op, m_Intrinsic( + m_Value(), m_Value(), m_Value())); + if (!IsVPSplat && + !match(Op, m_Shuffle(m_InsertElt(m_Undef(), m_Value(), m_ZeroInt()), m_Undef(), m_ZeroMask()))) continue; @@ -2864,12 +2867,17 @@ bool RISCVTTIImpl::isProfitableToSinkOperands( return false; } - Use *InsertEltUse = &Op->getOperandUse(0); // Sink any fpexts since they might be used in a widening fp pattern. - auto *InsertElt = cast(InsertEltUse); - if (isa(InsertElt->getOperand(1))) - Ops.push_back(&InsertElt->getOperandUse(1)); - Ops.push_back(InsertEltUse); + if (IsVPSplat) { + if (isa(Op->getOperand(0))) + Ops.push_back(&Op->getOperandUse(0)); + } else { + Use *InsertEltUse = &Op->getOperandUse(0); + auto *InsertElt = cast(InsertEltUse); + if (isa(InsertElt->getOperand(1))) + Ops.push_back(&InsertElt->getOperandUse(1)); + Ops.push_back(InsertEltUse); + } Ops.push_back(&OpIdx.value()); } return true; diff --git a/llvm/test/CodeGen/RISCV/rvv/sink-splat-operands.ll b/llvm/test/CodeGen/RISCV/rvv/sink-splat-operands.ll index 9b794538c404e..c216fb65a6a5b 100644 --- a/llvm/test/CodeGen/RISCV/rvv/sink-splat-operands.ll +++ b/llvm/test/CodeGen/RISCV/rvv/sink-splat-operands.ll @@ -5890,3 +5890,132 @@ vector.body: ; preds = %vector.body, %entry for.cond.cleanup: ; preds = %vector.body ret void } + +define void @sink_vp_splat(ptr nocapture %out, ptr nocapture %in) { +; CHECK-LABEL: sink_vp_splat: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: li a2, 0 +; CHECK-NEXT: li a3, 1024 +; CHECK-NEXT: li a4, 3 +; CHECK-NEXT: lui a5, 1 +; CHECK-NEXT: .LBB129_1: # %vector.body +; CHECK-NEXT: # =>This Loop Header: Depth=1 +; CHECK-NEXT: # Child Loop BB129_2 Depth 2 +; CHECK-NEXT: vsetvli a6, a3, e32, m4, ta, ma +; CHECK-NEXT: slli a7, a2, 2 +; CHECK-NEXT: vmv.v.i v8, 0 +; CHECK-NEXT: add t0, a1, a7 +; CHECK-NEXT: li t1, 1024 +; CHECK-NEXT: .LBB129_2: # %for.body424 +; CHECK-NEXT: # Parent Loop BB129_1 Depth=1 +; CHECK-NEXT: # => This Inner Loop Header: Depth=2 +; CHECK-NEXT: vle32.v v12, (t0) +; CHECK-NEXT: addi t1, t1, -1 +; CHECK-NEXT: vmacc.vx v8, a4, v12 +; CHECK-NEXT: add t0, t0, a5 +; CHECK-NEXT: bnez t1, .LBB129_2 +; CHECK-NEXT: # %bb.3: # %vector.latch +; CHECK-NEXT: # in Loop: Header=BB129_1 Depth=1 +; CHECK-NEXT: add a7, a0, a7 +; CHECK-NEXT: sub a3, a3, a6 +; CHECK-NEXT: vse32.v v8, (a7) +; CHECK-NEXT: add a2, a2, a6 +; CHECK-NEXT: bnez a3, .LBB129_1 +; CHECK-NEXT: # %bb.4: # %for.cond.cleanup +; CHECK-NEXT: ret +entry: + br label %vector.body + +vector.body: ; preds = %vector.latch, %entry + %scalar.ind = phi i64 [ 0, %entry ], [ %next.ind, %vector.latch ] + %trip.count = phi i64 [ 1024, %entry ], [ %remaining.trip.count, %vector.latch ] + %evl = tail call i32 @llvm.experimental.get.vector.length.i64(i64 %trip.count, i32 8, i1 true) + %vp.splat1 = tail call @llvm.experimental.vp.splat.nxv8i32(i32 0, splat(i1 true), i32 %evl) + %vp.splat2 = tail call @llvm.experimental.vp.splat.nxv8i32(i32 3, splat(i1 true), i32 %evl) + %evl.cast = zext i32 %evl to i64 + br label %for.body424 + +for.body424: ; preds = %for.body424, %vector.body + %scalar.phi = phi i64 [ 0, %vector.body ], [ %indvars.iv.next27, %for.body424 ] + %vector.phi = phi [ %vp.splat1, %vector.body ], [ %vp.binary26, %for.body424 ] + %arrayidx625 = getelementptr inbounds [1024 x i32], ptr %in, i64 %scalar.phi, i64 %scalar.ind + %widen.load = tail call @llvm.vp.load.nxv8i32.p0(ptr %arrayidx625, splat (i1 true), i32 %evl) + %vp.binary = tail call @llvm.vp.mul.nxv8i32( %widen.load, %vp.splat2, splat (i1 true), i32 %evl) + %vp.binary26 = tail call @llvm.vp.add.nxv8i32( %vector.phi, %vp.binary, splat (i1 true), i32 %evl) + %indvars.iv.next27 = add nuw nsw i64 %scalar.phi, 1 + %exitcond.not28 = icmp eq i64 %indvars.iv.next27, 1024 + br i1 %exitcond.not28, label %vector.latch, label %for.body424 + +vector.latch: ; preds = %for.body424 + %arrayidx830 = getelementptr inbounds i32, ptr %out, i64 %scalar.ind + tail call void @llvm.vp.store.nxv8i32.p0( %vp.binary26, ptr %arrayidx830, splat (i1 true), i32 %evl) + %remaining.trip.count = sub nuw i64 %trip.count, %evl.cast + %next.ind = add i64 %scalar.ind, %evl.cast + %6 = icmp eq i64 %remaining.trip.count, 0 + br i1 %6, label %for.cond.cleanup, label %vector.body + +for.cond.cleanup: ; preds = %vector.latch + ret void +} + +define void @sink_vp_splat_vfwadd_wf(ptr nocapture %in, float %f) { +; CHECK-LABEL: sink_vp_splat_vfwadd_wf: +; CHECK: # %bb.0: # %entry +; CHECK-NEXT: li a1, 0 +; CHECK-NEXT: li a2, 1024 +; CHECK-NEXT: lui a3, 2 +; CHECK-NEXT: .LBB130_1: # %vector.body +; CHECK-NEXT: # =>This Loop Header: Depth=1 +; CHECK-NEXT: # Child Loop BB130_2 Depth 2 +; CHECK-NEXT: vsetvli a4, a2, e8, m1, ta, ma +; CHECK-NEXT: slli a5, a1, 3 +; CHECK-NEXT: add a5, a0, a5 +; CHECK-NEXT: li a6, 1024 +; CHECK-NEXT: .LBB130_2: # %for.body419 +; CHECK-NEXT: # Parent Loop BB130_1 Depth=1 +; CHECK-NEXT: # => This Inner Loop Header: Depth=2 +; CHECK-NEXT: vsetvli zero, zero, e32, m4, ta, ma +; CHECK-NEXT: vle64.v v8, (a5) +; CHECK-NEXT: addi a6, a6, -1 +; CHECK-NEXT: vfwadd.wf v8, v8, fa0 +; CHECK-NEXT: vse64.v v8, (a5) +; CHECK-NEXT: add a5, a5, a3 +; CHECK-NEXT: bnez a6, .LBB130_2 +; CHECK-NEXT: # %bb.3: # %vector.latch +; CHECK-NEXT: # in Loop: Header=BB130_1 Depth=1 +; CHECK-NEXT: sub a2, a2, a4 +; CHECK-NEXT: add a1, a1, a4 +; CHECK-NEXT: bnez a2, .LBB130_1 +; CHECK-NEXT: # %bb.4: # %for.cond.cleanup +; CHECK-NEXT: ret +entry: + %conv = fpext float %f to double + br label %vector.body + +vector.body: ; preds = %vector.latch, %entry + %scalar.ind = phi i64 [ 0, %entry ], [ %next.ind, %vector.latch ] + %trip.count = phi i64 [ 1024, %entry ], [ %remaining.trip.count, %vector.latch ] + %evl = call i32 @llvm.experimental.get.vector.length.i64(i64 %trip.count, i32 8, i1 true) + %vp.splat = call @llvm.experimental.vp.splat.nxv8f64(double %conv, splat (i1 true), i32 %evl) + %evl.cast = zext i32 %evl to i64 + br label %for.body419 + +for.body419: ; preds = %for.body419, %vector.body + %scalar.phi = phi i64 [ 0, %vector.body ], [ %indvars.iv.next21, %for.body419 ] + %arrayidx620 = getelementptr inbounds [1024 x double], ptr %in, i64 %scalar.phi, i64 %scalar.ind + %widen.load = call @llvm.vp.load.nxv8f64.p0(ptr %arrayidx620, splat (i1 true), i32 %evl) + %vp.binary = call @llvm.vp.fadd.nxv8f64( %widen.load, %vp.splat, splat (i1 true), i32 %evl) + call void @llvm.vp.store.nxv8f64.p0( %vp.binary, ptr %arrayidx620, splat (i1 true), i32 %evl) + %indvars.iv.next21 = add nuw nsw i64 %scalar.phi, 1 + %exitcond.not22 = icmp eq i64 %indvars.iv.next21, 1024 + br i1 %exitcond.not22, label %vector.latch, label %for.body419 + +vector.latch: ; preds = %for.body419 + %remaining.trip.count = sub nuw i64 %trip.count, %evl.cast + %next.ind = add i64 %scalar.ind, %evl.cast + %cond = icmp eq i64 %remaining.trip.count, 0 + br i1 %cond, label %for.cond.cleanup, label %vector.body + +for.cond.cleanup: ; preds = %vector.latch + ret void +}