Skip to content

Commit b8b84d2

Browse files
committed
[VecorCombine] Avoid inserting freeze when scalarizing extend-extract if all extracts would lead to UB on poison.
1 parent 553cfa8 commit b8b84d2

File tree

2 files changed

+28
-13
lines changed

2 files changed

+28
-13
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2017,8 +2017,24 @@ bool VectorCombine::scalarizeExtExtract(Instruction &I) {
20172017

20182018
Value *ScalarV = Ext->getOperand(0);
20192019
if (!isGuaranteedNotToBePoison(ScalarV, &AC, dyn_cast<Instruction>(ScalarV),
2020-
&DT))
2021-
ScalarV = Builder.CreateFreeze(ScalarV);
2020+
&DT)) {
2021+
// Check if all lanes are extracted and all extracts trigger UB on poison.
2022+
// If so, we do not need to insert a freeze.
2023+
SmallDenseSet<uint64_t, 8> ExtractedLanes;
2024+
bool AllExtractsHaveUB = true;
2025+
for (User *U : Ext->users()) {
2026+
auto *Extract = cast<ExtractElementInst>(U);
2027+
uint64_t Idx =
2028+
cast<ConstantInt>(Extract->getIndexOperand())->getZExtValue();
2029+
ExtractedLanes.insert(Idx);
2030+
if (!programUndefinedIfPoison(Extract)) {
2031+
AllExtractsHaveUB = false;
2032+
break;
2033+
}
2034+
}
2035+
if (!AllExtractsHaveUB || ExtractedLanes.size() != SrcTy->getNumElements())
2036+
ScalarV = Builder.CreateFreeze(ScalarV);
2037+
}
20222038
ScalarV = Builder.CreateBitCast(
20232039
ScalarV,
20242040
IntegerType::get(SrcTy->getContext(), DL->getTypeSizeInBits(SrcTy)));

llvm/test/Transforms/VectorCombine/AArch64/ext-extract.ll

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -351,22 +351,21 @@ define noundef i32 @zext_v4i8_all_lanes_used_no_freeze(<4 x i8> %src) {
351351
; CHECK-LABEL: define noundef i32 @zext_v4i8_all_lanes_used_no_freeze(
352352
; CHECK-SAME: <4 x i8> [[SRC:%.*]]) {
353353
; CHECK-NEXT: [[ENTRY:.*:]]
354-
; CHECK-NEXT: [[TMP0:%.*]] = freeze <4 x i8> [[SRC]]
355-
; CHECK-NEXT: [[TMP1:%.*]] = bitcast <4 x i8> [[TMP0]] to i32
356-
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP1]], 24
357-
; CHECK-NEXT: [[TMP3:%.*]] = lshr i32 [[TMP1]], 16
358-
; CHECK-NEXT: [[TMP4:%.*]] = and i32 [[TMP3]], 255
359-
; CHECK-NEXT: [[TMP5:%.*]] = lshr i32 [[TMP1]], 8
360-
; CHECK-NEXT: [[TMP6:%.*]] = and i32 [[TMP5]], 255
361-
; CHECK-NEXT: [[TMP7:%.*]] = and i32 [[TMP1]], 255
354+
; CHECK-NEXT: [[TMP0:%.*]] = bitcast <4 x i8> [[SRC]] to i32
355+
; CHECK-NEXT: [[TMP1:%.*]] = lshr i32 [[TMP0]], 24
356+
; CHECK-NEXT: [[TMP2:%.*]] = lshr i32 [[TMP0]], 16
357+
; CHECK-NEXT: [[TMP3:%.*]] = and i32 [[TMP2]], 255
358+
; CHECK-NEXT: [[TMP4:%.*]] = lshr i32 [[TMP0]], 8
359+
; CHECK-NEXT: [[TMP5:%.*]] = and i32 [[TMP4]], 255
360+
; CHECK-NEXT: [[TMP6:%.*]] = and i32 [[TMP0]], 255
362361
; CHECK-NEXT: [[EXT:%.*]] = zext nneg <4 x i8> [[SRC]] to <4 x i32>
363362
; CHECK-NEXT: [[EXT_0:%.*]] = extractelement <4 x i32> [[EXT]], i64 0
364363
; CHECK-NEXT: [[EXT_1:%.*]] = extractelement <4 x i32> [[EXT]], i64 1
365364
; CHECK-NEXT: [[EXT_2:%.*]] = extractelement <4 x i32> [[EXT]], i64 2
366365
; CHECK-NEXT: [[EXT_3:%.*]] = extractelement <4 x i32> [[EXT]], i64 3
367-
; CHECK-NEXT: [[ADD1:%.*]] = add i32 [[TMP7]], [[TMP6]]
368-
; CHECK-NEXT: [[ADD2:%.*]] = add i32 [[ADD1]], [[TMP4]]
369-
; CHECK-NEXT: [[ADD3:%.*]] = add i32 [[ADD2]], [[TMP2]]
366+
; CHECK-NEXT: [[ADD1:%.*]] = add i32 [[TMP6]], [[TMP5]]
367+
; CHECK-NEXT: [[ADD2:%.*]] = add i32 [[ADD1]], [[TMP3]]
368+
; CHECK-NEXT: [[ADD3:%.*]] = add i32 [[ADD2]], [[TMP1]]
370369
; CHECK-NEXT: ret i32 [[ADD3]]
371370
;
372371
entry:

0 commit comments

Comments
 (0)