Skip to content

Commit cb80fa7

Browse files
authored
[VectorCombine] Support pattern bitop(bitcast(x), C) -> bitcast(bitop(x, InvC)) (#155216)
Resolves #154797. This patch adds the fold `bitop(bitcast(x), C) -> bitop(bitcast(x), cast(InvC)) -> bitcast(bitop(x, InvC))`. The helper function `getLosslessInvCast` tries to calculate the constant `InvC`, satisfying `castop(InvC) == C`, and will try its best to keep the poison-generated flags of the cast operation.
1 parent 8dee9e4 commit cb80fa7

File tree

2 files changed

+303
-0
lines changed

2 files changed

+303
-0
lines changed

llvm/lib/Transforms/Vectorize/VectorCombine.cpp

Lines changed: 143 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ class VectorCombine {
122122
bool foldInsExtBinop(Instruction &I);
123123
bool foldInsExtVectorToShuffle(Instruction &I);
124124
bool foldBitOpOfCastops(Instruction &I);
125+
bool foldBitOpOfCastConstant(Instruction &I);
125126
bool foldBitcastShuffle(Instruction &I);
126127
bool scalarizeOpOrCmp(Instruction &I);
127128
bool scalarizeVPIntrinsic(Instruction &I);
@@ -937,6 +938,146 @@ bool VectorCombine::foldBitOpOfCastops(Instruction &I) {
937938
return true;
938939
}
939940

941+
struct PreservedCastFlags {
942+
bool NNeg = false;
943+
bool NUW = false;
944+
bool NSW = false;
945+
};
946+
947+
// Try to cast C to InvC losslessly, satisfying CastOp(InvC) == C.
948+
// Will try best to preserve the flags.
949+
static Constant *getLosslessInvCast(Constant *C, Type *InvCastTo,
950+
Instruction::CastOps CastOp,
951+
const DataLayout &DL,
952+
PreservedCastFlags &Flags) {
953+
switch (CastOp) {
954+
case Instruction::BitCast:
955+
// Bitcast is always lossless.
956+
return ConstantFoldCastOperand(Instruction::BitCast, C, InvCastTo, DL);
957+
case Instruction::Trunc: {
958+
auto *ZExtC = ConstantFoldCastOperand(Instruction::ZExt, C, InvCastTo, DL);
959+
auto *SExtC = ConstantFoldCastOperand(Instruction::SExt, C, InvCastTo, DL);
960+
// Truncation back on ZExt value is always NUW.
961+
Flags.NUW = true;
962+
// Test positivity of C.
963+
Flags.NSW = ZExtC == SExtC;
964+
return ZExtC;
965+
}
966+
case Instruction::SExt:
967+
case Instruction::ZExt: {
968+
auto *InvC = ConstantExpr::getTrunc(C, InvCastTo);
969+
auto *CastInvC = ConstantFoldCastOperand(CastOp, InvC, C->getType(), DL);
970+
// Must satisfy CastOp(InvC) == C.
971+
if (!CastInvC || CastInvC != C)
972+
return nullptr;
973+
if (CastOp == Instruction::ZExt) {
974+
auto *SExtInvC =
975+
ConstantFoldCastOperand(Instruction::SExt, InvC, C->getType(), DL);
976+
// Test positivity of InvC.
977+
Flags.NNeg = CastInvC == SExtInvC;
978+
}
979+
return InvC;
980+
}
981+
default:
982+
return nullptr;
983+
}
984+
}
985+
986+
/// Match:
987+
// bitop(castop(x), C) ->
988+
// bitop(castop(x), castop(InvC)) ->
989+
// castop(bitop(x, InvC))
990+
// Supports: bitcast
991+
bool VectorCombine::foldBitOpOfCastConstant(Instruction &I) {
992+
Instruction *LHS;
993+
Constant *C;
994+
995+
// Check if this is a bitwise logic operation
996+
if (!match(&I, m_c_BitwiseLogic(m_Instruction(LHS), m_Constant(C))))
997+
return false;
998+
999+
// Get the cast instructions
1000+
auto *LHSCast = dyn_cast<CastInst>(LHS);
1001+
if (!LHSCast)
1002+
return false;
1003+
1004+
Instruction::CastOps CastOpcode = LHSCast->getOpcode();
1005+
1006+
// Only handle supported cast operations
1007+
switch (CastOpcode) {
1008+
case Instruction::BitCast:
1009+
break;
1010+
default:
1011+
return false;
1012+
}
1013+
1014+
Value *LHSSrc = LHSCast->getOperand(0);
1015+
1016+
// Only handle vector types with integer elements
1017+
auto *SrcVecTy = dyn_cast<FixedVectorType>(LHSSrc->getType());
1018+
auto *DstVecTy = dyn_cast<FixedVectorType>(I.getType());
1019+
if (!SrcVecTy || !DstVecTy)
1020+
return false;
1021+
1022+
if (!SrcVecTy->getScalarType()->isIntegerTy() ||
1023+
!DstVecTy->getScalarType()->isIntegerTy())
1024+
return false;
1025+
1026+
// Find the constant InvC, such that castop(InvC) equals to C.
1027+
PreservedCastFlags RHSFlags;
1028+
Constant *InvC = getLosslessInvCast(C, SrcVecTy, CastOpcode, *DL, RHSFlags);
1029+
if (!InvC)
1030+
return false;
1031+
1032+
// Cost Check :
1033+
// OldCost = bitlogic + cast
1034+
// NewCost = bitlogic + cast
1035+
1036+
// Calculate specific costs for each cast with instruction context
1037+
InstructionCost LHSCastCost =
1038+
TTI.getCastInstrCost(CastOpcode, DstVecTy, SrcVecTy,
1039+
TTI::CastContextHint::None, CostKind, LHSCast);
1040+
1041+
InstructionCost OldCost =
1042+
TTI.getArithmeticInstrCost(I.getOpcode(), DstVecTy, CostKind) +
1043+
LHSCastCost;
1044+
1045+
// For new cost, we can't provide an instruction (it doesn't exist yet)
1046+
InstructionCost GenericCastCost = TTI.getCastInstrCost(
1047+
CastOpcode, DstVecTy, SrcVecTy, TTI::CastContextHint::None, CostKind);
1048+
1049+
InstructionCost NewCost =
1050+
TTI.getArithmeticInstrCost(I.getOpcode(), SrcVecTy, CostKind) +
1051+
GenericCastCost;
1052+
1053+
// Account for multi-use casts using specific costs
1054+
if (!LHSCast->hasOneUse())
1055+
NewCost += LHSCastCost;
1056+
1057+
LLVM_DEBUG(dbgs() << "foldBitOpOfCastConstant: OldCost=" << OldCost
1058+
<< " NewCost=" << NewCost << "\n");
1059+
1060+
if (NewCost > OldCost)
1061+
return false;
1062+
1063+
// Create the operation on the source type
1064+
Value *NewOp = Builder.CreateBinOp((Instruction::BinaryOps)I.getOpcode(),
1065+
LHSSrc, InvC, I.getName() + ".inner");
1066+
if (auto *NewBinOp = dyn_cast<BinaryOperator>(NewOp))
1067+
NewBinOp->copyIRFlags(&I);
1068+
1069+
Worklist.pushValue(NewOp);
1070+
1071+
// Create the cast operation directly to ensure we get a new instruction
1072+
Instruction *NewCast = CastInst::Create(CastOpcode, NewOp, I.getType());
1073+
1074+
// Insert the new instruction
1075+
Value *Result = Builder.Insert(NewCast);
1076+
1077+
replaceValue(I, *Result);
1078+
return true;
1079+
}
1080+
9401081
/// If this is a bitcast of a shuffle, try to bitcast the source vector to the
9411082
/// destination type followed by shuffle. This can enable further transforms by
9421083
/// moving bitcasts or shuffles together.
@@ -4474,6 +4615,8 @@ bool VectorCombine::run() {
44744615
case Instruction::Xor:
44754616
if (foldBitOpOfCastops(I))
44764617
return true;
4618+
if (foldBitOpOfCastConstant(I))
4619+
return true;
44774620
break;
44784621
case Instruction::PHI:
44794622
if (shrinkPhiOfShuffles(I))

llvm/test/Transforms/VectorCombine/X86/bitop-of-castops.ll

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -260,3 +260,163 @@ define <4 x i32> @or_zext_nneg(<4 x i16> %a, <4 x i16> %b) {
260260
%or = or <4 x i32> %z1, %z2
261261
ret <4 x i32> %or
262262
}
263+
264+
; Test bitwise operations with integer-to-integer bitcast with one constant
265+
define <2 x i32> @or_bitcast_v4i16_to_v2i32_constant(<4 x i16> %a) {
266+
; CHECK-LABEL: @or_bitcast_v4i16_to_v2i32_constant(
267+
; CHECK-NEXT: [[A:%.*]] = or <4 x i16> [[A1:%.*]], <i16 16960, i16 15, i16 -31616, i16 30>
268+
; CHECK-NEXT: [[BC1:%.*]] = bitcast <4 x i16> [[A]] to <2 x i32>
269+
; CHECK-NEXT: ret <2 x i32> [[BC1]]
270+
;
271+
%bc1 = bitcast <4 x i16> %a to <2 x i32>
272+
%or = or <2 x i32> %bc1, <i32 1000000, i32 2000000>
273+
ret <2 x i32> %or
274+
}
275+
276+
define <2 x i32> @or_bitcast_v4i16_to_v2i32_constant_commuted(<4 x i16> %a) {
277+
; CHECK-LABEL: @or_bitcast_v4i16_to_v2i32_constant_commuted(
278+
; CHECK-NEXT: [[A:%.*]] = or <4 x i16> [[A1:%.*]], <i16 16960, i16 15, i16 -31616, i16 30>
279+
; CHECK-NEXT: [[BC1:%.*]] = bitcast <4 x i16> [[A]] to <2 x i32>
280+
; CHECK-NEXT: ret <2 x i32> [[BC1]]
281+
;
282+
%bc1 = bitcast <4 x i16> %a to <2 x i32>
283+
%or = or <2 x i32> <i32 1000000, i32 2000000>, %bc1
284+
ret <2 x i32> %or
285+
}
286+
287+
; Test bitwise operations with truncate and one constant
288+
define <4 x i16> @or_trunc_v4i32_to_v4i16_constant(<4 x i32> %a) {
289+
; CHECK-LABEL: @or_trunc_v4i32_to_v4i16_constant(
290+
; CHECK-NEXT: [[T1:%.*]] = trunc <4 x i32> [[A:%.*]] to <4 x i16>
291+
; CHECK-NEXT: [[OR:%.*]] = or <4 x i16> [[T1]], <i16 1, i16 2, i16 3, i16 4>
292+
; CHECK-NEXT: ret <4 x i16> [[OR]]
293+
;
294+
%t1 = trunc <4 x i32> %a to <4 x i16>
295+
%or = or <4 x i16> %t1, <i16 1, i16 2, i16 3, i16 4>
296+
ret <4 x i16> %or
297+
}
298+
299+
; Test bitwise operations with zero extend and one constant
300+
define <4 x i32> @or_zext_v4i16_to_v4i32_constant(<4 x i16> %a) {
301+
; CHECK-LABEL: @or_zext_v4i16_to_v4i32_constant(
302+
; CHECK-NEXT: [[Z1:%.*]] = zext <4 x i16> [[A:%.*]] to <4 x i32>
303+
; CHECK-NEXT: [[OR:%.*]] = or <4 x i32> [[Z1]], <i32 1, i32 2, i32 3, i32 4>
304+
; CHECK-NEXT: ret <4 x i32> [[OR]]
305+
;
306+
%z1 = zext <4 x i16> %a to <4 x i32>
307+
%or = or <4 x i32> %z1, <i32 1, i32 2, i32 3, i32 4>
308+
ret <4 x i32> %or
309+
}
310+
311+
define <4 x i32> @or_zext_v4i8_to_v4i32_constant_with_loss(<4 x i8> %a) {
312+
; CHECK-LABEL: @or_zext_v4i8_to_v4i32_constant_with_loss(
313+
; CHECK-NEXT: [[Z1:%.*]] = zext <4 x i8> [[A:%.*]] to <4 x i32>
314+
; CHECK-NEXT: [[OR:%.*]] = or <4 x i32> [[Z1]], <i32 1024, i32 129, i32 3, i32 4>
315+
; CHECK-NEXT: ret <4 x i32> [[OR]]
316+
;
317+
%z1 = zext <4 x i8> %a to <4 x i32>
318+
%or = or <4 x i32> %z1, <i32 1024, i32 129, i32 3, i32 4>
319+
ret <4 x i32> %or
320+
}
321+
322+
; Test bitwise operations with sign extend and one constant
323+
define <4 x i32> @or_sext_v4i8_to_v4i32_positive_constant(<4 x i8> %a) {
324+
; CHECK-LABEL: @or_sext_v4i8_to_v4i32_positive_constant(
325+
; CHECK-NEXT: [[S1:%.*]] = sext <4 x i8> [[A:%.*]] to <4 x i32>
326+
; CHECK-NEXT: [[OR:%.*]] = or <4 x i32> [[S1]], <i32 1, i32 2, i32 3, i32 4>
327+
; CHECK-NEXT: ret <4 x i32> [[OR]]
328+
;
329+
%s1 = sext <4 x i8> %a to <4 x i32>
330+
%or = or <4 x i32> %s1, <i32 1, i32 2, i32 3, i32 4>
331+
ret <4 x i32> %or
332+
}
333+
334+
define <4 x i32> @or_sext_v4i8_to_v4i32_minus_constant(<4 x i8> %a) {
335+
; CHECK-LABEL: @or_sext_v4i8_to_v4i32_minus_constant(
336+
; CHECK-NEXT: [[S1:%.*]] = sext <4 x i8> [[A:%.*]] to <4 x i32>
337+
; CHECK-NEXT: [[OR:%.*]] = or <4 x i32> [[S1]], <i32 -1, i32 -2, i32 -3, i32 -4>
338+
; CHECK-NEXT: ret <4 x i32> [[OR]]
339+
;
340+
%s1 = sext <4 x i8> %a to <4 x i32>
341+
%or = or <4 x i32> %s1, <i32 -1, i32 -2, i32 -3, i32 -4>
342+
ret <4 x i32> %or
343+
}
344+
345+
define <4 x i32> @or_sext_v4i8_to_v4i32_constant_with_loss(<4 x i8> %a) {
346+
; CHECK-LABEL: @or_sext_v4i8_to_v4i32_constant_with_loss(
347+
; CHECK-NEXT: [[Z1:%.*]] = sext <4 x i8> [[A:%.*]] to <4 x i32>
348+
; CHECK-NEXT: [[OR:%.*]] = or <4 x i32> [[Z1]], <i32 -10000, i32 2, i32 3, i32 4>
349+
; CHECK-NEXT: ret <4 x i32> [[OR]]
350+
;
351+
%z1 = sext <4 x i8> %a to <4 x i32>
352+
%or = or <4 x i32> %z1, <i32 -10000, i32 2, i32 3, i32 4>
353+
ret <4 x i32> %or
354+
}
355+
356+
; Test truncate with flag preservation and one constant
357+
define <4 x i16> @and_trunc_nuw_nsw_constant(<4 x i32> %a) {
358+
; CHECK-LABEL: @and_trunc_nuw_nsw_constant(
359+
; CHECK-NEXT: [[T1:%.*]] = trunc nuw nsw <4 x i32> [[A:%.*]] to <4 x i16>
360+
; CHECK-NEXT: [[AND:%.*]] = and <4 x i16> [[T1]], <i16 1, i16 2, i16 3, i16 4>
361+
; CHECK-NEXT: ret <4 x i16> [[AND]]
362+
;
363+
%t1 = trunc nuw nsw <4 x i32> %a to <4 x i16>
364+
%and = and <4 x i16> %t1, <i16 1, i16 2, i16 3, i16 4>
365+
ret <4 x i16> %and
366+
}
367+
368+
define <4 x i8> @and_trunc_nuw_nsw_minus_constant(<4 x i32> %a) {
369+
; CHECK-LABEL: @and_trunc_nuw_nsw_minus_constant(
370+
; CHECK-NEXT: [[T1:%.*]] = trunc nuw nsw <4 x i32> [[A:%.*]] to <4 x i8>
371+
; CHECK-NEXT: [[AND:%.*]] = and <4 x i8> [[T1]], <i8 -16, i8 -15, i8 -14, i8 -13>
372+
; CHECK-NEXT: ret <4 x i8> [[AND]]
373+
;
374+
%t1 = trunc nuw nsw <4 x i32> %a to <4 x i8>
375+
%and = and <4 x i8> %t1, <i8 240, i8 241, i8 242, i8 243>
376+
ret <4 x i8> %and
377+
}
378+
379+
define <4 x i8> @and_trunc_nuw_nsw_multiconstant(<4 x i32> %a) {
380+
; CHECK-LABEL: @and_trunc_nuw_nsw_multiconstant(
381+
; CHECK-NEXT: [[T1:%.*]] = trunc nuw nsw <4 x i32> [[A:%.*]] to <4 x i8>
382+
; CHECK-NEXT: [[AND:%.*]] = and <4 x i8> [[T1]], <i8 -16, i8 1, i8 -14, i8 3>
383+
; CHECK-NEXT: ret <4 x i8> [[AND]]
384+
;
385+
%t1 = trunc nuw nsw <4 x i32> %a to <4 x i8>
386+
%and = and <4 x i8> %t1, <i8 240, i8 1, i8 242, i8 3>
387+
ret <4 x i8> %and
388+
}
389+
390+
; Test sign extend with nneg flag and one constant
391+
define <4 x i32> @or_zext_nneg_constant(<4 x i16> %a) {
392+
; CHECK-LABEL: @or_zext_nneg_constant(
393+
; CHECK-NEXT: [[Z1:%.*]] = zext nneg <4 x i16> [[A:%.*]] to <4 x i32>
394+
; CHECK-NEXT: [[OR:%.*]] = or <4 x i32> [[Z1]], <i32 1, i32 2, i32 3, i32 4>
395+
; CHECK-NEXT: ret <4 x i32> [[OR]]
396+
;
397+
%z1 = zext nneg <4 x i16> %a to <4 x i32>
398+
%or = or <4 x i32> %z1, <i32 1, i32 2, i32 3, i32 4>
399+
ret <4 x i32> %or
400+
}
401+
402+
define <4 x i32> @or_zext_nneg_minus_constant(<4 x i8> %a) {
403+
; CHECK-LABEL: @or_zext_nneg_minus_constant(
404+
; CHECK-NEXT: [[Z1:%.*]] = zext nneg <4 x i8> [[A:%.*]] to <4 x i32>
405+
; CHECK-NEXT: [[OR:%.*]] = or <4 x i32> [[Z1]], <i32 240, i32 241, i32 242, i32 243>
406+
; CHECK-NEXT: ret <4 x i32> [[OR]]
407+
;
408+
%z1 = zext nneg <4 x i8> %a to <4 x i32>
409+
%or = or <4 x i32> %z1, <i32 240, i32 241, i32 242, i32 243>
410+
ret <4 x i32> %or
411+
}
412+
413+
define <4 x i32> @or_zext_nneg_multiconstant(<4 x i8> %a) {
414+
; CHECK-LABEL: @or_zext_nneg_multiconstant(
415+
; CHECK-NEXT: [[Z1:%.*]] = zext nneg <4 x i8> [[A:%.*]] to <4 x i32>
416+
; CHECK-NEXT: [[OR:%.*]] = or <4 x i32> [[Z1]], <i32 240, i32 1, i32 242, i32 3>
417+
; CHECK-NEXT: ret <4 x i32> [[OR]]
418+
;
419+
%z1 = zext nneg <4 x i8> %a to <4 x i32>
420+
%or = or <4 x i32> %z1, <i32 240, i32 1, i32 242, i32 3>
421+
ret <4 x i32> %or
422+
}

0 commit comments

Comments
 (0)