Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions llvm/lib/Target/X86/X86ISelLowering.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -54634,6 +54634,7 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
const X86Subtarget &Subtarget) {
EVT VT = N->getValueType(0);
SDValue Src = N->getOperand(0);
EVT SrcVT = Src.getValueType();
SDLoc DL(N);

// Attempt to pre-truncate inputs to arithmetic ops instead.
Expand All @@ -54652,6 +54653,39 @@ static SDValue combineTruncate(SDNode *N, SelectionDAG &DAG,
if (SDValue V = combinePMULH(Src, VT, DL, DAG, Subtarget))
return V;

// Fold trunc(srl(load(p),amt) -> load(p+amt/8)
// If we're shifting down whole byte+pow2 aligned bit chunks from a larger
// load for truncation, see if we can convert the shift into a pointer
// offset instead. Limit this to normal (non-ext) scalar integer loads.
if (SrcVT.isScalarInteger() && Src.getOpcode() == ISD::SRL &&
Src.hasOneUse() && Src.getOperand(0).hasOneUse() &&
ISD::isNormalLoad(Src.getOperand(0).getNode())) {
auto *Ld = cast<LoadSDNode>(Src.getOperand(0));
if (Ld->isSimple() && VT.isByteSized() &&
isPowerOf2_64(VT.getSizeInBits())) {
SDValue ShAmt = Src.getOperand(1);
Comment on lines +54663 to +54666
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we check ShAmt is octuple?

Copy link
Collaborator Author

@RKSimon RKSimon Oct 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What did you have in mind? We check that that VT is byte sized (multiple of 8 bits) and that its pow2 - then check ShAmt is zero in the lowest bits matching the alignment of VT.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could relax this to just check that ShAmt byte aligned:

KnownAmt.countMinTrailingZeros() >= 3

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I get it now. It's less clear than >= 3 :)

KnownBits KnownAmt = DAG.computeKnownBits(ShAmt);
// Check the shift amount is aligned to the truncated size.
// Check the truncation doesn't use any shifted in (zero) top bits.
if (KnownAmt.countMinTrailingZeros() >= Log2_64(VT.getSizeInBits()) &&
KnownAmt.getMaxValue().ule(SrcVT.getSizeInBits() -
VT.getSizeInBits())) {
EVT PtrVT = Ld->getBasePtr().getValueType();
SDValue PtrBitOfs = DAG.getZExtOrTrunc(ShAmt, DL, PtrVT);
SDValue PtrByteOfs =
DAG.getNode(ISD::SRL, DL, PtrVT, PtrBitOfs,
DAG.getShiftAmountConstant(3, PtrVT, DL));
SDValue NewPtr = DAG.getMemBasePlusOffset(
Ld->getBasePtr(), PtrByteOfs, DL, SDNodeFlags::NoUnsignedWrap);
SDValue NewLoad =
DAG.getLoad(VT, DL, Ld->getChain(), NewPtr, Ld->getMemOperand());
DAG.ReplaceAllUsesOfValueWith(Src.getOperand(0).getValue(1),
NewLoad.getValue(1));
return NewLoad;
}
}
}

// The bitcast source is a direct mmx result.
// Detect bitcasts between i32 to x86mmx
if (Src.getOpcode() == ISD::BITCAST && VT == MVT::i32) {
Expand Down
6 changes: 2 additions & 4 deletions llvm/test/CodeGen/X86/bfloat-calling-conv.ll
Original file line number Diff line number Diff line change
Expand Up @@ -660,8 +660,7 @@ define <3 x bfloat> @call_ret_v3bf16(ptr %ptr) #0 {
; SSE2-LABEL: call_ret_v3bf16:
; SSE2: # %bb.0:
; SSE2-NEXT: pushq %rax
; SSE2-NEXT: movl 4(%rdi), %eax
; SSE2-NEXT: pinsrw $0, %eax, %xmm1
; SSE2-NEXT: pinsrw $0, 4(%rdi), %xmm1
; SSE2-NEXT: movd {{.*#+}} xmm0 = mem[0],zero,zero,zero
; SSE2-NEXT: punpckldq {{.*#+}} xmm0 = xmm0[0],xmm1[0],xmm0[1],xmm1[1]
; SSE2-NEXT: callq returns_v3bf16@PLT
Expand Down Expand Up @@ -725,8 +724,7 @@ define <3 x bfloat> @call_ret_v3bf16(ptr %ptr) #0 {
; AVXNECONVERT-LABEL: call_ret_v3bf16:
; AVXNECONVERT: # %bb.0:
; AVXNECONVERT-NEXT: pushq %rax
; AVXNECONVERT-NEXT: movl 4(%rdi), %eax
; AVXNECONVERT-NEXT: vpinsrw $0, %eax, %xmm0, %xmm0
; AVXNECONVERT-NEXT: vpinsrw $0, 4(%rdi), %xmm0, %xmm0
; AVXNECONVERT-NEXT: vmovss {{.*#+}} xmm1 = mem[0],zero,zero,zero
; AVXNECONVERT-NEXT: vinsertps {{.*#+}} xmm0 = xmm1[0],xmm0[0],zero,zero
; AVXNECONVERT-NEXT: callq returns_v3bf16@PLT
Expand Down
Loading