Skip to content

Commit ddaa2c3

Browse files
authored
[Clang] Generalize interp__builtin_ia32_shuffle_generic to handle single op permute shuffles. (llvm#167236)
This patch extends `interp__builtin_ia32_shuffle_generic` and `evalShuffleGeneric` to handle both 2-argument and 3-argument patterns, replacing specialized shuffle functions with the unified handler. Resolves llvm#166342
1 parent 125b6b5 commit ddaa2c3

File tree

2 files changed

+211
-295
lines changed

2 files changed

+211
-295
lines changed

clang/lib/AST/ByteCode/InterpBuiltin.cpp

Lines changed: 91 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -2841,76 +2841,6 @@ static bool interp__builtin_blend(InterpState &S, CodePtr OpPC,
28412841
return true;
28422842
}
28432843

2844-
static bool interp__builtin_ia32_pshufb(InterpState &S, CodePtr OpPC,
2845-
const CallExpr *Call) {
2846-
assert(Call->getNumArgs() == 2 && "masked forms handled via select*");
2847-
const Pointer &Control = S.Stk.pop<Pointer>();
2848-
const Pointer &Src = S.Stk.pop<Pointer>();
2849-
const Pointer &Dst = S.Stk.peek<Pointer>();
2850-
2851-
unsigned NumElems = Dst.getNumElems();
2852-
assert(NumElems == Control.getNumElems());
2853-
assert(NumElems == Dst.getNumElems());
2854-
2855-
for (unsigned Idx = 0; Idx != NumElems; ++Idx) {
2856-
uint8_t Ctlb = static_cast<uint8_t>(Control.elem<int8_t>(Idx));
2857-
2858-
if (Ctlb & 0x80) {
2859-
Dst.elem<int8_t>(Idx) = 0;
2860-
} else {
2861-
unsigned LaneBase = (Idx / 16) * 16;
2862-
unsigned SrcOffset = Ctlb & 0x0F;
2863-
unsigned SrcIdx = LaneBase + SrcOffset;
2864-
2865-
Dst.elem<int8_t>(Idx) = Src.elem<int8_t>(SrcIdx);
2866-
}
2867-
}
2868-
Dst.initializeAllElements();
2869-
return true;
2870-
}
2871-
2872-
static bool interp__builtin_ia32_pshuf(InterpState &S, CodePtr OpPC,
2873-
const CallExpr *Call, bool IsShufHW) {
2874-
assert(Call->getNumArgs() == 2 && "masked forms handled via select*");
2875-
APSInt ControlImm = popToAPSInt(S, Call->getArg(1));
2876-
const Pointer &Src = S.Stk.pop<Pointer>();
2877-
const Pointer &Dst = S.Stk.peek<Pointer>();
2878-
2879-
unsigned NumElems = Dst.getNumElems();
2880-
PrimType ElemT = Dst.getFieldDesc()->getPrimType();
2881-
2882-
unsigned ElemBits = static_cast<unsigned>(primSize(ElemT) * 8);
2883-
if (ElemBits != 16 && ElemBits != 32)
2884-
return false;
2885-
2886-
unsigned LaneElts = 128u / ElemBits;
2887-
assert(LaneElts && (NumElems % LaneElts == 0));
2888-
2889-
uint8_t Ctl = static_cast<uint8_t>(ControlImm.getZExtValue());
2890-
2891-
for (unsigned Idx = 0; Idx != NumElems; Idx++) {
2892-
unsigned LaneBase = (Idx / LaneElts) * LaneElts;
2893-
unsigned LaneIdx = Idx % LaneElts;
2894-
unsigned SrcIdx = Idx;
2895-
unsigned Sel = (Ctl >> (2 * (LaneIdx & 0x3))) & 0x3;
2896-
if (ElemBits == 32) {
2897-
SrcIdx = LaneBase + Sel;
2898-
} else {
2899-
constexpr unsigned HalfSize = 4;
2900-
bool InHigh = LaneIdx >= HalfSize;
2901-
if (!IsShufHW && !InHigh) {
2902-
SrcIdx = LaneBase + Sel;
2903-
} else if (IsShufHW && InHigh) {
2904-
SrcIdx = LaneBase + HalfSize + Sel;
2905-
}
2906-
}
2907-
2908-
INT_TYPE_SWITCH_NO_BOOL(ElemT, { Dst.elem<T>(Idx) = Src.elem<T>(SrcIdx); });
2909-
}
2910-
Dst.initializeAllElements();
2911-
return true;
2912-
}
2913-
29142844
static bool interp__builtin_ia32_test_op(
29152845
InterpState &S, CodePtr OpPC, const CallExpr *Call,
29162846
llvm::function_ref<bool(const APInt &A, const APInt &B)> Fn) {
@@ -3377,61 +3307,46 @@ static bool interp__builtin_ia32_vpconflict(InterpState &S, CodePtr OpPC,
33773307
return true;
33783308
}
33793309

3380-
static bool interp__builtin_x86_byteshift(
3381-
InterpState &S, CodePtr OpPC, const CallExpr *Call, unsigned ID,
3382-
llvm::function_ref<APInt(const Pointer &, unsigned Lane, unsigned I,
3383-
unsigned Shift)>
3384-
Fn) {
3385-
assert(Call->getNumArgs() == 2);
3386-
3387-
APSInt ImmAPS = popToAPSInt(S, Call->getArg(1));
3388-
uint64_t Shift = ImmAPS.getZExtValue() & 0xff;
3389-
3390-
const Pointer &Src = S.Stk.pop<Pointer>();
3391-
if (!Src.getFieldDesc()->isPrimitiveArray())
3392-
return false;
3393-
3394-
unsigned NumElems = Src.getNumElems();
3395-
const Pointer &Dst = S.Stk.peek<Pointer>();
3396-
PrimType ElemT = Src.getFieldDesc()->getPrimType();
3397-
3398-
for (unsigned Lane = 0; Lane != NumElems; Lane += 16) {
3399-
for (unsigned I = 0; I != 16; ++I) {
3400-
unsigned Base = Lane + I;
3401-
APSInt Result = APSInt(Fn(Src, Lane, I, Shift));
3402-
INT_TYPE_SWITCH_NO_BOOL(ElemT,
3403-
{ Dst.elem<T>(Base) = static_cast<T>(Result); });
3404-
}
3405-
}
3406-
3407-
Dst.initializeAllElements();
3408-
3409-
return true;
3410-
}
3411-
34123310
static bool interp__builtin_ia32_shuffle_generic(
34133311
InterpState &S, CodePtr OpPC, const CallExpr *Call,
34143312
llvm::function_ref<std::pair<unsigned, int>(unsigned, unsigned)>
34153313
GetSourceIndex) {
34163314

3417-
assert(Call->getNumArgs() == 3);
3315+
assert(Call->getNumArgs() == 2 || Call->getNumArgs() == 3);
34183316

34193317
unsigned ShuffleMask = 0;
34203318
Pointer A, MaskVector, B;
3421-
3422-
QualType Arg2Type = Call->getArg(2)->getType();
34233319
bool IsVectorMask = false;
3424-
if (Arg2Type->isVectorType()) {
3425-
IsVectorMask = true;
3426-
B = S.Stk.pop<Pointer>();
3427-
MaskVector = S.Stk.pop<Pointer>();
3428-
A = S.Stk.pop<Pointer>();
3429-
} else if (Arg2Type->isIntegerType()) {
3430-
ShuffleMask = popToAPSInt(S, Call->getArg(2)).getZExtValue();
3431-
B = S.Stk.pop<Pointer>();
3432-
A = S.Stk.pop<Pointer>();
3320+
bool IsSingleOperand = (Call->getNumArgs() == 2);
3321+
3322+
if (IsSingleOperand) {
3323+
QualType MaskType = Call->getArg(1)->getType();
3324+
if (MaskType->isVectorType()) {
3325+
IsVectorMask = true;
3326+
MaskVector = S.Stk.pop<Pointer>();
3327+
A = S.Stk.pop<Pointer>();
3328+
B = A;
3329+
} else if (MaskType->isIntegerType()) {
3330+
ShuffleMask = popToAPSInt(S, Call->getArg(1)).getZExtValue();
3331+
A = S.Stk.pop<Pointer>();
3332+
B = A;
3333+
} else {
3334+
return false;
3335+
}
34333336
} else {
3434-
return false;
3337+
QualType Arg2Type = Call->getArg(2)->getType();
3338+
if (Arg2Type->isVectorType()) {
3339+
IsVectorMask = true;
3340+
B = S.Stk.pop<Pointer>();
3341+
MaskVector = S.Stk.pop<Pointer>();
3342+
A = S.Stk.pop<Pointer>();
3343+
} else if (Arg2Type->isIntegerType()) {
3344+
ShuffleMask = popToAPSInt(S, Call->getArg(2)).getZExtValue();
3345+
B = S.Stk.pop<Pointer>();
3346+
A = S.Stk.pop<Pointer>();
3347+
} else {
3348+
return false;
3349+
}
34353350
}
34363351

34373352
QualType Arg0Type = Call->getArg(0)->getType();
@@ -3455,6 +3370,7 @@ static bool interp__builtin_ia32_shuffle_generic(
34553370
ShuffleMask = static_cast<unsigned>(MaskVector.elem<T>(DstIdx));
34563371
});
34573372
}
3373+
34583374
auto [SrcVecIdx, SrcIdx] = GetSourceIndex(DstIdx, ShuffleMask);
34593375

34603376
if (SrcIdx < 0) {
@@ -4555,22 +4471,58 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
45554471
case X86::BI__builtin_ia32_pshufb128:
45564472
case X86::BI__builtin_ia32_pshufb256:
45574473
case X86::BI__builtin_ia32_pshufb512:
4558-
return interp__builtin_ia32_pshufb(S, OpPC, Call);
4474+
return interp__builtin_ia32_shuffle_generic(
4475+
S, OpPC, Call, [](unsigned DstIdx, unsigned ShuffleMask) {
4476+
uint8_t Ctlb = static_cast<uint8_t>(ShuffleMask);
4477+
if (Ctlb & 0x80)
4478+
return std::make_pair(0, -1);
4479+
4480+
unsigned LaneBase = (DstIdx / 16) * 16;
4481+
unsigned SrcOffset = Ctlb & 0x0F;
4482+
unsigned SrcIdx = LaneBase + SrcOffset;
4483+
return std::make_pair(0, static_cast<int>(SrcIdx));
4484+
});
45594485

45604486
case X86::BI__builtin_ia32_pshuflw:
45614487
case X86::BI__builtin_ia32_pshuflw256:
45624488
case X86::BI__builtin_ia32_pshuflw512:
4563-
return interp__builtin_ia32_pshuf(S, OpPC, Call, false);
4489+
return interp__builtin_ia32_shuffle_generic(
4490+
S, OpPC, Call, [](unsigned DstIdx, unsigned ShuffleMask) {
4491+
unsigned LaneBase = (DstIdx / 8) * 8;
4492+
unsigned LaneIdx = DstIdx % 8;
4493+
if (LaneIdx < 4) {
4494+
unsigned Sel = (ShuffleMask >> (2 * LaneIdx)) & 0x3;
4495+
return std::make_pair(0, static_cast<int>(LaneBase + Sel));
4496+
}
4497+
4498+
return std::make_pair(0, static_cast<int>(DstIdx));
4499+
});
45644500

45654501
case X86::BI__builtin_ia32_pshufhw:
45664502
case X86::BI__builtin_ia32_pshufhw256:
45674503
case X86::BI__builtin_ia32_pshufhw512:
4568-
return interp__builtin_ia32_pshuf(S, OpPC, Call, true);
4504+
return interp__builtin_ia32_shuffle_generic(
4505+
S, OpPC, Call, [](unsigned DstIdx, unsigned ShuffleMask) {
4506+
unsigned LaneBase = (DstIdx / 8) * 8;
4507+
unsigned LaneIdx = DstIdx % 8;
4508+
if (LaneIdx >= 4) {
4509+
unsigned Sel = (ShuffleMask >> (2 * (LaneIdx - 4))) & 0x3;
4510+
return std::make_pair(0, static_cast<int>(LaneBase + 4 + Sel));
4511+
}
4512+
4513+
return std::make_pair(0, static_cast<int>(DstIdx));
4514+
});
45694515

45704516
case X86::BI__builtin_ia32_pshufd:
45714517
case X86::BI__builtin_ia32_pshufd256:
45724518
case X86::BI__builtin_ia32_pshufd512:
4573-
return interp__builtin_ia32_pshuf(S, OpPC, Call, false);
4519+
return interp__builtin_ia32_shuffle_generic(
4520+
S, OpPC, Call, [](unsigned DstIdx, unsigned ShuffleMask) {
4521+
unsigned LaneBase = (DstIdx / 4) * 4;
4522+
unsigned LaneIdx = DstIdx % 4;
4523+
unsigned Sel = (ShuffleMask >> (2 * LaneIdx)) & 0x3;
4524+
return std::make_pair(0, static_cast<int>(LaneBase + Sel));
4525+
});
45744526

45754527
case X86::BI__builtin_ia32_kandqi:
45764528
case X86::BI__builtin_ia32_kandhi:
@@ -4728,13 +4680,16 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
47284680
// The lane width is hardcoded to 16 to match the SIMD register size,
47294681
// but the algorithm processes one byte per iteration,
47304682
// so APInt(8, ...) is correct and intentional.
4731-
return interp__builtin_x86_byteshift(
4732-
S, OpPC, Call, BuiltinID,
4733-
[](const Pointer &Src, unsigned Lane, unsigned I, unsigned Shift) {
4734-
if (I < Shift) {
4735-
return APInt(8, 0);
4736-
}
4737-
return APInt(8, Src.elem<uint8_t>(Lane + I - Shift));
4683+
return interp__builtin_ia32_shuffle_generic(
4684+
S, OpPC, Call,
4685+
[](unsigned DstIdx, unsigned Shift) -> std::pair<unsigned, int> {
4686+
unsigned LaneBase = (DstIdx / 16) * 16;
4687+
unsigned LaneIdx = DstIdx % 16;
4688+
if (LaneIdx < Shift)
4689+
return std::make_pair(0, -1);
4690+
4691+
return std::make_pair(0,
4692+
static_cast<int>(LaneBase + LaneIdx - Shift));
47384693
});
47394694

47404695
case X86::BI__builtin_ia32_psrldqi128_byteshift:
@@ -4744,14 +4699,16 @@ bool InterpretBuiltin(InterpState &S, CodePtr OpPC, const CallExpr *Call,
47444699
// The lane width is hardcoded to 16 to match the SIMD register size,
47454700
// but the algorithm processes one byte per iteration,
47464701
// so APInt(8, ...) is correct and intentional.
4747-
return interp__builtin_x86_byteshift(
4748-
S, OpPC, Call, BuiltinID,
4749-
[](const Pointer &Src, unsigned Lane, unsigned I, unsigned Shift) {
4750-
if (I + Shift < 16) {
4751-
return APInt(8, Src.elem<uint8_t>(Lane + I + Shift));
4752-
}
4753-
4754-
return APInt(8, 0);
4702+
return interp__builtin_ia32_shuffle_generic(
4703+
S, OpPC, Call,
4704+
[](unsigned DstIdx, unsigned Shift) -> std::pair<unsigned, int> {
4705+
unsigned LaneBase = (DstIdx / 16) * 16;
4706+
unsigned LaneIdx = DstIdx % 16;
4707+
if (LaneIdx + Shift < 16)
4708+
return std::make_pair(0,
4709+
static_cast<int>(LaneBase + LaneIdx + Shift));
4710+
4711+
return std::make_pair(0, -1);
47554712
});
47564713

47574714
case X86::BI__builtin_ia32_palignr128:

0 commit comments

Comments
 (0)