Skip to content

Commit a9c7179

Browse files
gefeilidghgit
authored andcommitted
TODO pairing
1 parent 2434c1c commit a9c7179

File tree

2 files changed

+490
-206
lines changed

2 files changed

+490
-206
lines changed

core/src/main/java/org/bouncycastle/crypto/kems/SAKKEKEMExtractor.java

Lines changed: 256 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
import static org.bouncycastle.crypto.kems.SAKKEKEMSGenerator.pairing;
1616

17-
public class SAKKEKEMExtractor implements EncapsulatedSecretExtractor
17+
public class SAKKEKEMExtractor
18+
implements EncapsulatedSecretExtractor
1819
{
1920
private final ECCurve curve;
2021
private final BigInteger p;
@@ -25,7 +26,8 @@ public class SAKKEKEMExtractor implements EncapsulatedSecretExtractor
2526
private final int n; // Security parameter
2627
private final SAKKEPrivateKeyParameters privateKey;
2728

28-
public SAKKEKEMExtractor(SAKKEPrivateKeyParameters privateKey) {
29+
public SAKKEKEMExtractor(SAKKEPrivateKeyParameters privateKey)
30+
{
2931
this.privateKey = privateKey;
3032
SAKKEPublicKeyParameters publicKey = privateKey.getPublicParams();
3133
this.curve = publicKey.getCurve();
@@ -38,14 +40,18 @@ public SAKKEKEMExtractor(SAKKEPrivateKeyParameters privateKey) {
3840
}
3941

4042
@Override
41-
public byte[] extractSecret(byte[] encapsulation) {
42-
try {
43+
public byte[] extractSecret(byte[] encapsulation)
44+
{
45+
try
46+
{
4347
// Step 1: Parse Encapsulated Data (R_bS, H)
44-
ECPoint R_bS = parseECPoint(encapsulation);
45-
BigInteger H = parseH(encapsulation);
48+
ECPoint R_bS = curve.decodePoint(Arrays.copyOfRange(encapsulation, 0, 257));
49+
BigInteger H = new BigInteger(Arrays.copyOfRange(encapsulation, 257, 274));
4650

4751
// Step 2: Compute w = <R_bS, K_bS> using pairing
48-
BigInteger w = computePairing(R_bS, K_bS);
52+
// BigInteger w = computeTLPairing(new BigInteger[] {R_bS.getXCoord().toBigInteger(), R_bS.getYCoord().toBigInteger()},
53+
// new BigInteger[] {K_bS.getXCoord().toBigInteger(), K_bS.getYCoord().toBigInteger()}, this.p, this.q);
54+
BigInteger w = computePairing(R_bS, K_bS, p, q);
4955

5056
// Step 3: Compute SSV = H XOR HashToIntegerRange(w, 2^n)
5157
BigInteger ssv = computeSSV(H, w);
@@ -58,8 +64,10 @@ public byte[] extractSecret(byte[] encapsulation) {
5864
// throw new IllegalStateException("Validation of R_bS failed");
5965
// }
6066

61-
return BigIntegers.asUnsignedByteArray(n/8, ssv);
62-
} catch (Exception e) {
67+
return BigIntegers.asUnsignedByteArray(n / 8, ssv);
68+
}
69+
catch (Exception e)
70+
{
6371
throw new IllegalStateException("SAKKE extraction failed: " + e.getMessage());
6472
}
6573
}
@@ -70,59 +78,263 @@ public int getEncapsulationLength()
7078
return 0;
7179
}
7280

73-
private ECPoint parseECPoint(byte[] encapsulation) {
74-
int coordLen = (p.bitLength() + 7) / 8;
75-
byte[] xBytes = Arrays.copyOfRange(encapsulation, 0, coordLen);
76-
byte[] yBytes = Arrays.copyOfRange(encapsulation, coordLen, 2*coordLen);
81+
private BigInteger computePairing(ECPoint R, ECPoint K)
82+
{
83+
// Use your existing pairing implementation
84+
return pairing(R, K, p, q);
85+
}
86+
87+
private BigInteger computeSSV(BigInteger H, BigInteger w)
88+
{
89+
BigInteger twoToN = BigInteger.ONE.shiftLeft(n);
90+
BigInteger mask = SAKKEUtils.hashToIntegerRange(w.toByteArray(), twoToN);
91+
return H.xor(mask);
92+
}
93+
94+
public static BigInteger computeTLPairing(
95+
BigInteger[] R, // C = (Rx, Ry)
96+
BigInteger[] Q, // Q = (Qx, Qy)
97+
BigInteger p,
98+
BigInteger q
99+
)
100+
{
101+
BigInteger qMinus1 = q.subtract(BigInteger.ONE);
102+
int N = qMinus1.bitLength() - 1;
103+
104+
// Initialize V = (1, 0)
105+
BigInteger[] V = {BigInteger.ONE, BigInteger.ZERO};
106+
// Initialize C = R
107+
BigInteger[] C = {R[0], R[1]};
108+
109+
for (; N > 0; N--)
110+
{
111+
// V = V^2
112+
pointSquare(V, p);
113+
114+
// Compute line function T
115+
BigInteger[] T = computeLineFunctionT(C, Q, p);
116+
117+
// V = V * T
118+
pointMultiply(V, T, p);
119+
120+
// C = 2*C (point doubling)
121+
pointDouble(C, p);
77122

78-
BigInteger x = new BigInteger(1, xBytes);
79-
BigInteger y = new BigInteger(1, yBytes);
123+
if (qMinus1.testBit(N - 1))
124+
{
125+
// Compute addition line function
126+
BigInteger[] TAdd = computeLineFunctionAdd(C, R, Q, p);
80127

81-
return curve.createPoint(x, y).normalize();
128+
// V = V * TAdd
129+
pointMultiply(V, TAdd, p);
130+
131+
// C = C + R (point addition)
132+
pointAdd(C, R, p);
133+
}
134+
}
135+
136+
// Final squaring
137+
pointSquare(V, p);
138+
pointSquare(V, p);
139+
140+
// Compute w = (Vy * Vx^{-1}) mod p
141+
BigInteger VxInv = V[0].modInverse(p);
142+
return V[1].multiply(VxInv).mod(p);
82143
}
83144

84-
private BigInteger parseH(byte[] encapsulation) {
85-
int coordLen = (p.bitLength() + 7) / 8;
86-
byte[] hBytes = Arrays.copyOfRange(encapsulation, 2*coordLen, encapsulation.length);
87-
return new BigInteger(1, hBytes);
145+
private static void pointSquare(BigInteger[] point, BigInteger p)
146+
{
147+
BigInteger x = point[0];
148+
BigInteger y = point[1];
149+
150+
// x = (x + y)(x - y) mod p
151+
BigInteger xPlusY = x.add(y).mod(p);
152+
BigInteger xMinusY = x.subtract(y).mod(p);
153+
BigInteger newX = xPlusY.multiply(xMinusY).mod(p);
154+
155+
// y = 2xy mod p
156+
BigInteger newY = x.multiply(y).multiply(BigInteger.valueOf(2)).mod(p);
157+
158+
point[0] = newX;
159+
point[1] = newY;
88160
}
89161

90-
private BigInteger computePairing(ECPoint R, ECPoint K) {
91-
// Use your existing pairing implementation
92-
return pairing(R, K, p, q);
162+
private static void pointMultiply(BigInteger[] a, BigInteger[] b, BigInteger p)
163+
{
164+
// Complex multiplication (a + bi)*(c + di) = (ac - bd) + (ad + bc)i
165+
BigInteger real = a[0].multiply(b[0]).subtract(a[1].multiply(b[1])).mod(p);
166+
BigInteger imag = a[0].multiply(b[1]).add(a[1].multiply(b[0])).mod(p);
167+
168+
a[0] = real;
169+
a[1] = imag;
93170
}
94171

95-
private BigInteger computeSSV(BigInteger H, BigInteger w) {
96-
BigInteger twoToN = BigInteger.ONE.shiftLeft(n);
97-
BigInteger mask = SAKKEUtils.hashToIntegerRange(w.toByteArray(), twoToN);
98-
return H.xor(mask);
172+
private static void pointDouble(BigInteger[] point, BigInteger p)
173+
{
174+
// Elliptic curve point doubling formulas
175+
BigInteger x = point[0];
176+
BigInteger y = point[1];
177+
178+
BigInteger slope = x.pow(2).multiply(BigInteger.valueOf(3))
179+
.mod(p)
180+
.multiply(y.multiply(BigInteger.valueOf(2)).modInverse(p))
181+
.mod(p);
182+
183+
BigInteger newX = slope.pow(2).subtract(x.multiply(BigInteger.valueOf(2))).mod(p);
184+
BigInteger newY = slope.multiply(x.subtract(newX)).subtract(y).mod(p);
185+
186+
point[0] = newX;
187+
point[1] = newY;
99188
}
100189

101-
private BigInteger computeR(BigInteger ssv, byte[] userId) {
102-
byte[] ssvBytes = BigIntegers.asUnsignedByteArray(ssv);
103-
byte[] ssvConcatB = Arrays.concatenate(ssvBytes, userId);
104-
return SAKKEUtils.hashToIntegerRange(ssvConcatB, q);
190+
private static void pointAdd(BigInteger[] a, BigInteger[] b, BigInteger p)
191+
{
192+
// Elliptic curve point addition
193+
BigInteger x1 = a[0], y1 = a[1];
194+
BigInteger x2 = b[0], y2 = b[1];
195+
196+
BigInteger slope = y2.subtract(y1)
197+
.multiply(x2.subtract(x1).modInverse(p))
198+
.mod(p);
199+
200+
BigInteger newX = slope.pow(2).subtract(x1).subtract(x2).mod(p);
201+
BigInteger newY = slope.multiply(x1.subtract(newX)).subtract(y1).mod(p);
202+
203+
a[0] = newX;
204+
a[1] = newY;
105205
}
106206

107-
private boolean validateR_bS(BigInteger r, byte[] b, ECPoint receivedR) {
108-
try {
109-
// Compute [b]P
110-
ECPoint bP = P.multiply(new BigInteger(1, b)).normalize();
207+
private static BigInteger[] computeLineFunctionT(
208+
BigInteger[] C,
209+
BigInteger[] Q,
210+
BigInteger p
211+
)
212+
{
213+
// Line function evaluation for doubling
214+
BigInteger Cx = C[0], Cy = C[1];
215+
BigInteger Qx = Q[0], Qy = Q[1];
111216

112-
// Compute [b]P + Z_S
113-
ECPoint bP_plus_Z = bP.add(Z_S).normalize();
217+
// l = (3Cx² + a)/(2Cy) but a=0 for many curves
218+
BigInteger numerator = Cx.pow(2).multiply(BigInteger.valueOf(3)).mod(p);
219+
BigInteger denominator = Cy.multiply(BigInteger.valueOf(2)).mod(p);
220+
BigInteger l = numerator.multiply(denominator.modInverse(p)).mod(p);
114221

115-
// Compute [r]([b]P + Z_S)
116-
ECPoint computedR = bP_plus_Z.multiply(r).normalize();
222+
// T = l*(Qx + Cx) - 2Qy
223+
BigInteger tReal = l.multiply(Qx.add(Cx).mod(p)).mod(p);
224+
BigInteger tImag = l.multiply(Qy).negate().mod(p);
117225

118-
return pointsEqual(computedR, receivedR);
119-
} catch (Exception e) {
120-
return false;
121-
}
226+
return new BigInteger[]{tReal, tImag};
122227
}
123228

124-
private boolean pointsEqual(ECPoint p1, ECPoint p2) {
229+
private static BigInteger[] computeLineFunctionAdd(
230+
BigInteger[] C,
231+
BigInteger[] R,
232+
BigInteger[] Q,
233+
BigInteger p
234+
)
235+
{
236+
// Line function evaluation for addition
237+
BigInteger Cx = C[0], Cy = C[1];
238+
BigInteger Rx = R[0], Ry = R[1];
239+
BigInteger Qx = Q[0], Qy = Q[1];
240+
241+
// l = (Cy - Ry)/(Cx - Rx)
242+
BigInteger numerator = Cy.subtract(Ry).mod(p);
243+
BigInteger denominator = Cx.subtract(Rx).mod(p);
244+
BigInteger l = numerator.multiply(denominator.modInverse(p)).mod(p);
245+
246+
// T = l*(Qx + Cx) - Qy
247+
BigInteger tReal = l.multiply(Qx.add(Cx).mod(p)).mod(p);
248+
BigInteger tImag = l.multiply(Qy).negate().mod(p);
249+
250+
return new BigInteger[]{tReal, tImag};
251+
}
252+
253+
private boolean pointsEqual(ECPoint p1, ECPoint p2)
254+
{
125255
return p1.normalize().getXCoord().equals(p2.normalize().getXCoord())
126256
&& p1.normalize().getYCoord().equals(p2.normalize().getYCoord());
127257
}
258+
259+
public static BigInteger computePairing(ECPoint R, ECPoint Q, BigInteger p, BigInteger q)
260+
{
261+
BigInteger c = p.add(BigInteger.ONE).divide(q); // Compute c = (p+1)/q
262+
BigInteger[] v = new BigInteger[]{BigInteger.ONE, BigInteger.ZERO}; // v = (1,0) in F_p^2
263+
ECPoint C = R;
264+
265+
BigInteger qMinusOne = q.subtract(BigInteger.ONE);
266+
int numBits = qMinusOne.bitLength();
267+
268+
// Miller loop
269+
for (int i = numBits - 2; i >= 0; i--)
270+
{
271+
v = fp2SquareAndAccumulate(v, C, Q, p);
272+
C = C.twice().normalize(); // C = [2]C
273+
274+
if (qMinusOne.testBit(i))
275+
{
276+
v = fp2MultiplyAndAccumulate(v, C, R, Q, p);
277+
C = C.add(R).normalize();
278+
}
279+
}
280+
281+
// Final exponentiation: t = v^c
282+
return fp2FinalExponentiation(v, p, c);
283+
}
284+
285+
private static BigInteger[] fp2SquareAndAccumulate(BigInteger[] v, ECPoint C, ECPoint Q, BigInteger p)
286+
{
287+
BigInteger Cx = C.getAffineXCoord().toBigInteger();
288+
BigInteger Cy = C.getAffineYCoord().toBigInteger();
289+
BigInteger Qx = Q.getAffineXCoord().toBigInteger();
290+
BigInteger Qy = Q.getAffineYCoord().toBigInteger();
291+
292+
// Compute l = (3 * (Cx^2 - 1)) / (2 * Cy) mod p
293+
BigInteger l = Cx.multiply(Cx).mod(p).subtract(BigInteger.ONE).multiply(BigInteger.valueOf(3)).mod(p)
294+
.multiply(Cy.multiply(BigInteger.valueOf(2)).modInverse(p))
295+
.mod(p);
296+
297+
// Compute v = v^2 * ( l*( Q_x + C_x ) + ( i*Q_y - C_y ) )
298+
v = fp2Multiply(v[0], v[1], v[0], v[1], p);
299+
return fp2Multiply(v[0], v[1], l.multiply(Qx.add(Cx)), (Qy.subtract(Cy)), p);
300+
}
301+
302+
private static BigInteger[] fp2MultiplyAndAccumulate(BigInteger[] v, ECPoint C, ECPoint R, ECPoint Q, BigInteger p)
303+
{
304+
BigInteger Cx = C.getAffineXCoord().toBigInteger();
305+
BigInteger Cy = C.getAffineYCoord().toBigInteger();
306+
BigInteger Rx = R.getAffineXCoord().toBigInteger();
307+
BigInteger Ry = R.getAffineYCoord().toBigInteger();
308+
BigInteger Qx = Q.getAffineXCoord().toBigInteger();
309+
BigInteger Qy = Q.getAffineYCoord().toBigInteger();
310+
311+
// Compute l = (Cy - Ry) / (Cx - Rx) mod p
312+
BigInteger l = Cy.subtract(Ry)
313+
.multiply(Cx.subtract(Rx).modInverse(p))
314+
.mod(p);
315+
316+
// Compute v = v * ( l*( Q_x + C_x ) + ( i*Q_y - C_y ) )
317+
return fp2Multiply(v[0], v[1], l.multiply(Qx.add(Cx)), Qy.subtract(Cy), p);
318+
}
319+
320+
321+
private static BigInteger[] fp2Multiply(BigInteger x_real, BigInteger x_imag, BigInteger y_real, BigInteger y_imag, BigInteger p)
322+
{
323+
// Multiply v = (a + i*b) * scalar
324+
return new BigInteger[]{
325+
x_real.multiply(y_real).subtract(x_imag.multiply(y_imag)).mod(p),
326+
x_real.multiply(y_imag).add(x_imag.multiply(y_real)).mod(p)
327+
};
328+
}
329+
330+
private static BigInteger fp2FinalExponentiation(BigInteger[] v, BigInteger p, BigInteger c)
331+
{
332+
// Compute representative in F_p: return b/a (mod p)
333+
// BigInteger v0 = v[0].modPow(c, p);
334+
// BigInteger v1 = v[1].modPow(c, p);
335+
// return v1.multiply(v0.modInverse(p)).mod(p);
336+
v = fp2Multiply(v[0], v[1], v[0], v[1], p);
337+
v = fp2Multiply(v[0], v[1], v[0], v[1], p);
338+
return v[1].multiply(v[0].modInverse(p)).mod(p);
339+
}
128340
}

0 commit comments

Comments
 (0)