Skip to content

Commit cc77eaf

Browse files
gefeilidghgit
authored andcommitted
Pass the test vector of SAKKE
1 parent a9c7179 commit cc77eaf

File tree

3 files changed

+42
-193
lines changed

3 files changed

+42
-193
lines changed

core/src/main/java/org/bouncycastle/crypto/kems/SAKKEKEMExtractor.java

Lines changed: 38 additions & 178 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
package org.bouncycastle.crypto.kems;
22

33
import java.math.BigInteger;
4+
import java.security.SecureRandom;
45

56
import org.bouncycastle.crypto.Digest;
67
import org.bouncycastle.crypto.EncapsulatedSecretExtractor;
78
import org.bouncycastle.crypto.digests.SHA256Digest;
89
import org.bouncycastle.crypto.params.SAKKEPrivateKeyParameters;
910
import org.bouncycastle.crypto.params.SAKKEPublicKeyParameters;
1011
import org.bouncycastle.math.ec.ECCurve;
12+
import org.bouncycastle.math.ec.ECFieldElement;
1113
import org.bouncycastle.math.ec.ECPoint;
1214
import org.bouncycastle.util.Arrays;
1315
import org.bouncycastle.util.BigIntegers;
16+
import org.bouncycastle.util.encoders.Hex;
1417

1518
import static org.bouncycastle.crypto.kems.SAKKEKEMSGenerator.pairing;
1619

@@ -48,18 +51,28 @@ public byte[] extractSecret(byte[] encapsulation)
4851
ECPoint R_bS = curve.decodePoint(Arrays.copyOfRange(encapsulation, 0, 257));
4952
BigInteger H = new BigInteger(Arrays.copyOfRange(encapsulation, 257, 274));
5053

54+
//ECCurveWithTatePairing pairing = new ECCurveWithTatePairing(q, BigInteger.ONE, BigInteger.ZERO, p);
55+
//BigInteger w = pairing.TatePairing(R_bS, K_bS).toBigInteger();
5156
// Step 2: Compute w = <R_bS, K_bS> using pairing
5257
// BigInteger w = computeTLPairing(new BigInteger[] {R_bS.getXCoord().toBigInteger(), R_bS.getYCoord().toBigInteger()},
5358
// new BigInteger[] {K_bS.getXCoord().toBigInteger(), K_bS.getYCoord().toBigInteger()}, this.p, this.q);
5459
BigInteger w = computePairing(R_bS, K_bS, p, q);
55-
60+
System.out.println(new String(Hex.encode(w.toByteArray())));
61+
//BigInteger w = tatePairing(R_bS.getXCoord().toBigInteger(), R_bS.getYCoord().toBigInteger(), K_bS.getXCoord().toBigInteger(), K_bS.getYCoord().toBigInteger(), q, p);
5662
// Step 3: Compute SSV = H XOR HashToIntegerRange(w, 2^n)
5763
BigInteger ssv = computeSSV(H, w);
5864

5965
// Step 4: Compute r = HashToIntegerRange(SSV || b)
60-
// BigInteger r = computeR(ssv, privateKey.getPrivatePoint());
66+
BigInteger b = privateKey.getB();
67+
BigInteger r = SAKKEUtils.hashToIntegerRange(Arrays.concatenate(ssv.toByteArray(), b.toByteArray()), q);
6168
//
6269
// // Step 5: Validate R_bS
70+
ECPoint bP = P.multiply(b).normalize();
71+
ECPoint Test = bP.add(Z_S).multiply(r).normalize();
72+
if(!R_bS.equals(Test))
73+
{
74+
throw new IllegalStateException("Validation of R_bS failed");
75+
}
6376
// if (!validateR_bS(r, privateKey.getPrivatePoint(), R_bS)) {
6477
// throw new IllegalStateException("Validation of R_bS failed");
6578
// }
@@ -78,11 +91,6 @@ public int getEncapsulationLength()
7891
return 0;
7992
}
8093

81-
private BigInteger computePairing(ECPoint R, ECPoint K)
82-
{
83-
// Use your existing pairing implementation
84-
return pairing(R, K, p, q);
85-
}
8694

8795
private BigInteger computeSSV(BigInteger H, BigInteger w)
8896
{
@@ -91,175 +99,11 @@ private BigInteger computeSSV(BigInteger H, BigInteger w)
9199
return H.xor(mask);
92100
}
93101

