Skip to content

Commit cdab5cc

Browse files
oscardssmithOscar Smith
authored andcommitted
finish hooking up CLMUL to selectiondag?
1 parent dcc3580 commit cdab5cc

File tree

7 files changed

+50
-16
lines changed

7 files changed

+50
-16
lines changed

llvm/docs/LangRef.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18190,7 +18190,7 @@ Semantics:
1819018190
""""""""""
1819118191

1819218192
The ‘llvm.clmul’ intrinsic computes carryless multiply of ``%a`` and ``%b``, which is the result
18193-
of applying the standard multiplication algorithm if you replace all of the aditions with exclusive ors.
18193+
of applying the standard multiplication algorithm if you replace all of the additions with exclusive ors.
1819418194
The vector intrinsics, such as llvm.clmul.v4i32, operate on a per-element basis and the element order is not affected.
1819518195

1819618196
Examples

llvm/include/llvm/IR/Intrinsics.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1428,7 +1428,7 @@ let IntrProperties = [IntrNoMem, IntrSpeculatable] in {
14281428
def int_fshr : DefaultAttrsIntrinsic<[llvm_anyint_ty],
14291429
[LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
14301430
def int_clmul : DefaultAttrsIntrinsic<[llvm_anyint_ty],
1431-
[LLVMMatchType<0>, LLVMMatchType<0>, LLVMMatchType<0>]>;
1431+
[LLVMMatchType<0>, LLVMMatchType<0>]>;
14321432
}
14331433

14341434
let IntrProperties = [IntrNoMem, IntrSpeculatable,

llvm/lib/CodeGen/IntrinsicLowering.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ static Value *LowerCLMUL(LLVMContext &Context, Value *V1, Value *V2, Instruction
208208
Value *Res = ConstantInt::get(V1->getType(), 0);
209209
Value *Zero = ConstantInt::get(V1->getType(), 0);
210210
Value *One = ConstantInt::get(V1->getType(), 1);
211-
for (unsigned I = 1; I < BitSize; I ++) {
211+
for (unsigned I = 1; I < BitSize; I++) {
212212
Value *LowBit = Builder.CreateAnd(V1, One, "clmul.isodd");
213213
Value *Pred = Builder.CreateSelect(LowBit, V2, Zero, "clmul.V2_or_zero");
214214
Res = Builder.CreateXor(Res, Pred, "clmul.Res");

llvm/lib/CodeGen/SelectionDAG/LegalizeIntegerTypes.cpp

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,8 @@ void DAGTypeLegalizer::PromoteIntegerResult(SDNode *N, unsigned ResNo) {
209209
case ISD::VP_XOR:
210210
case ISD::VP_ADD:
211211
case ISD::VP_SUB:
212-
case ISD::VP_MUL: Res = PromoteIntRes_SimpleIntBinOp(N); break;
212+
case ISD::VP_MUL:
213+
case ISD::CLMUL: Res = PromoteIntRes_SimpleIntBinOp(N); break;
213214

214215
case ISD::ABDS:
215216
case ISD::AVGCEILS:
@@ -3140,6 +3141,10 @@ void DAGTypeLegalizer::ExpandIntegerResult(SDNode *N, unsigned ResNo) {
31403141
ExpandIntRes_FunnelShift(N, Lo, Hi);
31413142
break;
31423143

3144+
case ISD::CLMUL:
3145+
ExpandIntRes_CLMUL(N, Lo, Hi);
3146+
break;
3147+
31433148
case ISD::VSCALE:
31443149
ExpandIntRes_VSCALE(N, Lo, Hi);
31453150
break;
@@ -5476,6 +5481,35 @@ void DAGTypeLegalizer::ExpandIntRes_FunnelShift(SDNode *N, SDValue &Lo,
54765481
Hi = DAG.getNode(Opc, DL, HalfVT, Select3, Select2, NewShAmt);
54775482
}
54785483

5484+
void DAGTypeLegalizer::ExpandIntRes_CLMUL(SDNode *N, SDValue &Lo,
5485+
SDValue &Hi) {
5486+
// Values numbered from least significant to most significant.
5487+
SDValue In1, In2, In3, In4;
5488+
GetExpandedInteger(N->getOperand(0), In3, In4);
5489+
GetExpandedInteger(N->getOperand(1), In1, In2);
5490+
EVT HalfVT = In1.getValueType();
5491+
SDLoc DL(N);
5492+
5493+
// CLMUL is carryless so Lo is computed from the low half
5494+
Lo = DAG.getNode(ISD::CLMUL, DL, HalfVT, In1, In3);
5495+
// the high bits not included in CLMUL(A,B) can be computed by
5496+
// BITREVERSE(CLMUL(BITREVERSE(A), BITREVERSE(B))) >> 1
5497+
// Therefore we can compute the 2 hi/lo cross products
5498+
// and the the overflow of the low product
5499+
// and xor them together to compute HI
5500+
SDValue BitRevIn1 = DAG.getNode(ISD::BITREVERSE, DL, HalfVT, In1);
5501+
SDValue BitRevIn3 = DAG.getNode(ISD::BITREVERSE, DL, HalfVT, In3);
5502+
SDValue BitRevLoHi = DAG.getNode(ISD::CLMUL, DL, HalfVT, BitRevIn1, BitRevIn3);
5503+
SDValue LoHi = DAG.getNode(ISD::BITREVERSE, DL, HalfVT, BitRevLoHi);
5504+
SDValue One = DAG.getConstant(0, DL, HalfVT);
5505+
Hi = DAG.getNode(ISD::SRL, DL, HalfVT, LoHi, One);
5506+
5507+
SDValue HITMP = DAG.getNode(ISD::CLMUL, DL, HalfVT, In1, In4);
5508+
Hi = DAG.getNode(ISD::XOR, DL, HalfVT, Hi, HITMP);
5509+
HITMP = DAG.getNode(ISD::CLMUL, DL, HalfVT, In2, In3);
5510+
Hi = DAG.getNode(ISD::XOR, DL, HalfVT, Hi, HITMP);
5511+
}
5512+
54795513
void DAGTypeLegalizer::ExpandIntRes_VSCALE(SDNode *N, SDValue &Lo,
54805514
SDValue &Hi) {
54815515
EVT VT = N->getValueType(0);

llvm/lib/CodeGen/SelectionDAG/LegalizeTypes.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ class LLVM_LIBRARY_VISIBILITY DAGTypeLegalizer {
511511

512512
void ExpandIntRes_Rotate (SDNode *N, SDValue &Lo, SDValue &Hi);
513513
void ExpandIntRes_FunnelShift (SDNode *N, SDValue &Lo, SDValue &Hi);
514+
void ExpandIntRes_CLMUL (SDNode *N, SDValue &Lo, SDValue &Hi);
514515

515516
void ExpandIntRes_VSCALE (SDNode *N, SDValue &Lo, SDValue &Hi);
516517
void ExpandIntRes_READ_REGISTER(SDNode *N, SDValue &Lo, SDValue &Hi);

llvm/lib/CodeGen/SelectionDAG/LegalizeVectorTypes.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ void DAGTypeLegalizer::ScalarizeVectorResult(SDNode *N, unsigned ResNo) {
166166
case ISD::SMAX:
167167
case ISD::UMIN:
168168
case ISD::UMAX:
169+
case ISD::CLMUL:
169170

170171
case ISD::SADDSAT:
171172
case ISD::UADDSAT:
@@ -1330,6 +1331,7 @@ void DAGTypeLegalizer::SplitVectorResult(SDNode *N, unsigned ResNo) {
13301331
case ISD::SMAX: case ISD::VP_SMAX:
13311332
case ISD::UMIN: case ISD::VP_UMIN:
13321333
case ISD::UMAX: case ISD::VP_UMAX:
1334+
case ISD::CLMUL:
13331335
case ISD::SADDSAT: case ISD::VP_SADDSAT:
13341336
case ISD::UADDSAT: case ISD::VP_UADDSAT:
13351337
case ISD::SSUBSAT: case ISD::VP_SSUBSAT:
@@ -4764,6 +4766,7 @@ void DAGTypeLegalizer::WidenVectorResult(SDNode *N, unsigned ResNo) {
47644766
case ISD::SSUBSAT: case ISD::VP_SSUBSAT:
47654767
case ISD::SSHLSAT:
47664768
case ISD::USHLSAT:
4769+
case ISD::CLMUL:
47674770
case ISD::ROTL:
47684771
case ISD::ROTR:
47694772
case ISD::AVGFLOORS:

llvm/lib/CodeGen/SelectionDAG/TargetLowering.cpp

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8293,26 +8293,22 @@ SDValue TargetLowering::expandCLMUL(SDNode *Node,
82938293
!isOperationLegalOrCustom(ISD::SHL, VT) ||
82948294
!isOperationLegalOrCustom(ISD::XOR, VT) ||
82958295
!isOperationLegalOrCustom(ISD::AND, VT) ||
8296-
!isOperationLegalOrCustom(ISD::SELECT, VT) ||
8297-
!isOperationLegalOrCustomOrPromote(ISD::OR, VT))))
8296+
!isOperationLegalOrCustom(ISD::SELECT, VT))))
82988297
return SDValue();
82998298

83008299
SDValue Res = DAG.getConstant(0, DL, VT);
83018300
SDValue Zero = DAG.getConstant(0, DL, VT);
83028301
SDValue One = DAG.getConstant(1, DL, VT);
8303-
for (unsigned i = 0; i < NumBitsPerElt-1; ++i) {
8302+
for (unsigned I = 0; I < NumBitsPerElt-1; ++I) {
83048303
SDValue LowBit = DAG.getNode(ISD::AND, DL, VT, V1, One);
8305-
SDValue LowBool = DAG.getSetCC(DL, SetCCType, LowBit, One, ISD::SETULT);
8304+
SDValue LowBool = DAG.getSetCC(DL, SetCCType, LowBit, Zero, ISD::SETNE);
83068305
SDValue Pred = DAG.getNode(ISD::SELECT, DL, VT, LowBool, V2, Zero);
83078306
Res = DAG.getNode(ISD::XOR, DL, VT, Res, Pred);
8308-
V1 = DAG.getNode(ISD::SRL, DL, VT, V1, One);
8309-
V2 = DAG.getNode(ISD::SHL, DL, VT, V2, One);
8310-
}
8311-
// unroll last iteration to prevent dead nodes
8312-
SDValue LowBit = DAG.getNode(ISD::AND, DL, VT, V1, One);
8313-
SDValue LowBool = DAG.getSetCC(DL, SetCCType, LowBit, One, ISD::SETULT);
8314-
SDValue Pred = DAG.getNode(ISD::SELECT, DL, VT, LowBool, V2, Zero);
8315-
Res = DAG.getNode(ISD::XOR, DL, VT, Res, Pred);
8307+
if (I != NumBitsPerElt) {
8308+
V1 = DAG.getNode(ISD::SRL, DL, VT, V1, One);
8309+
V2 = DAG.getNode(ISD::SHL, DL, VT, V2, One);
8310+
}
8311+
}
83168312
return Res;
83178313
}
83188314

0 commit comments

Comments
 (0)