Skip to content

Commit 0ade260

Browse files
authored
[DAG] visitBITCAST - fold (bitcast (freeze (load x))) -> (freeze (load (bitcast*)x)) (#164618)
Tweak the existing (bitcast (load x)) -> (load (bitcast*)x) fold to handle oneuse freeze as well Inspired by #163070 - this tries to avoid in place replacement of frozen nodes which has caused infinite loops in the past
1 parent 9b114c5 commit 0ade260

File tree

5 files changed

+128
-138
lines changed

5 files changed

+128
-138
lines changed

llvm/lib/CodeGen/SelectionDAG/DAGCombiner.cpp

Lines changed: 41 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16736,38 +16736,51 @@ SDValue DAGCombiner::visitBITCAST(SDNode *N) {
1673616736
}
1673716737

1673816738
// fold (conv (load x)) -> (load (conv*)x)
16739+
// fold (conv (freeze (load x))) -> (freeze (load (conv*)x))
1673916740
// If the resultant load doesn't need a higher alignment than the original!
16740-
if (ISD::isNormalLoad(N0.getNode()) && N0.hasOneUse() &&
16741-
// Do not remove the cast if the types differ in endian layout.
16742-
TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) ==
16743-
TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()) &&
16744-
// If the load is volatile, we only want to change the load type if the
16745-
// resulting load is legal. Otherwise we might increase the number of
16746-
// memory accesses. We don't care if the original type was legal or not
16747-
// as we assume software couldn't rely on the number of accesses of an
16748-
// illegal type.
16749-
((!LegalOperations && cast<LoadSDNode>(N0)->isSimple()) ||
16750-
TLI.isOperationLegal(ISD::LOAD, VT))) {
16751-
LoadSDNode *LN0 = cast<LoadSDNode>(N0);
16741+
auto CastLoad = [this, &VT](SDValue N0, const SDLoc &DL) {
16742+
if (!ISD::isNormalLoad(N0.getNode()) || !N0.hasOneUse())
16743+
return SDValue();
1675216744

16753-
if (TLI.isLoadBitCastBeneficial(N0.getValueType(), VT, DAG,
16754-
*LN0->getMemOperand())) {
16755-
// If the range metadata type does not match the new memory
16756-
// operation type, remove the range metadata.
16757-
if (const MDNode *MD = LN0->getRanges()) {
16758-
ConstantInt *Lower = mdconst::extract<ConstantInt>(MD->getOperand(0));
16759-
if (Lower->getBitWidth() != VT.getScalarSizeInBits() ||
16760-
!VT.isInteger()) {
16761-
LN0->getMemOperand()->clearRanges();
16762-
}
16745+
// Do not remove the cast if the types differ in endian layout.
16746+
if (TLI.hasBigEndianPartOrdering(N0.getValueType(), DAG.getDataLayout()) !=
16747+
TLI.hasBigEndianPartOrdering(VT, DAG.getDataLayout()))
16748+
return SDValue();
16749+
16750+
// If the load is volatile, we only want to change the load type if the
16751+
// resulting load is legal. Otherwise we might increase the number of
16752+
// memory accesses. We don't care if the original type was legal or not
16753+
// as we assume software couldn't rely on the number of accesses of an
16754+
// illegal type.
16755+
auto *LN0 = cast<LoadSDNode>(N0);
16756+
if ((LegalOperations || !LN0->isSimple()) &&
16757+
!TLI.isOperationLegal(ISD::LOAD, VT))
16758+
return SDValue();
16759+
16760+
if (!TLI.isLoadBitCastBeneficial(N0.getValueType(), VT, DAG,
16761+
*LN0->getMemOperand()))
16762+
return SDValue();
16763+
16764+
// If the range metadata type does not match the new memory
16765+
// operation type, remove the range metadata.
16766+
if (const MDNode *MD = LN0->getRanges()) {
16767+
ConstantInt *Lower = mdconst::extract<ConstantInt>(MD->getOperand(0));
16768+
if (Lower->getBitWidth() != VT.getScalarSizeInBits() || !VT.isInteger()) {
16769+
LN0->getMemOperand()->clearRanges();
1676316770
}
16764-
SDValue Load =
16765-
DAG.getLoad(VT, SDLoc(N), LN0->getChain(), LN0->getBasePtr(),
16766-
LN0->getMemOperand());
16767-
DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
16768-
return Load;
1676916771
}
16770-
}
16772+
SDValue Load = DAG.getLoad(VT, DL, LN0->getChain(), LN0->getBasePtr(),
16773+
LN0->getMemOperand());
16774+
DAG.ReplaceAllUsesOfValueWith(N0.getValue(1), Load.getValue(1));
16775+
return Load;
16776+
};
16777+
16778+
if (SDValue NewLd = CastLoad(N0, SDLoc(N)))
16779+
return NewLd;
16780+
16781+
if (N0.getOpcode() == ISD::FREEZE && N0.hasOneUse())
16782+
if (SDValue NewLd = CastLoad(N0.getOperand(0), SDLoc(N)))
16783+
return DAG.getFreeze(NewLd);
1677116784

1677216785
if (SDValue V = foldBitcastedFPLogic(N, DAG, TLI))
1677316786
return V;

llvm/lib/Target/X86/X86ISelLowering.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3454,6 +3454,12 @@ bool X86TargetLowering::isLoadBitCastBeneficial(EVT LoadVT, EVT BitcastVT,
34543454
isTypeLegal(LoadVT) && isTypeLegal(BitcastVT))
34553455
return true;
34563456

3457+
// If we have a large vector type (even if illegal), don't bitcast to large
3458+
// (illegal) scalar types. Better to load fewer vectors and extract.
3459+
if (LoadVT.isVector() && !BitcastVT.isVector() && LoadVT.isInteger() &&
3460+
BitcastVT.isInteger() && (LoadVT.getSizeInBits() % 128) == 0)
3461+
return false;
3462+
34573463
return TargetLowering::isLoadBitCastBeneficial(LoadVT, BitcastVT, DAG, MMO);
34583464
}
34593465

llvm/test/CodeGen/X86/avx10_2_512bf16-arith.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ define <32 x bfloat> @test_int_x86_avx10_maskz_sub_bf16_512(<32 x bfloat> %src,
9494
;
9595
; X86-LABEL: test_int_x86_avx10_maskz_sub_bf16_512:
9696
; X86: # %bb.0:
97-
; X86-NEXT: kmovd {{[0-9]+}}(%esp), %k1 # encoding: [0xc4,0xe1,0xf9,0x90,0x4c,0x24,0x04]
9897
; X86-NEXT: movl {{[0-9]+}}(%esp), %eax # encoding: [0x8b,0x44,0x24,0x08]
98+
; X86-NEXT: kmovd {{[0-9]+}}(%esp), %k1 # encoding: [0xc4,0xe1,0xf9,0x90,0x4c,0x24,0x04]
9999
; X86-NEXT: vsubbf16 %zmm2, %zmm1, %zmm0 {%k1} {z} # encoding: [0x62,0xf5,0x75,0xc9,0x5c,0xc2]
100100
; X86-NEXT: vsubbf16 (%eax), %zmm1, %zmm1 # encoding: [0x62,0xf5,0x75,0x48,0x5c,0x08]
101101
; X86-NEXT: vsubbf16 %zmm1, %zmm0, %zmm0 {%k1} # encoding: [0x62,0xf5,0x7d,0x49,0x5c,0xc1]

llvm/test/CodeGen/X86/avx10_2bf16-arith.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -147,8 +147,8 @@ define <16 x bfloat> @test_int_x86_avx10_maskz_sub_bf16_256(<16 x bfloat> %src,
147147
;
148148
; X86-LABEL: test_int_x86_avx10_maskz_sub_bf16_256:
149149
; X86: # %bb.0:
150-
; X86-NEXT: kmovw {{[0-9]+}}(%esp), %k1 # encoding: [0xc5,0xf8,0x90,0x4c,0x24,0x04]
151150
; X86-NEXT: movl {{[0-9]+}}(%esp), %eax # encoding: [0x8b,0x44,0x24,0x08]
151+
; X86-NEXT: kmovw {{[0-9]+}}(%esp), %k1 # encoding: [0xc5,0xf8,0x90,0x4c,0x24,0x04]
152152
; X86-NEXT: vsubbf16 %ymm2, %ymm1, %ymm0 {%k1} {z} # encoding: [0x62,0xf5,0x75,0xa9,0x5c,0xc2]
153153
; X86-NEXT: vsubbf16 (%eax), %ymm1, %ymm1 # encoding: [0x62,0xf5,0x75,0x28,0x5c,0x08]
154154
; X86-NEXT: vsubbf16 %ymm1, %ymm0, %ymm0 {%k1} # encoding: [0x62,0xf5,0x7d,0x29,0x5c,0xc1]
@@ -201,8 +201,8 @@ define <8 x bfloat> @test_int_x86_avx10_maskz_sub_bf16_128(<8 x bfloat> %src, <8
201201
;
202202
; X86-LABEL: test_int_x86_avx10_maskz_sub_bf16_128:
203203
; X86: # %bb.0:
204-
; X86-NEXT: kmovb {{[0-9]+}}(%esp), %k1 # encoding: [0xc5,0xf9,0x90,0x4c,0x24,0x04]
205204
; X86-NEXT: movl {{[0-9]+}}(%esp), %eax # encoding: [0x8b,0x44,0x24,0x08]
205+
; X86-NEXT: kmovb {{[0-9]+}}(%esp), %k1 # encoding: [0xc5,0xf9,0x90,0x4c,0x24,0x04]
206206
; X86-NEXT: vsubbf16 %xmm2, %xmm1, %xmm0 {%k1} {z} # encoding: [0x62,0xf5,0x75,0x89,0x5c,0xc2]
207207
; X86-NEXT: vsubbf16 (%eax), %xmm1, %xmm1 # encoding: [0x62,0xf5,0x75,0x08,0x5c,0x08]
208208
; X86-NEXT: vsubbf16 %xmm1, %xmm0, %xmm0 {%k1} # encoding: [0x62,0xf5,0x7d,0x09,0x5c,0xc1]

0 commit comments

Comments
 (0)