Skip to content

Commit eafbf9f

Browse files
committed
SAKKE perf. opts.
1 parent b6f76e8 commit eafbf9f

File tree

2 files changed

+74
-58
lines changed

2 files changed

+74
-58
lines changed

core/src/main/java/org/bouncycastle/crypto/kems/SAKKEKEMExtractor.java

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import org.bouncycastle.crypto.EncapsulatedSecretExtractor;
77
import org.bouncycastle.crypto.params.SAKKEPrivateKeyParameters;
88
import org.bouncycastle.crypto.params.SAKKEPublicKeyParameters;
9+
import org.bouncycastle.math.ec.ECAlgorithms;
910
import org.bouncycastle.math.ec.ECCurve;
1011
import org.bouncycastle.math.ec.ECPoint;
1112
import org.bouncycastle.util.Arrays;
@@ -93,9 +94,22 @@ public byte[] extractSecret(byte[] encapsulation)
9394
BigInteger r = SAKKEKEMSGenerator.hashToIntegerRange(Arrays.concatenate(ssv.toByteArray(), b.toByteArray()), q, digest);
9495

9596
// Step 5: Validate R_bS
96-
ECPoint bP = P.multiply(b).normalize();
97-
ECPoint Test = bP.add(Z_S).multiply(r).normalize();
98-
if (!R_bS.equals(Test))
97+
ECPoint Test;
98+
99+
BigInteger order = curve.getOrder();
100+
if (order == null)
101+
{
102+
Test = P.multiply(b).add(Z_S).multiply(r);
103+
}
104+
else
105+
{
106+
BigInteger a = b.multiply(r).mod(order);
107+
Test = ECAlgorithms.sumOfTwoMultiplies(P, a, Z_S, r);
108+
}
109+
110+
Test = Test.subtract(R_bS);
111+
112+
if (!Test.isInfinity())
99113
{
100114
throw new IllegalStateException("Validation of R_bS failed");
101115
}
@@ -122,7 +136,7 @@ public int getEncapsulationLength()
122136
static BigInteger computePairing(ECPoint R, ECPoint Q, BigInteger p, BigInteger q)
123137
{
124138
// v = (1,0) in F_p^2
125-
BigInteger[] v = new BigInteger[]{BigInteger.ONE, BigInteger.ZERO};
139+
BigInteger[] v = new BigInteger[]{ BigInteger.ONE, BigInteger.ZERO };
126140
ECPoint C = R;
127141

128142
BigInteger qMinusOne = q.subtract(BigInteger.ONE);
@@ -131,23 +145,21 @@ static BigInteger computePairing(ECPoint R, ECPoint Q, BigInteger p, BigInteger
131145
BigInteger Qy = Q.getAffineYCoord().toBigInteger();
132146
BigInteger Rx = R.getAffineXCoord().toBigInteger();
133147
BigInteger Ry = R.getAffineYCoord().toBigInteger();
134-
BigInteger l, Cx, Cy;
135148
final BigInteger three = BigInteger.valueOf(3);
136-
final BigInteger two = BigInteger.valueOf(2);
137149

138150
// Miller loop
139151
for (int i = numBits - 2; i >= 0; i--)
140152
{
141-
Cx = C.getAffineXCoord().toBigInteger();
142-
Cy = C.getAffineYCoord().toBigInteger();
153+
BigInteger Cx = C.getAffineXCoord().toBigInteger();
154+
BigInteger Cy = C.getAffineYCoord().toBigInteger();
143155

144156
// Compute l = (3 * (Cx^2 - 1)) / (2 * Cy) mod p
145-
l = three.multiply(Cx.multiply(Cx).subtract(BigInteger.ONE))
146-
.multiply(Cy.multiply(two).modInverse(p)).mod(p);
157+
BigInteger l = Cx.multiply(Cx).mod(p).subtract(BigInteger.ONE).multiply(three)
158+
.multiply(BigIntegers.modOddInverse(p, Cy.shiftLeft(1))).mod(p);
147159

148160
// Compute v = v^2 * ( l*( Q_x + C_x ) + ( i*Q_y - C_y ) )
149161
v = fp2PointSquare(v[0], v[1], p);
150-
v = fp2Multiply(v[0], v[1], l.multiply(Qx.add(Cx)).subtract(Cy), Qy, p);
162+
v = fp2Multiply(v[0], v[1], l.multiply(Qx.add(Cx)).subtract(Cy).mod(p), Qy, p);
151163

152164
C = C.twice().normalize(); // C = [2]C
153165

@@ -157,55 +169,56 @@ static BigInteger computePairing(ECPoint R, ECPoint Q, BigInteger p, BigInteger
157169
Cy = C.getAffineYCoord().toBigInteger();
158170

159171
// Compute l = (Cy - Ry) / (Cx - Rx) mod p
160-
l = Cy.subtract(Ry).multiply(Cx.subtract(Rx).modInverse(p)).mod(p);
172+
l = Cy.subtract(Ry).multiply(BigIntegers.modOddInverse(p, Cx.subtract(Rx))).mod(p);
161173

162174
// Compute v = v * ( l*( Q_x + C_x ) + ( i*Q_y - C_y ) )
163-
v = fp2Multiply(v[0], v[1], l.multiply(Qx.add(Cx)).subtract(Cy), Qy, p);
175+
v = fp2Multiply(v[0], v[1], l.multiply(Qx.add(Cx)).subtract(Cy).mod(p), Qy, p);
164176

165-
C = C.add(R).normalize();
177+
if (i > 0)
178+
{
179+
C = C.add(R).normalize();
180+
}
166181
}
167182
}
168183

169184
// Final exponentiation: t = v^c
170185
v = fp2PointSquare(v[0], v[1], p);
171186
v = fp2PointSquare(v[0], v[1], p);
172-
return v[1].multiply(v[0].modInverse(p)).mod(p);
187+
BigInteger v0Inv = BigIntegers.modOddInverse(p, v[0]);
188+
return v[1].multiply(v0Inv).mod(p);
173189
}
174190

175191
/**
176192
* Performs multiplication in F_p^2 field.
177193
*
178-
* @param x_real Real component of first operand
179-
* @param x_imag Imaginary component of first operand
180-
* @param y_real Real component of second operand
181-
* @param y_imag Imaginary component of second operand
194+
* @param a0 Real component of first operand
195+
* @param b0 Imaginary component of first operand
196+
* @param a1 Real component of second operand
197+
* @param b1 Imaginary component of second operand
182198
* @param p Prime field characteristic
183199
* @return Result of multiplication in F_p^2 as [real, imaginary] array
184200
*/
185-
static BigInteger[] fp2Multiply(BigInteger x_real, BigInteger x_imag, BigInteger y_real, BigInteger y_imag, BigInteger p)
201+
static BigInteger[] fp2Multiply(BigInteger a0, BigInteger b0, BigInteger a1, BigInteger b1, BigInteger p)
186202
{
187203
return new BigInteger[]{
188-
x_real.multiply(y_real).subtract(x_imag.multiply(y_imag)).mod(p),
189-
x_real.multiply(y_imag).add(x_imag.multiply(y_real)).mod(p)
204+
a0.multiply(a1).subtract(b0.multiply(b1)).mod(p),
205+
a0.multiply(b1).add(b0.multiply(a1)).mod(p)
190206
};
191207
}
192208

193209
/**
194210
* Computes squaring operation in F_p^2 field.
195211
*
196-
* @param currentX Real component of input
197-
* @param currentY Imaginary component of input
212+
* @param a Real component of input
213+
* @param b Imaginary component of input
198214
* @param p Prime field characteristic
199215
* @return Squared result in F_p^2 as [newX, newY] array
200216
*/
201-
static BigInteger[] fp2PointSquare(BigInteger currentX, BigInteger currentY, BigInteger p)
217+
static BigInteger[] fp2PointSquare(BigInteger a, BigInteger b, BigInteger p)
202218
{
203-
BigInteger xPlusY = currentX.add(currentY).mod(p);
204-
BigInteger xMinusY = currentX.subtract(currentY).mod(p);
205-
BigInteger newX = xPlusY.multiply(xMinusY).mod(p);
206-
207-
// Compute newY = 2xy mod p
208-
BigInteger newY = currentX.multiply(currentY).multiply(BigInteger.valueOf(2)).mod(p);
209-
return new BigInteger[]{newX, newY};
219+
return new BigInteger[]{
220+
a.add(b).multiply(a.subtract(b)).mod(p),
221+
a.multiply(b).shiftLeft(1).mod(p)
222+
};
210223
}
211224
}

