-
Notifications
You must be signed in to change notification settings - Fork 15.2k
[ConstantTime][LLVM] Add llvm.ct.select intrinsic with generic SelectionDAG lowering #166702
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -435,6 +435,9 @@ struct SDNodeFlags { | |
| NonNeg | NoNaNs | NoInfs | SameSign | InBounds, | ||
| FastMathFlags = NoNaNs | NoInfs | NoSignedZeros | AllowReciprocal | | ||
| AllowContract | ApproximateFuncs | AllowReassociation, | ||
|
|
||
| // Flag for disabling optimization | ||
| NoMerge = 1 << 15, | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This shouldn't be here, I'll make sure to remove it |
||
| }; | ||
|
|
||
| /// Default constructor turns off all optimization flags. | ||
|
|
@@ -486,7 +489,6 @@ struct SDNodeFlags { | |
| bool hasNoFPExcept() const { return Flags & NoFPExcept; } | ||
| bool hasUnpredictable() const { return Flags & Unpredictable; } | ||
| bool hasInBounds() const { return Flags & InBounds; } | ||
|
|
||
| bool operator==(const SDNodeFlags &Other) const { | ||
| return Flags == Other.Flags; | ||
| } | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -242,11 +242,15 @@ class LLVM_ABI TargetLoweringBase { | |
|
|
||
| /// Enum that describes what type of support for selects the target has. | ||
| enum SelectSupportKind { | ||
| ScalarValSelect, // The target supports scalar selects (ex: cmov). | ||
| ScalarCondVectorVal, // The target supports selects with a scalar condition | ||
| // and vector values (ex: cmov). | ||
| VectorMaskSelect // The target supports vector selects with a vector | ||
| // mask (ex: x86 blends). | ||
| ScalarValSelect, // The target supports scalar selects (ex: cmov). | ||
| ScalarCondVectorVal, // The target supports selects with a scalar condition | ||
| // and vector values (ex: cmov). | ||
| VectorMaskSelect, // The target supports vector selects with a vector | ||
| // mask (ex: x86 blends). | ||
| CtSelect, // The target implements a custom constant-time select. | ||
| ScalarCondVectorValCtSelect, // The target supports selects with a scalar | ||
| // condition and vector values. | ||
| VectorMaskValCtSelect, // The target supports vector selects with a vector | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we don't use
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: two of these three don't mention "constant-time" in their comments. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We only rely on |
||
| }; | ||
|
|
||
| /// Enum that specifies what an atomic load/AtomicRMWInst is expanded | ||
|
|
@@ -476,8 +480,8 @@ class LLVM_ABI TargetLoweringBase { | |
| MachineMemOperand::Flags | ||
| getVPIntrinsicMemOperandFlags(const VPIntrinsic &VPIntrin) const; | ||
|
|
||
| virtual bool isSelectSupported(SelectSupportKind /*kind*/) const { | ||
| return true; | ||
| virtual bool isSelectSupported(SelectSupportKind kind) const { | ||
| return kind != CtSelect; | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why is this not checking for all three of the new values you added? It looks as if will assume by default that every target supports |
||
| } | ||
|
|
||
| /// Return true if the @llvm.get.active.lane.mask intrinsic should be expanded | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1825,6 +1825,15 @@ def int_coro_subfn_addr : DefaultAttrsIntrinsic< | |
| [IntrReadMem, IntrArgMemOnly, ReadOnly<ArgIndex<0>>, | ||
| NoCapture<ArgIndex<0>>]>; | ||
|
|
||
| ///===-------------------------- Constant Time Intrinsics | ||
| ///--------------------------===// | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nit: your ASCII art has been accidentally word-wrapped 🙂 |
||
| // | ||
| // Intrinsic to support constant time select | ||
| def int_ct_select | ||
| : DefaultAttrsIntrinsic<[llvm_any_ty], | ||
| [llvm_i1_ty, LLVMMatchType<0>, LLVMMatchType<0>], | ||
| [IntrWriteMem, IntrWillReturn, NoUndef<RetIndex>]>; | ||
|
|
||
| ///===-------------------------- Other Intrinsics --------------------------===// | ||
| // | ||
| // TODO: We should introduce a new memory kind fo traps (and other side effects | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -484,6 +484,7 @@ namespace { | |
| SDValue visitCTTZ_ZERO_UNDEF(SDNode *N); | ||
| SDValue visitCTPOP(SDNode *N); | ||
| SDValue visitSELECT(SDNode *N); | ||
| SDValue visitCTSELECT(SDNode *N); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I hate to say it, but is it possible that the naming of this node will cause confusion? Two lines above this |
||
| SDValue visitVSELECT(SDNode *N); | ||
| SDValue visitVP_SELECT(SDNode *N); | ||
| SDValue visitSELECT_CC(SDNode *N); | ||
|
|
@@ -1898,6 +1899,7 @@ void DAGCombiner::Run(CombineLevel AtLevel) { | |
| } | ||
|
|
||
| SDValue DAGCombiner::visit(SDNode *N) { | ||
|
|
||
| // clang-format off | ||
| switch (N->getOpcode()) { | ||
| default: break; | ||
|
|
@@ -1968,6 +1970,7 @@ SDValue DAGCombiner::visit(SDNode *N) { | |
| case ISD::CTTZ_ZERO_UNDEF: return visitCTTZ_ZERO_UNDEF(N); | ||
| case ISD::CTPOP: return visitCTPOP(N); | ||
| case ISD::SELECT: return visitSELECT(N); | ||
| case ISD::CTSELECT: return visitCTSELECT(N); | ||
| case ISD::VSELECT: return visitVSELECT(N); | ||
| case ISD::SELECT_CC: return visitSELECT_CC(N); | ||
| case ISD::SETCC: return visitSETCC(N); | ||
|
|
@@ -6032,6 +6035,7 @@ static SDValue isSaturatingMinMax(SDValue N0, SDValue N1, SDValue N2, | |
| N0CC = cast<CondCodeSDNode>(N0.getOperand(4))->get(); | ||
| break; | ||
| case ISD::SELECT: | ||
| case ISD::CTSELECT: | ||
| case ISD::VSELECT: | ||
| if (N0.getOperand(0).getOpcode() != ISD::SETCC) | ||
| return SDValue(); | ||
|
|
@@ -12184,8 +12188,9 @@ template <class MatchContextClass> | |
| static SDValue foldBoolSelectToLogic(SDNode *N, const SDLoc &DL, | ||
| SelectionDAG &DAG) { | ||
| assert((N->getOpcode() == ISD::SELECT || N->getOpcode() == ISD::VSELECT || | ||
| N->getOpcode() == ISD::VP_SELECT) && | ||
| "Expected a (v)(vp.)select"); | ||
| N->getOpcode() == ISD::VP_SELECT || | ||
| N->getOpcode() == ISD::CTSELECT) && | ||
| "Expected a (v)(vp.)(ct) select"); | ||
| SDValue Cond = N->getOperand(0); | ||
| SDValue T = N->getOperand(1), F = N->getOperand(2); | ||
| EVT VT = N->getValueType(0); | ||
|
|
@@ -12547,6 +12552,109 @@ SDValue DAGCombiner::visitSELECT(SDNode *N) { | |
| return SDValue(); | ||
| } | ||
|
|
||
| SDValue DAGCombiner::visitCTSELECT(SDNode *N) { | ||
| SDValue N0 = N->getOperand(0); | ||
| SDValue N1 = N->getOperand(1); | ||
| SDValue N2 = N->getOperand(2); | ||
| EVT VT = N->getValueType(0); | ||
| EVT VT0 = N0.getValueType(); | ||
| SDLoc DL(N); | ||
| SDNodeFlags Flags = N->getFlags(); | ||
|
|
||
| if (SDValue V = foldBoolSelectToLogic<EmptyMatchContext>(N, DL, DAG)) | ||
| return V; | ||
|
|
||
| // ctselect (not Cond), N1, N2 -> ctselect Cond, N2, N1 | ||
| if (SDValue F = extractBooleanFlip(N0, DAG, TLI, false)) { | ||
| SDValue SelectOp = DAG.getNode(ISD::CTSELECT, DL, VT, F, N2, N1); | ||
| SelectOp->setFlags(Flags); | ||
| return SelectOp; | ||
| } | ||
|
|
||
| if (VT0 == MVT::i1) { | ||
| // The code in this block deals with the following 2 equivalences: | ||
| // select(C0|C1, x, y) <=> select(C0, x, select(C1, x, y)) | ||
| // select(C0&C1, x, y) <=> select(C0, select(C1, x, y), y) | ||
| // The target can specify its preferred form with the | ||
| // shouldNormalizeToSelectSequence() callback. However we always transform | ||
| // to the right anyway if we find the inner select exists in the DAG anyway | ||
| // and we always transform to the left side if we know that we can further | ||
| // optimize the combination of the conditions. | ||
| bool normalizeToSequence = | ||
| TLI.shouldNormalizeToSelectSequence(*DAG.getContext(), VT); | ||
| // ctselect (and Cond0, Cond1), X, Y | ||
| // -> ctselect Cond0, (ctselect Cond1, X, Y), Y | ||
| if (N0->getOpcode() == ISD::AND && N0->hasOneUse()) { | ||
| SDValue Cond0 = N0->getOperand(0); | ||
| SDValue Cond1 = N0->getOperand(1); | ||
| SDValue InnerSelect = DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), | ||
| Cond1, N1, N2, Flags); | ||
| if (normalizeToSequence || !InnerSelect.use_empty()) | ||
| return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Cond0, | ||
| InnerSelect, N2, Flags); | ||
| // Cleanup on failure. | ||
| if (InnerSelect.use_empty()) | ||
| recursivelyDeleteUnusedNodes(InnerSelect.getNode()); | ||
| } | ||
| // ctselect (or Cond0, Cond1), X, Y -> ctselect Cond0, X, (ctselect Cond1, | ||
| // X, Y) | ||
| if (N0->getOpcode() == ISD::OR && N0->hasOneUse()) { | ||
| SDValue Cond0 = N0->getOperand(0); | ||
| SDValue Cond1 = N0->getOperand(1); | ||
| SDValue InnerSelect = DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), | ||
| Cond1, N1, N2, Flags); | ||
| if (normalizeToSequence || !InnerSelect.use_empty()) | ||
| return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Cond0, N1, | ||
| InnerSelect, Flags); | ||
| // Cleanup on failure. | ||
| if (InnerSelect.use_empty()) | ||
| recursivelyDeleteUnusedNodes(InnerSelect.getNode()); | ||
| } | ||
|
|
||
| // ctselect Cond0, (ctselect Cond1, X, Y), Y -> ctselect (and Cond0, Cond1), | ||
| // X, Y | ||
| if (N1->getOpcode() == ISD::CTSELECT && N1->hasOneUse()) { | ||
| SDValue N1_0 = N1->getOperand(0); | ||
| SDValue N1_1 = N1->getOperand(1); | ||
| SDValue N1_2 = N1->getOperand(2); | ||
| if (N1_2 == N2 && N0.getValueType() == N1_0.getValueType()) { | ||
| // Create the actual and node if we can generate good code for it. | ||
| if (!normalizeToSequence) { | ||
| SDValue And = DAG.getNode(ISD::AND, DL, N0.getValueType(), N0, N1_0); | ||
| return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), And, N1_1, | ||
| N2, Flags); | ||
| } | ||
| // Otherwise see if we can optimize the "and" to a better pattern. | ||
| if (SDValue Combined = visitANDLike(N0, N1_0, N)) { | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you confident that this It would be a shame to take as input a perfectly safe double-CTSELECT and spit out a thing which had "helpfully" optimized the condition into something that wasn't constant-time any more. (Same goes for the |
||
| return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Combined, | ||
| N1_1, N2, Flags); | ||
| } | ||
| } | ||
| } | ||
| // ctselect Cond0, X, (ctselect Cond1, X, Y) -> ctselect (or Cond0, Cond1), | ||
| // X, Y | ||
| if (N2->getOpcode() == ISD::CTSELECT && N2->hasOneUse()) { | ||
| SDValue N2_0 = N2->getOperand(0); | ||
| SDValue N2_1 = N2->getOperand(1); | ||
| SDValue N2_2 = N2->getOperand(2); | ||
| if (N2_1 == N1 && N0.getValueType() == N2_0.getValueType()) { | ||
| // Create the actual or node if we can generate good code for it. | ||
| if (!normalizeToSequence) { | ||
| SDValue Or = DAG.getNode(ISD::OR, DL, N0.getValueType(), N0, N2_0); | ||
| return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Or, N1, N2_2, | ||
| Flags); | ||
| } | ||
| // Otherwise see if we can optimize to a better pattern. | ||
| if (SDValue Combined = visitORLike(N0, N2_0, DL)) | ||
| return DAG.getNode(ISD::CTSELECT, DL, N1.getValueType(), Combined, N1, | ||
| N2_2, Flags); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return SDValue(); | ||
| } | ||
|
|
||
| // This function assumes all the vselect's arguments are CONCAT_VECTOR | ||
| // nodes and that the condition is a BV of ConstantSDNodes (or undefs). | ||
| static SDValue ConvertSelectToConcatVector(SDNode *N, SelectionDAG &DAG) { | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -4136,6 +4136,46 @@ bool SelectionDAGLegalize::ExpandNode(SDNode *Node) { | |
| } | ||
| Results.push_back(Tmp1); | ||
| break; | ||
| case ISD::CTSELECT: { | ||
| Tmp1 = Node->getOperand(0); | ||
| Tmp2 = Node->getOperand(1); | ||
| Tmp3 = Node->getOperand(2); | ||
| EVT VT = Tmp2.getValueType(); | ||
| if (VT.isVector()) { | ||
| SmallVector<SDValue> Elements; | ||
| unsigned NumElements = VT.getVectorNumElements(); | ||
| EVT ScalarVT = VT.getScalarType(); | ||
| for (unsigned Idx = 0; Idx < NumElements; ++Idx) { | ||
| SDValue IdxVal = DAG.getConstant(Idx, dl, MVT::i64); | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be using |
||
| SDValue TVal = | ||
| DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ScalarVT, Tmp2, IdxVal); | ||
| SDValue FVal = | ||
| DAG.getNode(ISD::EXTRACT_VECTOR_ELT, dl, ScalarVT, Tmp3, IdxVal); | ||
| Elements.push_back( | ||
| DAG.getCTSelect(dl, ScalarVT, Tmp1, TVal, FVal, Node->getFlags())); | ||
| } | ||
| Tmp1 = DAG.getBuildVector(VT, dl, Elements); | ||
| } else if (VT.isFloatingPoint()) { | ||
| EVT IntegerVT = EVT::getIntegerVT(*DAG.getContext(), VT.getSizeInBits()); | ||
| Tmp2 = DAG.getBitcast(IntegerVT, Tmp2); | ||
| Tmp3 = DAG.getBitcast(IntegerVT, Tmp3); | ||
| Tmp1 = DAG.getBitcast(VT, DAG.getCTSelect(dl, IntegerVT, Tmp1, Tmp2, Tmp3, | ||
| Node->getFlags())); | ||
| } else { | ||
| assert(VT.isInteger()); | ||
| EVT HalfVT = VT.getHalfSizedIntegerVT(*DAG.getContext()); | ||
| auto [Tmp2Lo, Tmp2Hi] = DAG.SplitScalar(Tmp2, dl, HalfVT, HalfVT); | ||
| auto [Tmp3Lo, Tmp3Hi] = DAG.SplitScalar(Tmp3, dl, HalfVT, HalfVT); | ||
| SDValue ResLo = | ||
| DAG.getCTSelect(dl, HalfVT, Tmp1, Tmp2Lo, Tmp3Lo, Node->getFlags()); | ||
| SDValue ResHi = | ||
| DAG.getCTSelect(dl, HalfVT, Tmp1, Tmp2Hi, Tmp3Hi, Node->getFlags()); | ||
| Tmp1 = DAG.getNode(ISD::BUILD_PAIR, dl, VT, ResLo, ResHi); | ||
| Tmp1->setFlags(Node->getFlags()); | ||
| } | ||
| Results.push_back(Tmp1); | ||
| break; | ||
| } | ||
| case ISD::BR_JT: { | ||
| SDValue Chain = Node->getOperand(0); | ||
| SDValue Table = Node->getOperand(1); | ||
|
|
@@ -5474,7 +5514,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) { | |
| Results.push_back(DAG.getNode(ISD::TRUNCATE, dl, OVT, Tmp2)); | ||
| break; | ||
| } | ||
| case ISD::SELECT: { | ||
| case ISD::SELECT: | ||
| case ISD::CTSELECT: { | ||
| unsigned ExtOp, TruncOp; | ||
| if (Node->getValueType(0).isVector() || | ||
| Node->getValueType(0).getSizeInBits() == NVT.getSizeInBits()) { | ||
|
|
@@ -5492,7 +5533,8 @@ void SelectionDAGLegalize::PromoteNode(SDNode *Node) { | |
| Tmp2 = DAG.getNode(ExtOp, dl, NVT, Node->getOperand(1)); | ||
| Tmp3 = DAG.getNode(ExtOp, dl, NVT, Node->getOperand(2)); | ||
| // Perform the larger operation, then round down. | ||
| Tmp1 = DAG.getSelect(dl, NVT, Tmp1, Tmp2, Tmp3); | ||
| Tmp1 = DAG.getNode(Node->getOpcode(), dl, NVT, Tmp1, Tmp2, Tmp3); | ||
| Tmp1->setFlags(Node->getFlags()); | ||
| if (TruncOp != ISD::FP_ROUND) | ||
| Tmp1 = DAG.getNode(TruncOp, dl, Node->getValueType(0), Tmp1); | ||
| else | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
implemented with CMOV instruction? Your description says these will be expanded to bitselect patterns?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, "Constant-time select [...] This is used to implement constant-time select" isn't really adding any value by repeating the same phrase again 🙂
Perhaps better to restate the order of parameters (it's obvious to you that it's the same as SELECT immediately above, but perhaps not to the next reader), and also, what conditions apply to the boolean – if it's not an i1, is it still expected to be an integer containing 0 or 1, or is it a bitmask containing 0 or ~0, or what?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch, that's not suppose to be there. Original when we were constructing the constant-time code, we had added implementation for the specific archs (like x86 in this case), the fallback implementation was added later. So I think that's why the comment mentions CMOV. I'll make sure to fix it, thanks! :)