1414
1515import static org .bouncycastle .crypto .kems .SAKKEKEMSGenerator .pairing ;
1616
17- public class SAKKEKEMExtractor implements EncapsulatedSecretExtractor
17+ public class SAKKEKEMExtractor
18+ implements EncapsulatedSecretExtractor
1819{
1920 private final ECCurve curve ;
2021 private final BigInteger p ;
@@ -25,7 +26,8 @@ public class SAKKEKEMExtractor implements EncapsulatedSecretExtractor
2526 private final int n ; // Security parameter
2627 private final SAKKEPrivateKeyParameters privateKey ;
2728
28- public SAKKEKEMExtractor (SAKKEPrivateKeyParameters privateKey ) {
29+ public SAKKEKEMExtractor (SAKKEPrivateKeyParameters privateKey )
30+ {
2931 this .privateKey = privateKey ;
3032 SAKKEPublicKeyParameters publicKey = privateKey .getPublicParams ();
3133 this .curve = publicKey .getCurve ();
@@ -38,14 +40,18 @@ public SAKKEKEMExtractor(SAKKEPrivateKeyParameters privateKey) {
3840 }
3941
4042 @ Override
41- public byte [] extractSecret (byte [] encapsulation ) {
42- try {
43+ public byte [] extractSecret (byte [] encapsulation )
44+ {
45+ try
46+ {
4347 // Step 1: Parse Encapsulated Data (R_bS, H)
44- ECPoint R_bS = parseECPoint ( encapsulation );
45- BigInteger H = parseH ( encapsulation );
48+ ECPoint R_bS = curve . decodePoint ( Arrays . copyOfRange ( encapsulation , 0 , 257 ) );
49+ BigInteger H = new BigInteger ( Arrays . copyOfRange ( encapsulation , 257 , 274 ) );
4650
4751 // Step 2: Compute w = <R_bS, K_bS> using pairing
48- BigInteger w = computePairing (R_bS , K_bS );
52+ // BigInteger w = computeTLPairing(new BigInteger[] {R_bS.getXCoord().toBigInteger(), R_bS.getYCoord().toBigInteger()},
53+ // new BigInteger[] {K_bS.getXCoord().toBigInteger(), K_bS.getYCoord().toBigInteger()}, this.p, this.q);
54+ BigInteger w = computePairing (R_bS , K_bS , p , q );
4955
5056 // Step 3: Compute SSV = H XOR HashToIntegerRange(w, 2^n)
5157 BigInteger ssv = computeSSV (H , w );
@@ -58,8 +64,10 @@ public byte[] extractSecret(byte[] encapsulation) {
5864// throw new IllegalStateException("Validation of R_bS failed");
5965// }
6066
61- return BigIntegers .asUnsignedByteArray (n /8 , ssv );
62- } catch (Exception e ) {
67+ return BigIntegers .asUnsignedByteArray (n / 8 , ssv );
68+ }
69+ catch (Exception e )
70+ {
6371 throw new IllegalStateException ("SAKKE extraction failed: " + e .getMessage ());
6472 }
6573 }
@@ -70,59 +78,263 @@ public int getEncapsulationLength()
7078 return 0 ;
7179 }
7280
73- private ECPoint parseECPoint (byte [] encapsulation ) {
74- int coordLen = (p .bitLength () + 7 ) / 8 ;
75- byte [] xBytes = Arrays .copyOfRange (encapsulation , 0 , coordLen );
76- byte [] yBytes = Arrays .copyOfRange (encapsulation , coordLen , 2 *coordLen );
81+ private BigInteger computePairing (ECPoint R , ECPoint K )
82+ {
83+ // Use your existing pairing implementation
84+ return pairing (R , K , p , q );
85+ }
86+
87+ private BigInteger computeSSV (BigInteger H , BigInteger w )
88+ {
89+ BigInteger twoToN = BigInteger .ONE .shiftLeft (n );
90+ BigInteger mask = SAKKEUtils .hashToIntegerRange (w .toByteArray (), twoToN );
91+ return H .xor (mask );
92+ }
93+
94+ public static BigInteger computeTLPairing (
95+ BigInteger [] R , // C = (Rx, Ry)
96+ BigInteger [] Q , // Q = (Qx, Qy)
97+ BigInteger p ,
98+ BigInteger q
99+ )
100+ {
101+ BigInteger qMinus1 = q .subtract (BigInteger .ONE );
102+ int N = qMinus1 .bitLength () - 1 ;
103+
104+ // Initialize V = (1, 0)
105+ BigInteger [] V = {BigInteger .ONE , BigInteger .ZERO };
106+ // Initialize C = R
107+ BigInteger [] C = {R [0 ], R [1 ]};
108+
109+ for (; N > 0 ; N --)
110+ {
111+ // V = V^2
112+ pointSquare (V , p );
113+
114+ // Compute line function T
115+ BigInteger [] T = computeLineFunctionT (C , Q , p );
116+
117+ // V = V * T
118+ pointMultiply (V , T , p );
119+
120+ // C = 2*C (point doubling)
121+ pointDouble (C , p );
77122
78- BigInteger x = new BigInteger (1 , xBytes );
79- BigInteger y = new BigInteger (1 , yBytes );
123+ if (qMinus1 .testBit (N - 1 ))
124+ {
125+ // Compute addition line function
126+ BigInteger [] TAdd = computeLineFunctionAdd (C , R , Q , p );
80127
81- return curve .createPoint (x , y ).normalize ();
128+ // V = V * TAdd
129+ pointMultiply (V , TAdd , p );
130+
131+ // C = C + R (point addition)
132+ pointAdd (C , R , p );
133+ }
134+ }
135+
136+ // Final squaring
137+ pointSquare (V , p );
138+ pointSquare (V , p );
139+
140+ // Compute w = (Vy * Vx^{-1}) mod p
141+ BigInteger VxInv = V [0 ].modInverse (p );
142+ return V [1 ].multiply (VxInv ).mod (p );
82143 }
83144
84- private BigInteger parseH (byte [] encapsulation ) {
85- int coordLen = (p .bitLength () + 7 ) / 8 ;
86- byte [] hBytes = Arrays .copyOfRange (encapsulation , 2 *coordLen , encapsulation .length );
87- return new BigInteger (1 , hBytes );
145+ private static void pointSquare (BigInteger [] point , BigInteger p )
146+ {
147+ BigInteger x = point [0 ];
148+ BigInteger y = point [1 ];
149+
150+ // x = (x + y)(x - y) mod p
151+ BigInteger xPlusY = x .add (y ).mod (p );
152+ BigInteger xMinusY = x .subtract (y ).mod (p );
153+ BigInteger newX = xPlusY .multiply (xMinusY ).mod (p );
154+
155+ // y = 2xy mod p
156+ BigInteger newY = x .multiply (y ).multiply (BigInteger .valueOf (2 )).mod (p );
157+
158+ point [0 ] = newX ;
159+ point [1 ] = newY ;
88160 }
89161
90- private BigInteger computePairing (ECPoint R , ECPoint K ) {
91- // Use your existing pairing implementation
92- return pairing (R , K , p , q );
162+ private static void pointMultiply (BigInteger [] a , BigInteger [] b , BigInteger p )
163+ {
164+ // Complex multiplication (a + bi)*(c + di) = (ac - bd) + (ad + bc)i
165+ BigInteger real = a [0 ].multiply (b [0 ]).subtract (a [1 ].multiply (b [1 ])).mod (p );
166+ BigInteger imag = a [0 ].multiply (b [1 ]).add (a [1 ].multiply (b [0 ])).mod (p );
167+
168+ a [0 ] = real ;
169+ a [1 ] = imag ;
93170 }
94171
95- private BigInteger computeSSV (BigInteger H , BigInteger w ) {
96- BigInteger twoToN = BigInteger .ONE .shiftLeft (n );
97- BigInteger mask = SAKKEUtils .hashToIntegerRange (w .toByteArray (), twoToN );
98- return H .xor (mask );
172+ private static void pointDouble (BigInteger [] point , BigInteger p )
173+ {
174+ // Elliptic curve point doubling formulas
175+ BigInteger x = point [0 ];
176+ BigInteger y = point [1 ];
177+
178+ BigInteger slope = x .pow (2 ).multiply (BigInteger .valueOf (3 ))
179+ .mod (p )
180+ .multiply (y .multiply (BigInteger .valueOf (2 )).modInverse (p ))
181+ .mod (p );
182+
183+ BigInteger newX = slope .pow (2 ).subtract (x .multiply (BigInteger .valueOf (2 ))).mod (p );
184+ BigInteger newY = slope .multiply (x .subtract (newX )).subtract (y ).mod (p );
185+
186+ point [0 ] = newX ;
187+ point [1 ] = newY ;
99188 }
100189
101- private BigInteger computeR (BigInteger ssv , byte [] userId ) {
102- byte [] ssvBytes = BigIntegers .asUnsignedByteArray (ssv );
103- byte [] ssvConcatB = Arrays .concatenate (ssvBytes , userId );
104- return SAKKEUtils .hashToIntegerRange (ssvConcatB , q );
190+ private static void pointAdd (BigInteger [] a , BigInteger [] b , BigInteger p )
191+ {
192+ // Elliptic curve point addition
193+ BigInteger x1 = a [0 ], y1 = a [1 ];
194+ BigInteger x2 = b [0 ], y2 = b [1 ];
195+
196+ BigInteger slope = y2 .subtract (y1 )
197+ .multiply (x2 .subtract (x1 ).modInverse (p ))
198+ .mod (p );
199+
200+ BigInteger newX = slope .pow (2 ).subtract (x1 ).subtract (x2 ).mod (p );
201+ BigInteger newY = slope .multiply (x1 .subtract (newX )).subtract (y1 ).mod (p );
202+
203+ a [0 ] = newX ;
204+ a [1 ] = newY ;
105205 }
106206
107- private boolean validateR_bS (BigInteger r , byte [] b , ECPoint receivedR ) {
108- try {
109- // Compute [b]P
110- ECPoint bP = P .multiply (new BigInteger (1 , b )).normalize ();
207+ private static BigInteger [] computeLineFunctionT (
208+ BigInteger [] C ,
209+ BigInteger [] Q ,
210+ BigInteger p
211+ )
212+ {
213+ // Line function evaluation for doubling
214+ BigInteger Cx = C [0 ], Cy = C [1 ];
215+ BigInteger Qx = Q [0 ], Qy = Q [1 ];
111216
112- // Compute [b]P + Z_S
113- ECPoint bP_plus_Z = bP .add (Z_S ).normalize ();
217+ // l = (3Cx² + a)/(2Cy) but a=0 for many curves
218+ BigInteger numerator = Cx .pow (2 ).multiply (BigInteger .valueOf (3 )).mod (p );
219+ BigInteger denominator = Cy .multiply (BigInteger .valueOf (2 )).mod (p );
220+ BigInteger l = numerator .multiply (denominator .modInverse (p )).mod (p );
114221
115- // Compute [r]([b]P + Z_S)
116- ECPoint computedR = bP_plus_Z .multiply (r ).normalize ();
222+ // T = l*(Qx + Cx) - 2Qy
223+ BigInteger tReal = l .multiply (Qx .add (Cx ).mod (p )).mod (p );
224+ BigInteger tImag = l .multiply (Qy ).negate ().mod (p );
117225
118- return pointsEqual (computedR , receivedR );
119- } catch (Exception e ) {
120- return false ;
121- }
226+ return new BigInteger []{tReal , tImag };
122227 }
123228
124- private boolean pointsEqual (ECPoint p1 , ECPoint p2 ) {
229+ private static BigInteger [] computeLineFunctionAdd (
230+ BigInteger [] C ,
231+ BigInteger [] R ,
232+ BigInteger [] Q ,
233+ BigInteger p
234+ )
235+ {
236+ // Line function evaluation for addition
237+ BigInteger Cx = C [0 ], Cy = C [1 ];
238+ BigInteger Rx = R [0 ], Ry = R [1 ];
239+ BigInteger Qx = Q [0 ], Qy = Q [1 ];
240+
241+ // l = (Cy - Ry)/(Cx - Rx)
242+ BigInteger numerator = Cy .subtract (Ry ).mod (p );
243+ BigInteger denominator = Cx .subtract (Rx ).mod (p );
244+ BigInteger l = numerator .multiply (denominator .modInverse (p )).mod (p );
245+
246+ // T = l*(Qx + Cx) - Qy
247+ BigInteger tReal = l .multiply (Qx .add (Cx ).mod (p )).mod (p );
248+ BigInteger tImag = l .multiply (Qy ).negate ().mod (p );
249+
250+ return new BigInteger []{tReal , tImag };
251+ }
252+
253+ private boolean pointsEqual (ECPoint p1 , ECPoint p2 )
254+ {
125255 return p1 .normalize ().getXCoord ().equals (p2 .normalize ().getXCoord ())
126256 && p1 .normalize ().getYCoord ().equals (p2 .normalize ().getYCoord ());
127257 }
258+
259+ public static BigInteger computePairing (ECPoint R , ECPoint Q , BigInteger p , BigInteger q )
260+ {
261+ BigInteger c = p .add (BigInteger .ONE ).divide (q ); // Compute c = (p+1)/q
262+ BigInteger [] v = new BigInteger []{BigInteger .ONE , BigInteger .ZERO }; // v = (1,0) in F_p^2
263+ ECPoint C = R ;
264+
265+ BigInteger qMinusOne = q .subtract (BigInteger .ONE );
266+ int numBits = qMinusOne .bitLength ();
267+
268+ // Miller loop
269+ for (int i = numBits - 2 ; i >= 0 ; i --)
270+ {
271+ v = fp2SquareAndAccumulate (v , C , Q , p );
272+ C = C .twice ().normalize (); // C = [2]C
273+
274+ if (qMinusOne .testBit (i ))
275+ {
276+ v = fp2MultiplyAndAccumulate (v , C , R , Q , p );
277+ C = C .add (R ).normalize ();
278+ }
279+ }
280+
281+ // Final exponentiation: t = v^c
282+ return fp2FinalExponentiation (v , p , c );
283+ }
284+
285+ private static BigInteger [] fp2SquareAndAccumulate (BigInteger [] v , ECPoint C , ECPoint Q , BigInteger p )
286+ {
287+ BigInteger Cx = C .getAffineXCoord ().toBigInteger ();
288+ BigInteger Cy = C .getAffineYCoord ().toBigInteger ();
289+ BigInteger Qx = Q .getAffineXCoord ().toBigInteger ();
290+ BigInteger Qy = Q .getAffineYCoord ().toBigInteger ();
291+
292+ // Compute l = (3 * (Cx^2 - 1)) / (2 * Cy) mod p
293+ BigInteger l = Cx .multiply (Cx ).mod (p ).subtract (BigInteger .ONE ).multiply (BigInteger .valueOf (3 )).mod (p )
294+ .multiply (Cy .multiply (BigInteger .valueOf (2 )).modInverse (p ))
295+ .mod (p );
296+
297+ // Compute v = v^2 * ( l*( Q_x + C_x ) + ( i*Q_y - C_y ) )
298+ v = fp2Multiply (v [0 ], v [1 ], v [0 ], v [1 ], p );
299+ return fp2Multiply (v [0 ], v [1 ], l .multiply (Qx .add (Cx )), (Qy .subtract (Cy )), p );
300+ }
301+
302+ private static BigInteger [] fp2MultiplyAndAccumulate (BigInteger [] v , ECPoint C , ECPoint R , ECPoint Q , BigInteger p )
303+ {
304+ BigInteger Cx = C .getAffineXCoord ().toBigInteger ();
305+ BigInteger Cy = C .getAffineYCoord ().toBigInteger ();
306+ BigInteger Rx = R .getAffineXCoord ().toBigInteger ();
307+ BigInteger Ry = R .getAffineYCoord ().toBigInteger ();
308+ BigInteger Qx = Q .getAffineXCoord ().toBigInteger ();
309+ BigInteger Qy = Q .getAffineYCoord ().toBigInteger ();
310+
311+ // Compute l = (Cy - Ry) / (Cx - Rx) mod p
312+ BigInteger l = Cy .subtract (Ry )
313+ .multiply (Cx .subtract (Rx ).modInverse (p ))
314+ .mod (p );
315+
316+ // Compute v = v * ( l*( Q_x + C_x ) + ( i*Q_y - C_y ) )
317+ return fp2Multiply (v [0 ], v [1 ], l .multiply (Qx .add (Cx )), Qy .subtract (Cy ), p );
318+ }
319+
320+
321+ private static BigInteger [] fp2Multiply (BigInteger x_real , BigInteger x_imag , BigInteger y_real , BigInteger y_imag , BigInteger p )
322+ {
323+ // Multiply v = (a + i*b) * scalar
324+ return new BigInteger []{
325+ x_real .multiply (y_real ).subtract (x_imag .multiply (y_imag )).mod (p ),
326+ x_real .multiply (y_imag ).add (x_imag .multiply (y_real )).mod (p )
327+ };
328+ }
329+
330+ private static BigInteger fp2FinalExponentiation (BigInteger [] v , BigInteger p , BigInteger c )
331+ {
332+ // Compute representative in F_p: return b/a (mod p)
333+ // BigInteger v0 = v[0].modPow(c, p);
334+ // BigInteger v1 = v[1].modPow(c, p);
335+ // return v1.multiply(v0.modInverse(p)).mod(p);
336+ v = fp2Multiply (v [0 ], v [1 ], v [0 ], v [1 ], p );
337+ v = fp2Multiply (v [0 ], v [1 ], v [0 ], v [1 ], p );
338+ return v [1 ].multiply (v [0 ].modInverse (p )).mod (p );
339+ }
128340}
0 commit comments