core/src/main/java/org/bouncycastle/crypto/kems/SAKKEKEMSGenerator.java

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import org.bouncycastle.crypto.SecretWithEncapsulation;
99
import org.bouncycastle.crypto.params.AsymmetricKeyParameter;
1010
import org.bouncycastle.crypto.params.SAKKEPublicKeyParameters;
11+
import org.bouncycastle.math.ec.ECAlgorithms;
1112
import org.bouncycastle.math.ec.ECCurve;
1213
import org.bouncycastle.math.ec.ECPoint;
1314
import org.bouncycastle.util.Arrays;
@@ -85,51 +86,53 @@ public SecretWithEncapsulation generateEncapsulated(AsymmetricKeyParameter recip
8586

8687

8788
// 3. Compute R_(b,S) = [r]([b]P + Z_S)
88-
ECPoint bP = P.multiply(b).normalize();
89-
ECPoint R_bS = bP.add(Z).multiply(r).normalize();
89+
ECPoint R_bS;
90+
91+
BigInteger order = curve.getOrder();
92+
if (order == null)
93+
{
94+
R_bS = P.multiply(b).add(Z).multiply(r).normalize();
95+
}
96+
else
97+
{
98+
BigInteger a = b.multiply(r).mod(order);
99+
R_bS = ECAlgorithms.sumOfTwoMultiplies(P, a, Z, r).normalize();
100+
}
90101

91102
// 4. Compute H = SSV XOR HashToIntegerRange( g^r, 2^n )
92103
BigInteger pointX = BigInteger.ONE;
93104
BigInteger pointY = g;
94-
BigInteger[] v = new BigInteger[2];
95105

96106
// Initialize result with the original point
97-
BigInteger currentX = BigInteger.ONE;
98-
BigInteger currentY = g;
99-
ECPoint current = curve.createPoint(currentX, currentY);
107+
BigInteger v0 = BigInteger.ONE;
108+
BigInteger v1 = g;
109+
ECPoint current = curve.createPoint(v0, v1);
100110

101-
int numBits = r.bitLength();
102-
BigInteger[] rlt;
103111
// Process bits from MSB-1 down to 0
104-
for (int i = numBits - 2; i >= 0; i--)
112+
for (int i = r.bitLength() - 2; i >= 0; i--)
105113
{
106114
// Square the current point
107-
rlt = SAKKEKEMExtractor.fp2PointSquare(currentX, currentY, p);
115+
BigInteger[] rlt = SAKKEKEMExtractor.fp2PointSquare(v0, v1, p);
108116
current = current.timesPow2(2);
109-
currentX = rlt[0];
110-
currentY = rlt[1];
117+
v0 = rlt[0];
118+
v1 = rlt[1];
111119
// Multiply if bit is set
112120
if (r.testBit(i))
113121
{
114-
rlt = SAKKEKEMExtractor.fp2Multiply(currentX, currentY, pointX, pointY, p);
122+
rlt = SAKKEKEMExtractor.fp2Multiply(v0, v1, pointX, pointY, p);
115123

116-
currentX = rlt[0];
117-
currentY = rlt[1];
124+
v0 = rlt[0];
125+
v1 = rlt[1];
118126
}
119127
}
120128

121-
v[0] = currentX;
122-
v[1] = currentY;
123-
BigInteger g_r = v[1].multiply(v[0].modInverse(p)).mod(p);
129+
BigInteger v0Inv = BigIntegers.modOddInverse(p, v0);
130+
BigInteger g_r = v1.multiply(v0Inv).mod(p);
124131

125132
BigInteger mask = hashToIntegerRange(g_r.toByteArray(), BigInteger.ONE.shiftLeft(n), digest); // 2^n
126133

127134
BigInteger H = ssv.xor(mask);
128135
// 5. Encode encapsulated data (R_bS, H)
129-
// byte[] encapsulated = Arrays.concatenate(new byte[]{(byte)0x04},
130-
// BigIntegers.asUnsignedByteArray(n, R_bS.getXCoord().toBigInteger()),
131-
// BigIntegers.asUnsignedByteArray(n, R_bS.getYCoord().toBigInteger()),
132-
// BigIntegers.asUnsignedByteArray(16, H));
133136
byte[] encapsulated = Arrays.concatenate(R_bS.getEncoded(false), BigIntegers.asUnsignedByteArray(16, H));
134137

135138
return new SecretWithEncapsulationImpl(
@@ -141,20 +144,21 @@ public SecretWithEncapsulation generateEncapsulated(AsymmetricKeyParameter recip
141144
static BigInteger hashToIntegerRange(byte[] input, BigInteger q, Digest digest)
142145
{
143146
// RFC 6508 Section 5.1: Hashing to an Integer Range
144-
byte[] hash = new byte[digest.getDigestSize()];
147+
byte[] A = new byte[digest.getDigestSize()];
145148

146149
// Step 1: Compute A = hashfn(s)
147150
digest.update(input, 0, input.length);
148-
digest.doFinal(hash, 0);
149-
byte[] A = Arrays.clone(hash);
151+
digest.doFinal(A, 0);
150152

151153
// Step 2: Initialize h_0 to all-zero bytes of hashlen size
152154
byte[] h = new byte[digest.getDigestSize()];
153155

154156
// Step 3: Compute l = Ceiling(lg(n)/hashlen)
157+
// FIXME Seems hardcoded to 256 bit digest?
155158
int l = q.bitLength() >> 8;
156159

157160
BigInteger v = BigInteger.ZERO;
161+
byte[] v_i = new byte[digest.getDigestSize()];
158162

159163
// Step 4: Compute h_i and v_i
160164
for (int i = 0; i <= l; i++)
@@ -165,7 +169,6 @@ static BigInteger hashToIntegerRange(byte[] input, BigInteger q, Digest digest)
165169
// v_i = hashfn(h_i || A)
166170
digest.update(h, 0, h.length);
167171
digest.update(A, 0, A.length);
168-
byte[] v_i = new byte[digest.getDigestSize()];
169172
digest.doFinal(v_i, 0);
170173
// Append v_i to v'
171174
v = v.shiftLeft(v_i.length * 8).add(new BigInteger(1, v_i));

0 commit comments

Comments
 (0)