94-
public static BigInteger computeTLPairing(
95-
BigInteger[] R, // C = (Rx, Ry)
96-
BigInteger[] Q, // Q = (Qx, Qy)
97-
BigInteger p,
98-
BigInteger q
99-
)
100-
{
101-
BigInteger qMinus1 = q.subtract(BigInteger.ONE);
102-
int N = qMinus1.bitLength() - 1;
103-
104-
// Initialize V = (1, 0)
105-
BigInteger[] V = {BigInteger.ONE, BigInteger.ZERO};
106-
// Initialize C = R
107-
BigInteger[] C = {R[0], R[1]};
108-
109-
for (; N > 0; N--)
110-
{
111-
// V = V^2
112-
pointSquare(V, p);
113-
114-
// Compute line function T
115-
BigInteger[] T = computeLineFunctionT(C, Q, p);
116-
117-
// V = V * T
118-
pointMultiply(V, T, p);
119-
120-
// C = 2*C (point doubling)
121-
pointDouble(C, p);
122-
123-
if (qMinus1.testBit(N - 1))
124-
{
125-
// Compute addition line function
126-
BigInteger[] TAdd = computeLineFunctionAdd(C, R, Q, p);
127-
128-
// V = V * TAdd
129-
pointMultiply(V, TAdd, p);
130-
131-
// C = C + R (point addition)
132-
pointAdd(C, R, p);
133-
}
134-
}
135-
136-
// Final squaring
137-
pointSquare(V, p);
138-
pointSquare(V, p);
139-
140-
// Compute w = (Vy * Vx^{-1}) mod p
141-
BigInteger VxInv = V[0].modInverse(p);
142-
return V[1].multiply(VxInv).mod(p);
143-
}
144-
145-
private static void pointSquare(BigInteger[] point, BigInteger p)
146-
{
147-
BigInteger x = point[0];
148-
BigInteger y = point[1];
149-
150-
// x = (x + y)(x - y) mod p
151-
BigInteger xPlusY = x.add(y).mod(p);
152-
BigInteger xMinusY = x.subtract(y).mod(p);
153-
BigInteger newX = xPlusY.multiply(xMinusY).mod(p);
154-
155-
// y = 2xy mod p
156-
BigInteger newY = x.multiply(y).multiply(BigInteger.valueOf(2)).mod(p);
157-
158-
point[0] = newX;
159-
point[1] = newY;
160-
}
161-
162-
private static void pointMultiply(BigInteger[] a, BigInteger[] b, BigInteger p)
163-
{
164-
// Complex multiplication (a + bi)*(c + di) = (ac - bd) + (ad + bc)i
165-
BigInteger real = a[0].multiply(b[0]).subtract(a[1].multiply(b[1])).mod(p);
166-
BigInteger imag = a[0].multiply(b[1]).add(a[1].multiply(b[0])).mod(p);
167-
168-
a[0] = real;
169-
a[1] = imag;
170-
}
171-
172-
private static void pointDouble(BigInteger[] point, BigInteger p)
173-
{
174-
// Elliptic curve point doubling formulas
175-
BigInteger x = point[0];
176-
BigInteger y = point[1];
177-
178-
BigInteger slope = x.pow(2).multiply(BigInteger.valueOf(3))
179-
.mod(p)
180-
.multiply(y.multiply(BigInteger.valueOf(2)).modInverse(p))
181-
.mod(p);
182-
183-
BigInteger newX = slope.pow(2).subtract(x.multiply(BigInteger.valueOf(2))).mod(p);
184-
BigInteger newY = slope.multiply(x.subtract(newX)).subtract(y).mod(p);
185-
186-
point[0] = newX;
187-
point[1] = newY;
188-
}
189-
190-
private static void pointAdd(BigInteger[] a, BigInteger[] b, BigInteger p)
191-
{
192-
// Elliptic curve point addition
193-
BigInteger x1 = a[0], y1 = a[1];
194-
BigInteger x2 = b[0], y2 = b[1];
195-
196-
BigInteger slope = y2.subtract(y1)
197-
.multiply(x2.subtract(x1).modInverse(p))
198-
.mod(p);
199-
200-
BigInteger newX = slope.pow(2).subtract(x1).subtract(x2).mod(p);
201-
BigInteger newY = slope.multiply(x1.subtract(newX)).subtract(y1).mod(p);
202-
203-
a[0] = newX;
204-
a[1] = newY;
205-
}
206-
207-
private static BigInteger[] computeLineFunctionT(
208-
BigInteger[] C,
209-
BigInteger[] Q,
210-
BigInteger p
211-
)
212-
{
213-
// Line function evaluation for doubling
214-
BigInteger Cx = C[0], Cy = C[1];
215-
BigInteger Qx = Q[0], Qy = Q[1];
216-
217-
// l = (3Cx² + a)/(2Cy) but a=0 for many curves
218-
BigInteger numerator = Cx.pow(2).multiply(BigInteger.valueOf(3)).mod(p);
219-
BigInteger denominator = Cy.multiply(BigInteger.valueOf(2)).mod(p);
220-
BigInteger l = numerator.multiply(denominator.modInverse(p)).mod(p);
221-
222-
// T = l*(Qx + Cx) - 2Qy
223-
BigInteger tReal = l.multiply(Qx.add(Cx).mod(p)).mod(p);
224-
BigInteger tImag = l.multiply(Qy).negate().mod(p);
225-
226-
return new BigInteger[]{tReal, tImag};
227-
}
228-
229-
private static BigInteger[] computeLineFunctionAdd(
230-
BigInteger[] C,
231-
BigInteger[] R,
232-
BigInteger[] Q,
233-
BigInteger p
234-
)
235-
{
236-
// Line function evaluation for addition
237-
BigInteger Cx = C[0], Cy = C[1];
238-
BigInteger Rx = R[0], Ry = R[1];
239-
BigInteger Qx = Q[0], Qy = Q[1];
240-
241-
// l = (Cy - Ry)/(Cx - Rx)
242-
BigInteger numerator = Cy.subtract(Ry).mod(p);
243-
BigInteger denominator = Cx.subtract(Rx).mod(p);
244-
BigInteger l = numerator.multiply(denominator.modInverse(p)).mod(p);
245-
246-
// T = l*(Qx + Cx) - Qy
247-
BigInteger tReal = l.multiply(Qx.add(Cx).mod(p)).mod(p);
248-
BigInteger tImag = l.multiply(Qy).negate().mod(p);
249-
250-
return new BigInteger[]{tReal, tImag};
251-
}
252-
253-
private boolean pointsEqual(ECPoint p1, ECPoint p2)
254-
{
255-
return p1.normalize().getXCoord().equals(p2.normalize().getXCoord())
256-
&& p1.normalize().getYCoord().equals(p2.normalize().getYCoord());
257-
}
258-
259102
public static BigInteger computePairing(ECPoint R, ECPoint Q, BigInteger p, BigInteger q)
260103
{
261104
BigInteger c = p.add(BigInteger.ONE).divide(q); // Compute c = (p+1)/q
262105
BigInteger[] v = new BigInteger[]{BigInteger.ONE, BigInteger.ZERO}; // v = (1,0) in F_p^2
106+
//BigInteger v = BigInteger.ONE;
263107
ECPoint C = R;
264108

265109
BigInteger qMinusOne = q.subtract(BigInteger.ONE);
@@ -269,6 +113,7 @@ public static BigInteger computePairing(ECPoint R, ECPoint Q, BigInteger p, BigI
269113
for (int i = numBits - 2; i >= 0; i--)
270114
{
271115
v = fp2SquareAndAccumulate(v, C, Q, p);
116+
272117
C = C.twice().normalize(); // C = [2]C
273118

274119
if (qMinusOne.testBit(i))
@@ -290,13 +135,24 @@ private static BigInteger[] fp2SquareAndAccumulate(BigInteger[] v, ECPoint C, EC
290135
BigInteger Qy = Q.getAffineYCoord().toBigInteger();
291136

292137
// Compute l = (3 * (Cx^2 - 1)) / (2 * Cy) mod p
293-
BigInteger l = Cx.multiply(Cx).mod(p).subtract(BigInteger.ONE).multiply(BigInteger.valueOf(3)).mod(p)
294-
.multiply(Cy.multiply(BigInteger.valueOf(2)).modInverse(p))
295-
.mod(p);
138+
BigInteger l = BigInteger.valueOf(3).multiply(Cx.multiply(Cx).subtract(BigInteger.ONE))
139+
.multiply(Cy.multiply(BigInteger.valueOf(2)).modInverse(p)).mod(p);
296140

297141
// Compute v = v^2 * ( l*( Q_x + C_x ) + ( i*Q_y - C_y ) )
298142
v = fp2Multiply(v[0], v[1], v[0], v[1], p);
299-
return fp2Multiply(v[0], v[1], l.multiply(Qx.add(Cx)), (Qy.subtract(Cy)), p);
143+
// v[0] = v[0].multiply(v[0]);
144+
// v[1] = v[1].multiply(v[1]);
145+
return accumulateLine(v[0], v[1], Cx, Cy, Qx, Qy, l, p);
146+
// BigInteger t_x1_bn = Cx.multiply(Cx).subtract(BigInteger.ONE).multiply(BigInteger.valueOf(3)).multiply(Qx.add(Cx)).mod(p)
147+
// .subtract(Cy.multiply(Cy).multiply(BigInteger.valueOf(2))).mod(p);
148+
// BigInteger t_x2_bn = Cy.multiply(Qy).multiply(BigInteger.valueOf(2)).mod(p);
149+
// v = fp2Multiply(v[0], v[1], v[0], v[1], p);
150+
// return fp2Multiply(v[0], v[1], t_x1_bn, t_x2_bn, p);
151+
}
152+
153+
private static BigInteger[] accumulateLine(BigInteger v0, BigInteger v1, BigInteger Cx, BigInteger Cy, BigInteger Qx, BigInteger Qy, BigInteger l, BigInteger p)
154+
{
155+
return fp2Multiply(v0, v1, l.multiply(Qx.add(Cx)).subtract(Cy), Qy, p);
300156
}
301157

302158
private static BigInteger[] fp2MultiplyAndAccumulate(BigInteger[] v, ECPoint C, ECPoint R, ECPoint Q, BigInteger p)
@@ -314,11 +170,15 @@ private static BigInteger[] fp2MultiplyAndAccumulate(BigInteger[] v, ECPoint C,
314170
.mod(p);
315171

316172
// Compute v = v * ( l*( Q_x + C_x ) + ( i*Q_y - C_y ) )
317-
return fp2Multiply(v[0], v[1], l.multiply(Qx.add(Cx)), Qy.subtract(Cy), p);
173+
return accumulateLine(v[0], v[1], Cx, Cy, Qx, Qy, l, p);
174+
// BigInteger t_x1_bn = Qx.add(Rx).multiply(Cy).subtract(Qx.add(Cx).multiply(Ry)).mod(p);
175+
// BigInteger t_x2_bn = Cx.subtract(Rx).multiply(Qy).mod(p);
176+
// return fp2Multiply(v[0], v[1], t_x1_bn, t_x2_bn, p);
177+
318178
}
319179

320180

321-
private static BigInteger[] fp2Multiply(BigInteger x_real, BigInteger x_imag, BigInteger y_real, BigInteger y_imag, BigInteger p)
181+
static BigInteger[] fp2Multiply(BigInteger x_real, BigInteger x_imag, BigInteger y_real, BigInteger y_imag, BigInteger p)
322182
{
323183
// Multiply v = (a + i*b) * scalar
324184
return new BigInteger[]{

core/src/main/java/org/bouncycastle/crypto/kems/SAKKEKEMSGenerator.java

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -156,20 +156,12 @@ public SecretWithEncapsulation generateEncapsulated(AsymmetricKeyParameter recip
156156
BigInteger g_r = result[0].mod(p);
157157
g_r = g_r.modInverse(p);
158158
g_r = g_r.multiply(result[1]).mod(p);
159-
System.out.println("g_r " + new String(Hex.encode(g_r.toByteArray())));
160-
byte[] expected_g_r = Hex.decode("7D2A8438 E6291C64 9B6579EB 3B79EAE9\n" +
161-
" 48B1DE9E 5F7D1F40 70A08F8D B6B3C515\n" +
162-
" 6F2201AF FBB5CB9D 82AA3EC0 D0398B89\n" +
163-
" ABC78A13 A760C0BF 3F77E63D 0DF3F1A3\n" +
164-
" 41A41B88 11DF197F D6CD0F00 3125606F\n" +
165-
" 4F109F40 0F7292A1 0D255E3C 0EBCCB42\n" +
166-
" 53FB182C 68F09CF6 CD9C4A53 DA6C74AD\n" +
167-
" 007AF36B 8BCA979D 5895E282 F483FCD6");
159+
168160
BigInteger mask = SAKKEUtils.hashToIntegerRange(g_r.toByteArray(), BigInteger.ONE.shiftLeft(n)); // 2^n
169161
System.out.println(new String(Hex.encode(mask.toByteArray())));
170162

171163
BigInteger H = ssv.xor(mask);
172-
164+
System.out.println(new String(Hex.encode(H.toByteArray())));
173165
// 5. Encode encapsulated data (R_bS, H)
174166
byte[] encapsulated = Arrays.concatenate(R_bS.getEncoded(false), H.toByteArray());
175167

@@ -201,10 +193,6 @@ public static boolean sakkePointExponent(
201193
BigInteger n
202194
)
203195
{
204-
if (n.equals(BigInteger.ZERO))
205-
{
206-
return false;
207-
}
208196

209197
// Initialize result with the original point
210198
BigInteger currentX = pointX;

core/src/test/java/org/bouncycastle/crypto/kems/test/SAKKEKEMSTest.java

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,8 @@ public void performTest()
155155
ECPoint K_bS = curve.createPoint(kbx, kby);
156156

157157

158-
SAKKEKEMExtractor extractor = new SAKKEKEMExtractor(new SAKKEPrivateKeyParameters(new BigInteger(b), K_bS, new SAKKEPublicKeyParameters(null)));
158+
SAKKEKEMExtractor extractor = new SAKKEKEMExtractor(new SAKKEPrivateKeyParameters(new BigInteger(b), K_bS,
159+
new SAKKEPublicKeyParameters(curve.createPoint(Zx, Zy))));
159160
byte[] test = extractor.extractSecret(rlt.getSecret());
160161

161162

0 commit comments

Comments
 (0)