Skip to content

Commit 19e1011

Browse files
SamTebbs33MacDue
andauthored
[SelectionDAG] Fix unsafe cases for loop.dependence.{war/raw}.mask (#168565)
Both `LOOP_DEPENDENCE_WAR_MASK` and `LOOP_DEPENDENCE_RAW_MASK` are currently hard to split correctly, and there are a number of incorrect cases. The difficulty comes from how the intrinsics are defined. For example, take `LOOP_DEPENDENCE_WAR_MASK`. It is defined as the OR of: * `(ptrB - ptrA) <= 0` * `elementSize * lane < (ptrB - ptrA)` Now, if we want to split a loop dependence mask for the high half of the mask we want to compute: * `(ptrB - ptrA) <= 0` * `elementSize * (lane + LoVT.getElementCount()) < (ptrB - ptrA)` However, with the current opcode definitions, we can only modify ptrA or ptrB, which may change the result of the first case, which should be invariant to the lane. This patch resolves these cases by adding a "lane offset" to the ISD opcodes. The lane offset is always a constant. For scalable masks, it is implicitly multiplied by vscale. This makes splitting trivial as we increment the lane offset by `LoVT.getElementCount()` now. Note: In the AArch64 backend, we only support zero lane offsets (as other cases are tricky to lower to whilewr/rw). --------- Co-authored-by: Benjamin Maxwell <[email protected]>
1 parent fbde1dc commit 19e1011

File tree

12 files changed

+267
-1134
lines changed

12 files changed

+267
-1134
lines changed

llvm/include/llvm/CodeGen/ISDOpcodes.h

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1569,8 +1569,21 @@ enum NodeType {
15691569
GET_ACTIVE_LANE_MASK,
15701570

15711571
// The `llvm.loop.dependence.{war, raw}.mask` intrinsics
1572-
// Operands: Load pointer, Store pointer, Element size
1572+
// Operands: Load pointer, Store pointer, Element size, Lane offset
15731573
// Output: Mask
1574+
//
1575+
// Note: The semantics of these opcodes differ slightly from the intrinsics.
1576+
// Wherever "lane" (meaning lane index) occurs in the intrinsic definition, it
1577+
// is replaced with (lane + lane_offset) for the ISD opcode.
1578+
//
1579+
// E.g., for LOOP_DEPENDENCE_WAR_MASK:
1580+
// `elementSize * lane < (ptrB - ptrA)`
1581+
// Becomes:
1582+
// `elementSize * (lane + lane_offset) < (ptrB - ptrA)`
1583+
//
1584+
// This is done to allow for trivial splitting of the operation. Note: The
1585+
// lane offset is always a constant, for scalable masks, it is implicitly
1586+
// multiplied by vscale.
15741587
LOOP_DEPENDENCE_WAR_MASK,
15751588
LOOP_DEPENDENCE_RAW_MASK,
15761589

llvm/include/llvm/Target/TargetSelectionDAG.td

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -347,6 +347,11 @@ def SDTAtomicLoad : SDTypeProfile<1, 1, [
347347
SDTCisPtrTy<1>
348348
]>;
349349

350+
def SDTLoopDepMask : SDTypeProfile<1, 4,
351+
[/*Result=*/SDTCisVec<0>, /*PtrA=*/SDTCisInt<1>, /*PtrB=*/SDTCisInt<2>,
352+
/*EltSizeInBytes=*/SDTCisInt<3>, /*LaneOffset=*/SDTCisInt<4>,
353+
SDTCisSameAs<2, 1>]>;
354+
350355
class SDCallSeqStart<list<SDTypeConstraint> constraints> :
351356
SDTypeProfile<0, 2, constraints>;
352357
class SDCallSeqEnd<list<SDTypeConstraint> constraints> :
@@ -839,10 +844,6 @@ def step_vector : SDNode<"ISD::STEP_VECTOR", SDTypeProfile<1, 1,
839844
[SDTCisVec<0>, SDTCisInt<1>]>, []>;
840845
def scalar_to_vector : SDNode<"ISD::SCALAR_TO_VECTOR", SDTypeProfile<1, 1, []>,
841846
[]>;
842-
843-
def SDTLoopDepMask : SDTypeProfile<1, 3, [SDTCisVec<0>, SDTCisInt<1>,
844-
SDTCisSameAs<2, 1>, SDTCisInt<3>,
845-
SDTCVecEltisVT<0,i1>]>;
846847
def loop_dependence_war_mask : SDNode<"ISD::LOOP_DEPENDENCE_WAR_MASK",
847848
SDTLoopDepMask, []>;
848849
def loop_dependence_raw_mask : SDNode<"ISD::LOOP_DEPENDENCE_RAW_MASK",

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorOps.cpp

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1807,46 +1807,41 @@ SDValue VectorLegalizer::ExpandVP_FCOPYSIGN(SDNode *Node) {
18071807

18081808
SDValue VectorLegalizer::ExpandLOOP_DEPENDENCE_MASK(SDNode *N) {
18091809
SDLoc DL(N);
1810+
EVT VT = N->getValueType(0);
18101811
SDValue SourceValue = N->getOperand(0);
18111812
SDValue SinkValue = N->getOperand(1);
1812-
SDValue EltSize = N->getOperand(2);
1813+
SDValue EltSizeInBytes = N->getOperand(2);
1814+
1815+
// Note: The lane offset is scalable if the mask is scalable.
1816+
ElementCount LaneOffsetEC =
1817+
ElementCount::get(N->getConstantOperandVal(3), VT.isScalableVT());
18131818

1814-
bool IsReadAfterWrite = N->getOpcode() == ISD::LOOP_DEPENDENCE_RAW_MASK;
1815-
EVT VT = N->getValueType(0);
18161819
EVT PtrVT = SourceValue->getValueType(0);
1820+
bool IsReadAfterWrite = N->getOpcode() == ISD::LOOP_DEPENDENCE_RAW_MASK;
18171821

1822+
// Take the difference between the pointers and divided by the element size,
1823+
// to see how many lanes separate them.
18181824
SDValue Diff = DAG.getNode(ISD::SUB, DL, PtrVT, SinkValue, SourceValue);
18191825
if (IsReadAfterWrite)
18201826
Diff = DAG.getNode(ISD::ABS, DL, PtrVT, Diff);
1827+
Diff = DAG.getNode(ISD::SDIV, DL, PtrVT, Diff, EltSizeInBytes);
18211828

1822-
Diff = DAG.getNode(ISD::SDIV, DL, PtrVT, Diff, EltSize);
1823-
1824-
// If the difference is positive then some elements may alias
1825-
EVT CmpVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(),
1826-
Diff.getValueType());
1829+
// The pointers do not alias if:
1830+
// * Diff <= 0 (WAR_MASK)
1831+
// * Diff == 0 (RAW_MASK)
1832+
EVT CmpVT =
1833+
TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(), PtrVT);
18271834
SDValue Zero = DAG.getConstant(0, DL, PtrVT);
18281835
SDValue Cmp = DAG.getSetCC(DL, CmpVT, Diff, Zero,
18291836
IsReadAfterWrite ? ISD::SETEQ : ISD::SETLE);
18301837

1831-
// Create the lane mask
1832-
EVT SplatVT = VT.changeElementType(PtrVT);
1833-
SDValue DiffSplat = DAG.getSplat(SplatVT, DL, Diff);
1834-
SDValue VectorStep = DAG.getStepVector(DL, SplatVT);
1835-
EVT MaskVT = VT.changeElementType(MVT::i1);
1836-
SDValue DiffMask =
1837-
DAG.getSetCC(DL, MaskVT, VectorStep, DiffSplat, ISD::CondCode::SETULT);
1838+
// The pointers do not alias if:
1839+
// Lane + LaneOffset < Diff (WAR/RAW_MASK)
1840+
SDValue LaneOffset = DAG.getElementCount(DL, PtrVT, LaneOffsetEC);
1841+
SDValue MaskN =
1842+
DAG.getSelect(DL, PtrVT, Cmp, DAG.getConstant(-1, DL, PtrVT), Diff);
18381843

1839-
EVT EltVT = VT.getVectorElementType();
1840-
// Extend the diff setcc in case the intrinsic has been promoted to a vector
1841-
// type with elements larger than i1
1842-
if (EltVT.getScalarSizeInBits() > MaskVT.getScalarSizeInBits())
1843-
DiffMask = DAG.getNode(ISD::ANY_EXTEND, DL, VT, DiffMask);
1844-
1845-
// Splat the compare result then OR it with the lane mask
1846-
if (CmpVT.getScalarSizeInBits() < EltVT.getScalarSizeInBits())
1847-
Cmp = DAG.getNode(ISD::ZERO_EXTEND, DL, EltVT, Cmp);
1848-
SDValue Splat = DAG.getSplat(VT, DL, Cmp);
1849-
return DAG.getNode(ISD::OR, DL, VT, DiffMask, Splat);
1844+
return DAG.getNode(ISD::GET_ACTIVE_LANE_MASK, DL, VT, LaneOffset, MaskN);
18501845
}
18511846

18521847
void VectorLegalizer::ExpandFP_TO_UINT(SDNode *Node,

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -404,19 +404,33 @@ SDValue DAGTypeLegalizer::ScalarizeVecRes_MERGE_VALUES(SDNode *N,
404404
}
405405

406406
SDValue DAGTypeLegalizer::ScalarizeVecRes_LOOP_DEPENDENCE_MASK(SDNode *N) {
407+
SDLoc DL(N);
407408
SDValue SourceValue = N->getOperand(0);
408409
SDValue SinkValue = N->getOperand(1);
409-
SDValue EltSize = N->getOperand(2);
410+
SDValue EltSizeInBytes = N->getOperand(2);
411+
SDValue LaneOffset = N->getOperand(3);
412+
410413
EVT PtrVT = SourceValue->getValueType(0);
411-
SDLoc DL(N);
414+
bool IsReadAfterWrite = N->getOpcode() == ISD::LOOP_DEPENDENCE_RAW_MASK;
412415

416+
// Take the difference between the pointers and divided by the element size,
417+
// to see how many lanes separate them.
413418
SDValue Diff = DAG.getNode(ISD::SUB, DL, PtrVT, SinkValue, SourceValue);
419+
if (IsReadAfterWrite)
420+
Diff = DAG.getNode(ISD::ABS, DL, PtrVT, Diff);
421+
Diff = DAG.getNode(ISD::SDIV, DL, PtrVT, Diff, EltSizeInBytes);
422+
423+
// The pointers do not alias if:
424+
// * Diff <= 0 || LaneOffset < Diff (WAR_MASK)
425+
// * Diff == 0 || LaneOffset < abs(Diff) (RAW_MASK)
426+
// Note: If LaneOffset is zero, both cases will fold to "true".
414427
EVT CmpVT = TLI.getSetCCResultType(DAG.getDataLayout(), *DAG.getContext(),
415428
Diff.getValueType());
416429
SDValue Zero = DAG.getConstant(0, DL, PtrVT);
417-
return DAG.getNode(ISD::OR, DL, CmpVT,
418-
DAG.getSetCC(DL, CmpVT, Diff, EltSize, ISD::SETGE),
419-
DAG.getSetCC(DL, CmpVT, Diff, Zero, ISD::SETEQ));
430+
SDValue Cmp = DAG.getSetCC(DL, CmpVT, Diff, Zero,
431+
IsReadAfterWrite ? ISD::SETEQ : ISD::SETLE);
432+
return DAG.getNode(ISD::OR, DL, CmpVT, Cmp,
433+
DAG.getSetCC(DL, CmpVT, LaneOffset, Diff, ISD::SETULT));
420434
}
421435

422436
SDValue DAGTypeLegalizer::ScalarizeVecRes_BITCAST(SDNode *N) {
@@ -1695,17 +1709,22 @@ void DAGTypeLegalizer::SplitVecRes_LOOP_DEPENDENCE_MASK(SDNode *N, SDValue &Lo,
16951709
SDValue &Hi) {
16961710
SDLoc DL(N);
16971711
EVT LoVT, HiVT;
1698-
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(N->getValueType(0));
16991712
SDValue PtrA = N->getOperand(0);
17001713
SDValue PtrB = N->getOperand(1);
1701-
Lo = DAG.getNode(N->getOpcode(), DL, LoVT, PtrA, PtrB, N->getOperand(2));
1702-
1703-
unsigned EltSize = N->getConstantOperandVal(2);
1704-
ElementCount Offset = HiVT.getVectorElementCount() * EltSize;
1705-
SDValue Addend = DAG.getElementCount(DL, MVT::i64, Offset);
1714+
std::tie(LoVT, HiVT) = DAG.GetSplitDestVTs(N->getValueType(0));
17061715

1707-
PtrA = DAG.getNode(ISD::ADD, DL, MVT::i64, PtrA, Addend);
1708-
Hi = DAG.getNode(N->getOpcode(), DL, HiVT, PtrA, PtrB, N->getOperand(2));
1716+
// The lane offset for the "Lo" half of the mask is unchanged.
1717+
Lo = DAG.getNode(N->getOpcode(), DL, LoVT, PtrA, PtrB,
1718+
/*ElementSizeInBytes=*/N->getOperand(2),
1719+
/*LaneOffset=*/N->getOperand(3));
1720+
// The lane offset for the "Hi" half of the mask is incremented by the number
1721+
// of elements in the "Lo" half.
1722+
unsigned LaneOffset =
1723+
N->getConstantOperandVal(3) + LoVT.getVectorMinNumElements();
1724+
// Note: The lane offset is implicitly scalable for scalable masks.
1725+
Hi = DAG.getNode(N->getOpcode(), DL, HiVT, PtrA, PtrB,
1726+
/*ElementSizeInBytes=*/N->getOperand(2),
1727+
/*LaneOffset=*/DAG.getConstant(LaneOffset, DL, MVT::i64));
17091728
}
17101729

17111730
void DAGTypeLegalizer::SplitVecRes_BUILD_VECTOR(SDNode *N, SDValue &Lo,
@@ -6050,7 +6069,7 @@ SDValue DAGTypeLegalizer::WidenVecRes_LOOP_DEPENDENCE_MASK(SDNode *N) {
60506069
return DAG.getNode(
60516070
N->getOpcode(), SDLoc(N),
60526071
TLI.getTypeToTransformTo(*DAG.getContext(), N->getValueType(0)),
6053-
N->getOperand(0), N->getOperand(1), N->getOperand(2));
6072+
N->getOperand(0), N->getOperand(1), N->getOperand(2), N->getOperand(3));
60546073
}
60556074

60566075
SDValue DAGTypeLegalizer::WidenVecRes_BUILD_VECTOR(SDNode *N) {

llvm/lib/CodeGen/SelectionDAG/SelectionDAGBuilder.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8427,13 +8427,15 @@ void SelectionDAGBuilder::visitIntrinsicCall(const CallInst &I,
84278427
setValue(&I,
84288428
DAG.getNode(ISD::LOOP_DEPENDENCE_WAR_MASK, sdl,
84298429
EVT::getEVT(I.getType()), getValue(I.getOperand(0)),
8430-
getValue(I.getOperand(1)), getValue(I.getOperand(2))));
8430+
getValue(I.getOperand(1)), getValue(I.getOperand(2)),
8431+
DAG.getConstant(0, sdl, MVT::i64)));
84318432
return;
84328433
case Intrinsic::loop_dependence_raw_mask:
84338434
setValue(&I,
84348435
DAG.getNode(ISD::LOOP_DEPENDENCE_RAW_MASK, sdl,
84358436
EVT::getEVT(I.getType()), getValue(I.getOperand(0)),
8436-
getValue(I.getOperand(1)), getValue(I.getOperand(2))));
8437+
getValue(I.getOperand(1)), getValue(I.getOperand(2)),
8438+
DAG.getConstant(0, sdl, MVT::i64)));
84378439
return;
84388440
}
84398441
}

