Skip to content

Commit ec35883

Browse files
committed
Only adjust the insert point if the step was defined in the loop
1 parent 2803322 commit ec35883

File tree

2 files changed

+60
-2
lines changed

2 files changed

+60
-2
lines changed

llvm/lib/Target/RISCV/RISCVGatherScatterLowering.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,8 +317,9 @@ bool RISCVGatherScatterLowering::matchStridedRecurrence(Value *Index, Loop *L,
317317
}
318318
}
319319

320-
// Adjust the step value after its definition if it's an instruction.
321-
if (auto *StepI = dyn_cast<Instruction>(Step))
320+
// If the Step was defined inside the loop, adjust it before its definition
321+
// instead of in the preheader.
322+
if (auto *StepI = dyn_cast<Instruction>(Step); StepI && L->contains(StepI))
322323
Builder.SetInsertPoint(*StepI->getInsertionPointAfterDef());
323324

324325
switch (BO->getOpcode()) {

llvm/test/CodeGen/RISCV/rvv/strided-load-store.ll

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,63 @@ for.cond.cleanup: ; preds = %vector.body
208208
ret <vscale x 1 x i64> %accum.next
209209
}
210210

211+
; Check that the operand of the binary op (%scale.splat in shl) always dominates
212+
; the existing step value when we're adjusting it.
213+
define <vscale x 1 x i64> @gather_splat_op_after_step(ptr %a, ptr %b, i32 %len) {
214+
; CHECK-LABEL: @gather_splat_op_after_step(
215+
; CHECK-NEXT: vector.ph:
216+
; CHECK-NEXT: [[WIDE_TRIP_COUNT:%.*]] = zext i32 [[LEN:%.*]] to i64
217+
; CHECK-NEXT: [[TMP0:%.*]] = tail call i64 @llvm.vscale.i64()
218+
; CHECK-NEXT: [[SCALE:%.*]] = load i64, ptr [[B:%.*]], align 8
219+
; CHECK-NEXT: [[STRIDE:%.*]] = shl i64 1, [[SCALE]]
220+
; CHECK-NEXT: [[STEP:%.*]] = shl i64 [[TMP0]], [[SCALE]]
221+
; CHECK-NEXT: [[TMP1:%.*]] = mul i64 [[STRIDE]], 16
222+
; CHECK-NEXT: br label [[VECTOR_BODY:%.*]]
223+
; CHECK: vector.body:
224+
; CHECK-NEXT: [[INDEX:%.*]] = phi i64 [ 0, [[VECTOR_PH:%.*]] ], [ [[INDEX_NEXT:%.*]], [[VECTOR_BODY]] ]
225+
; CHECK-NEXT: [[VEC_IND_SCALAR:%.*]] = phi i64 [ 0, [[VECTOR_PH]] ], [ [[VEC_IND_NEXT_SCALAR:%.*]], [[VECTOR_BODY]] ]
226+
; CHECK-NEXT: [[ACCUM:%.*]] = phi <vscale x 1 x i64> [ zeroinitializer, [[VECTOR_PH]] ], [ [[ACCUM_NEXT:%.*]], [[VECTOR_BODY]] ]
227+
; CHECK-NEXT: [[TMP2:%.*]] = getelementptr [[STRUCT_FOO:%.*]], ptr [[A:%.*]], i64 [[VEC_IND_SCALAR]], i32 3
228+
; CHECK-NEXT: [[TMP3:%.*]] = call i32 @llvm.vscale.i32()
229+
; CHECK-NEXT: [[TMP4:%.*]] = call <vscale x 1 x i64> @llvm.experimental.vp.strided.load.nxv1i64.p0.i64(ptr [[TMP2]], i64 [[TMP1]], <vscale x 1 x i1> splat (i1 true), i32 [[TMP3]])
230+
; CHECK-NEXT: [[GATHER:%.*]] = call <vscale x 1 x i64> @llvm.vp.select.nxv1i64(<vscale x 1 x i1> splat (i1 true), <vscale x 1 x i64> [[TMP4]], <vscale x 1 x i64> undef, i32 [[TMP3]])
231+
; CHECK-NEXT: [[ACCUM_NEXT]] = add <vscale x 1 x i64> [[ACCUM]], [[GATHER]]
232+
; CHECK-NEXT: [[INDEX_NEXT]] = add nuw i64 [[INDEX]], [[TMP0]]
233+
; CHECK-NEXT: [[VEC_IND_NEXT_SCALAR]] = add i64 [[VEC_IND_SCALAR]], [[STEP]]
234+
; CHECK-NEXT: [[TMP5:%.*]] = icmp ne i64 [[INDEX_NEXT]], [[WIDE_TRIP_COUNT]]
235+
; CHECK-NEXT: br i1 [[TMP5]], label [[FOR_COND_CLEANUP:%.*]], label [[VECTOR_BODY]]
236+
; CHECK: for.cond.cleanup:
237+
; CHECK-NEXT: ret <vscale x 1 x i64> [[ACCUM_NEXT]]
238+
;
239+
vector.ph:
240+
%wide.trip.count = zext i32 %len to i64
241+
%0 = tail call i64 @llvm.vscale.i64()
242+
%1 = tail call <vscale x 1 x i64> @llvm.stepvector.nxv1i64()
243+
%.splatinsert = insertelement <vscale x 1 x i64> poison, i64 %0, i64 0
244+
%.splat = shufflevector <vscale x 1 x i64> %.splatinsert, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
245+
246+
%scale = load i64, ptr %b
247+
%scale.head = insertelement <vscale x 1 x i64> poison, i64 %scale, i64 0
248+
%scale.splat = shufflevector <vscale x 1 x i64> %scale.head, <vscale x 1 x i64> poison, <vscale x 1 x i32> zeroinitializer
249+
br label %vector.body
250+
251+
vector.body: ; preds = %vector.body, %vector.ph
252+
%index = phi i64 [ 0, %vector.ph ], [ %index.next, %vector.body ]
253+
%vec.ind = phi <vscale x 1 x i64> [ %1, %vector.ph ], [ %vec.ind.next, %vector.body ]
254+
%accum = phi <vscale x 1 x i64> [ zeroinitializer, %vector.ph ], [ %accum.next, %vector.body ]
255+
%vec.ind.shl = shl <vscale x 1 x i64> %vec.ind, %scale.splat
256+
%2 = getelementptr inbounds %struct.foo, ptr %a, <vscale x 1 x i64> %vec.ind.shl, i32 3
257+
%gather = call <vscale x 1 x i64> @llvm.masked.gather.nxv1i64.nxv1p0(<vscale x 1 x ptr> %2, i32 8, <vscale x 1 x i1> splat (i1 true), <vscale x 1 x i64> undef)
258+
%accum.next = add <vscale x 1 x i64> %accum, %gather
259+
%index.next = add nuw i64 %index, %0
260+
%vec.ind.next = add <vscale x 1 x i64> %vec.ind, %.splat
261+
%3 = icmp ne i64 %index.next, %wide.trip.count
262+
br i1 %3, label %for.cond.cleanup, label %vector.body
263+
264+
for.cond.cleanup: ; preds = %vector.body
265+
ret <vscale x 1 x i64> %accum.next
266+
}
267+
211268
define void @scatter(ptr %a, i32 %len) {
212269
; CHECK-LABEL: @scatter(
213270
; CHECK-NEXT: vector.ph:

0 commit comments

Comments
 (0)