Skip to content

Commit 2434c1c

Browse files
gefeilidghgit
authored andcommitted
TODO fix the bugs in SAKKEKEMExtractor
1 parent d53e3ab commit 2434c1c

File tree

7 files changed

+597
-98
lines changed

7 files changed

+597
-98
lines changed
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
package org.bouncycastle.crypto.kems;
2+
3+
import java.math.BigInteger;
4+
5+
import org.bouncycastle.crypto.Digest;
6+
import org.bouncycastle.crypto.EncapsulatedSecretExtractor;
7+
import org.bouncycastle.crypto.digests.SHA256Digest;
8+
import org.bouncycastle.crypto.params.SAKKEPrivateKeyParameters;
9+
import org.bouncycastle.crypto.params.SAKKEPublicKeyParameters;
10+
import org.bouncycastle.math.ec.ECCurve;
11+
import org.bouncycastle.math.ec.ECPoint;
12+
import org.bouncycastle.util.Arrays;
13+
import org.bouncycastle.util.BigIntegers;
14+
15+
import static org.bouncycastle.crypto.kems.SAKKEKEMSGenerator.pairing;
16+
17+
public class SAKKEKEMExtractor implements EncapsulatedSecretExtractor
18+
{
19+
private final ECCurve curve;
20+
private final BigInteger p;
21+
private final BigInteger q;
22+
private final ECPoint P;
23+
private final ECPoint Z_S;
24+
private final ECPoint K_bS; // Receiver's RSK
25+
private final int n; // Security parameter
26+
private final SAKKEPrivateKeyParameters privateKey;
27+
28+
public SAKKEKEMExtractor(SAKKEPrivateKeyParameters privateKey) {
29+
this.privateKey = privateKey;
30+
SAKKEPublicKeyParameters publicKey = privateKey.getPublicParams();
31+
this.curve = publicKey.getCurve();
32+
this.q = publicKey.getQ();
33+
this.P = publicKey.getP();
34+
this.p = publicKey.getp();
35+
this.Z_S = publicKey.getZ();
36+
this.K_bS = privateKey.getPrivatePoint();
37+
this.n = publicKey.getN();
38+
}
39+
40+
@Override
41+
public byte[] extractSecret(byte[] encapsulation) {
42+
try {
43+
// Step 1: Parse Encapsulated Data (R_bS, H)
44+
ECPoint R_bS = parseECPoint(encapsulation);
45+
BigInteger H = parseH(encapsulation);
46+
47+
// Step 2: Compute w = <R_bS, K_bS> using pairing
48+
BigInteger w = computePairing(R_bS, K_bS);
49+
50+
// Step 3: Compute SSV = H XOR HashToIntegerRange(w, 2^n)
51+
BigInteger ssv = computeSSV(H, w);
52+
53+
// Step 4: Compute r = HashToIntegerRange(SSV || b)
54+
// BigInteger r = computeR(ssv, privateKey.getPrivatePoint());
55+
//
56+
// // Step 5: Validate R_bS
57+
// if (!validateR_bS(r, privateKey.getPrivatePoint(), R_bS)) {
58+
// throw new IllegalStateException("Validation of R_bS failed");
59+
// }
60+
61+
return BigIntegers.asUnsignedByteArray(n/8, ssv);
62+
} catch (Exception e) {
63+
throw new IllegalStateException("SAKKE extraction failed: " + e.getMessage());
64+
}
65+
}
66+
67+
@Override
68+
public int getEncapsulationLength()
69+
{
70+
return 0;
71+
}
72+
73+
private ECPoint parseECPoint(byte[] encapsulation) {
74+
int coordLen = (p.bitLength() + 7) / 8;
75+
byte[] xBytes = Arrays.copyOfRange(encapsulation, 0, coordLen);
76+
byte[] yBytes = Arrays.copyOfRange(encapsulation, coordLen, 2*coordLen);
77+
78+
BigInteger x = new BigInteger(1, xBytes);
79+
BigInteger y = new BigInteger(1, yBytes);
80+
81+
return curve.createPoint(x, y).normalize();
82+
}
83+
84+
private BigInteger parseH(byte[] encapsulation) {
85+
int coordLen = (p.bitLength() + 7) / 8;
86+
byte[] hBytes = Arrays.copyOfRange(encapsulation, 2*coordLen, encapsulation.length);
87+
return new BigInteger(1, hBytes);
88+
}
89+
90+
private BigInteger computePairing(ECPoint R, ECPoint K) {
91+
// Use your existing pairing implementation
92+
return pairing(R, K, p, q);
93+
}
94+
95+
private BigInteger computeSSV(BigInteger H, BigInteger w) {
96+
BigInteger twoToN = BigInteger.ONE.shiftLeft(n);
97+
BigInteger mask = SAKKEUtils.hashToIntegerRange(w.toByteArray(), twoToN);
98+
return H.xor(mask);
99+
}
100+
101+
private BigInteger computeR(BigInteger ssv, byte[] userId) {
102+
byte[] ssvBytes = BigIntegers.asUnsignedByteArray(ssv);
103+
byte[] ssvConcatB = Arrays.concatenate(ssvBytes, userId);
104+
return SAKKEUtils.hashToIntegerRange(ssvConcatB, q);
105+
}
106+
107+
private boolean validateR_bS(BigInteger r, byte[] b, ECPoint receivedR) {
108+
try {
109+
// Compute [b]P
110+
ECPoint bP = P.multiply(new BigInteger(1, b)).normalize();
111+
112+
// Compute [b]P + Z_S
113+
ECPoint bP_plus_Z = bP.add(Z_S).normalize();
114+
115+
// Compute [r]([b]P + Z_S)
116+
ECPoint computedR = bP_plus_Z.multiply(r).normalize();
117+
118+
return pointsEqual(computedR, receivedR);
119+
} catch (Exception e) {
120+
return false;
121+
}
122+
}
123+
124+
private boolean pointsEqual(ECPoint p1, ECPoint p2) {
125+
return p1.normalize().getXCoord().equals(p2.normalize().getXCoord())
126+
&& p1.normalize().getYCoord().equals(p2.normalize().getYCoord());
127+
}
128+
}

0 commit comments

Comments
 (0)