llvm/lib/Target/AArch64/AArch64ISelLowering.cpp

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5440,9 +5440,9 @@ SDValue
54405440
AArch64TargetLowering::LowerLOOP_DEPENDENCE_MASK(SDValue Op,
54415441
SelectionDAG &DAG) const {
54425442
SDLoc DL(Op);
5443-
uint64_t EltSize = Op.getConstantOperandVal(2);
54445443
EVT VT = Op.getValueType();
5445-
switch (EltSize) {
5444+
SDValue EltSize = Op.getOperand(2);
5445+
switch (EltSize->getAsZExtVal()) {
54465446
case 1:
54475447
if (VT != MVT::v16i8 && VT != MVT::nxv16i1)
54485448
return SDValue();
@@ -5464,11 +5464,15 @@ AArch64TargetLowering::LowerLOOP_DEPENDENCE_MASK(SDValue Op,
54645464
return SDValue();
54655465
}
54665466

5467+
SDValue LaneOffset = Op.getOperand(3);
5468+
if (LaneOffset->getAsZExtVal())
5469+
return SDValue();
5470+
54675471
SDValue PtrA = Op.getOperand(0);
54685472
SDValue PtrB = Op.getOperand(1);
54695473

54705474
if (VT.isScalableVT())
5471-
return DAG.getNode(Op.getOpcode(), DL, VT, PtrA, PtrB, Op.getOperand(2));
5475+
return DAG.getNode(Op.getOpcode(), DL, VT, PtrA, PtrB, EltSize, LaneOffset);
54725476

54735477
// We can use the SVE whilewr/whilerw instruction to lower this
54745478
// intrinsic by creating the appropriate sequence of scalable vector
@@ -5480,7 +5484,7 @@ AArch64TargetLowering::LowerLOOP_DEPENDENCE_MASK(SDValue Op,
54805484
EVT WhileVT = ContainerVT.changeElementType(MVT::i1);
54815485

54825486
SDValue Mask =
5483-
DAG.getNode(Op.getOpcode(), DL, WhileVT, PtrA, PtrB, Op.getOperand(2));
5487+
DAG.getNode(Op.getOpcode(), DL, WhileVT, PtrA, PtrB, EltSize, LaneOffset);
54845488
SDValue MaskAsInt = DAG.getNode(ISD::SIGN_EXTEND, DL, ContainerVT, Mask);
54855489
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, VT, MaskAsInt,
54865490
DAG.getVectorIdxConstant(0, DL));
@@ -6251,35 +6255,43 @@ SDValue AArch64TargetLowering::LowerINTRINSIC_WO_CHAIN(SDValue Op,
62516255
case Intrinsic::aarch64_sve_whilewr_b:
62526256
return DAG.getNode(ISD::LOOP_DEPENDENCE_WAR_MASK, DL, Op.getValueType(),
62536257
Op.getOperand(1), Op.getOperand(2),
6254-
DAG.getConstant(1, DL, MVT::i64));
6258+
DAG.getConstant(1, DL, MVT::i64),
6259+
DAG.getConstant(0, DL, MVT::i64));
62556260
case Intrinsic::aarch64_sve_whilewr_h:
62566261
return DAG.getNode(ISD::LOOP_DEPENDENCE_WAR_MASK, DL, Op.getValueType(),
62576262
Op.getOperand(1), Op.getOperand(2),
6258-
DAG.getConstant(2, DL, MVT::i64));
6263+
DAG.getConstant(2, DL, MVT::i64),
6264+
DAG.getConstant(0, DL, MVT::i64));
62596265
case Intrinsic::aarch64_sve_whilewr_s:
62606266
return DAG.getNode(ISD::LOOP_DEPENDENCE_WAR_MASK, DL, Op.getValueType(),
62616267
Op.getOperand(1), Op.getOperand(2),
6262-
DAG.getConstant(4, DL, MVT::i64));
6268+
DAG.getConstant(4, DL, MVT::i64),
6269+
DAG.getConstant(0, DL, MVT::i64));
62636270
case Intrinsic::aarch64_sve_whilewr_d:
62646271
return DAG.getNode(ISD::LOOP_DEPENDENCE_WAR_MASK, DL, Op.getValueType(),
62656272
Op.getOperand(1), Op.getOperand(2),
6266-
DAG.getConstant(8, DL, MVT::i64));
6273+
DAG.getConstant(8, DL, MVT::i64),
6274+
DAG.getConstant(0, DL, MVT::i64));
62676275
case Intrinsic::aarch64_sve_whilerw_b:
62686276
return DAG.getNode(ISD::LOOP_DEPENDENCE_RAW_MASK, DL, Op.getValueType(),
62696277
Op.getOperand(1), Op.getOperand(2),
6270-
DAG.getConstant(1, DL, MVT::i64));
6278+
DAG.getConstant(1, DL, MVT::i64),
6279+
DAG.getConstant(0, DL, MVT::i64));
62716280
case Intrinsic::aarch64_sve_whilerw_h:
62726281
return DAG.getNode(ISD::LOOP_DEPENDENCE_RAW_MASK, DL, Op.getValueType(),
62736282
Op.getOperand(1), Op.getOperand(2),
6274-
DAG.getConstant(2, DL, MVT::i64));
6283+
DAG.getConstant(2, DL, MVT::i64),
6284+
DAG.getConstant(0, DL, MVT::i64));
62756285
case Intrinsic::aarch64_sve_whilerw_s:
62766286
return DAG.getNode(ISD::LOOP_DEPENDENCE_RAW_MASK, DL, Op.getValueType(),
62776287
Op.getOperand(1), Op.getOperand(2),
6278-
DAG.getConstant(4, DL, MVT::i64));
6288+
DAG.getConstant(4, DL, MVT::i64),
6289+
DAG.getConstant(0, DL, MVT::i64));
62796290
case Intrinsic::aarch64_sve_whilerw_d:
62806291
return DAG.getNode(ISD::LOOP_DEPENDENCE_RAW_MASK, DL, Op.getValueType(),
62816292
Op.getOperand(1), Op.getOperand(2),
6282-
DAG.getConstant(8, DL, MVT::i64));
6293+
DAG.getConstant(8, DL, MVT::i64),
6294+
DAG.getConstant(0, DL, MVT::i64));
62836295
case Intrinsic::aarch64_neon_abs: {
62846296
EVT Ty = Op.getValueType();
62856297
if (Ty == MVT::i64) {

llvm/lib/Target/AArch64/SVEInstrFormats.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6020,13 +6020,13 @@ multiclass sve2_int_while_rr<bits<1> rw, string asm, SDPatternOperator op> {
60206020
def _S : sve2_int_while_rr<0b10, rw, asm, PPR32>;
60216021
def _D : sve2_int_while_rr<0b11, rw, asm, PPR64>;
60226022

6023-
def : Pat<(nxv16i1 (op i64:$Op1, i64:$Op2, (i64 1))),
6023+
def : Pat<(nxv16i1 (op i64:$Op1, i64:$Op2, (i64 1), (i64 0))),
60246024
(!cast<Instruction>(NAME # _B) $Op1, $Op2)>;
6025-
def : Pat<(nxv8i1 (op i64:$Op1, i64:$Op2, (i64 2))),
6025+
def : Pat<(nxv8i1 (op i64:$Op1, i64:$Op2, (i64 2), (i64 0))),
60266026
(!cast<Instruction>(NAME # _H) $Op1, $Op2)>;
6027-
def : Pat<(nxv4i1 (op i64:$Op1, i64:$Op2, (i64 4))),
6027+
def : Pat<(nxv4i1 (op i64:$Op1, i64:$Op2, (i64 4), (i64 0))),
60286028
(!cast<Instruction>(NAME # _S) $Op1, $Op2)>;
6029-
def : Pat<(nxv2i1 (op i64:$Op1, i64:$Op2, (i64 8))),
6029+
def : Pat<(nxv2i1 (op i64:$Op1, i64:$Op2, (i64 8), (i64 0))),
60306030
(!cast<Instruction>(NAME # _D) $Op1, $Op2)>;
60316031
}
60326032

0 commit comments

Comments
 (0)