Skip to content

Commit 01dbe52

Browse files
author
gefeili
committed
Refactor for Mayo
1 parent e356f4a commit 01dbe52

File tree

7 files changed

+94
-228
lines changed

7 files changed

+94
-228
lines changed

core/src/main/java/org/bouncycastle/pqc/crypto/mayo/GF16Utils.java

Lines changed: 18 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -16,26 +16,25 @@ public class GF16Utils
1616
public static long gf16vMulU64(long a, int b)
1717
{
1818
long maskMsb = 0x8888888888888888L;
19-
long a64 = a;
2019
// In the original code there is a conditional XOR with unsigned_char_blocker;
2120
// here we simply use b directly.
2221
long b32 = b & 0x00000000FFFFFFFFL;
23-
long r64 = a64 * (b32 & 1);
22+
long r64 = a * (b32 & 1);
2423

25-
long a_msb = a64 & maskMsb;
26-
a64 ^= a_msb;
27-
a64 = (a64 << 1) ^ ((a_msb >>> 3) * 3);
28-
r64 ^= a64 * ((b32 >> 1) & 1);
24+
long a_msb = a & maskMsb;
25+
a ^= a_msb;
26+
a = (a << 1) ^ ((a_msb >>> 3) * 3);
27+
r64 ^= a * ((b32 >> 1) & 1);
2928

30-
a_msb = a64 & maskMsb;
31-
a64 ^= a_msb;
32-
a64 = (a64 << 1) ^ ((a_msb >>> 3) * 3);
33-
r64 ^= a64 * ((b32 >>> 2) & 1);
29+
a_msb = a & maskMsb;
30+
a ^= a_msb;
31+
a = (a << 1) ^ ((a_msb >>> 3) * 3);
32+
r64 ^= a * ((b32 >>> 2) & 1);
3433

35-
a_msb = a64 & maskMsb;
36-
a64 ^= a_msb;
37-
a64 = (a64 << 1) ^ ((a_msb >>> 3) * 3);
38-
r64 ^= a64 * ((b32 >> 3) & 1);
34+
a_msb = a & maskMsb;
35+
a ^= a_msb;
36+
a = (a << 1) ^ ((a_msb >>> 3) * 3);
37+
r64 ^= a * ((b32 >> 3) & 1);
3938

4039
return r64;
4140
}
@@ -61,18 +60,6 @@ public static void mVecMulAdd(int mVecLimbs, long[] in, int inOffset, int a, lon
6160
}
6261
}
6362

64-
/**
65-
* Convenience overload of mVecMulAdd that assumes zero offsets.
66-
*
67-
* @param mVecLimbs the number of limbs
68-
* @param in the input vector
69-
* @param a the GF(16) element to multiply by
70-
* @param acc the accumulator vector
71-
*/
72-
public static void mVecMulAdd(int mVecLimbs, long[] in, int a, long[] acc)
73-
{
74-
mVecMulAdd(mVecLimbs, in, 0, a, acc, 0);
75-
}
7663

7764
/**
7865
* Performs the multiplication and accumulation of a block of an upper‐triangular matrix
@@ -156,33 +143,18 @@ public static void mulAddMatTransXMMat(int mVecLimbs, byte[] mat, long[] bsMat,
156143
{
157144
for (int c = 0; c < matRows; c++)
158145
{
146+
byte matVal = mat[c * matCols + r];
159147
for (int k = 0; k < bsMatCols; k++)
160148
{
161-
// For bsMat: the m-vector at index (c * bsMatCols + k)
162149
int bsMatOffset = (c * bsMatCols + k) * mVecLimbs;
163-
// For mat: element at row c, column r.
164-
int a = mat[c * matCols + r] & 0xFF;
165150
// For acc: add into the m-vector at index (r * bsMatCols + k)
166151
int accOffset = (r * bsMatCols + k) * mVecLimbs;
167-
mVecMulAdd(mVecLimbs, bsMat, bsMatOffset, a, acc, accOffset);
152+
mVecMulAdd(mVecLimbs, bsMat, bsMatOffset, matVal, acc, accOffset);
168153
}
169154
}
170155
}
171156
}
172157

173-
174-
/**
175-
* Adds (bitwise XOR) mVecLimbs elements from the source array (starting at srcOffset)
176-
* into the destination array (starting at destOffset).
177-
*/
178-
public static void mVecAdd(int mVecLimbs, long[] src, int srcOffset, long[] dest, int destOffset)
179-
{
180-
for (int i = 0; i < mVecLimbs; i++)
181-
{
182-
dest[destOffset + i] ^= src[srcOffset + i];
183-
}
184-
}
185-
186158
/**
187159
* Multiplies a matrix (given as a byte array) with a bit‐sliced matrix (given as a long array)
188160
* and accumulates the result into the acc array.
@@ -288,30 +260,6 @@ public static void mulAddMUpperTriangularMatXMatTrans(int mVecLimbs, long[] bsMa
288260
}
289261
}
290262

291-
/**
292-
* Multiplies a vector (from bsMat) by an unsigned scalar (from mat) and adds the result
293-
* to the corresponding vector in acc.
294-
*
295-
* <p>
296-
* This method corresponds to the C function <code>m_vec_mul_add</code>.
297-
* It processes {@code mVecLimbs} elements starting from the given offsets in the source and accumulator arrays.
298-
* </p>
299-
*
300-
* @param mVecLimbs the number of limbs (elements) in the vector
301-
* @param bsMat the source array (bit-sliced matrix) of long values
302-
* @param bsMatOffset the starting index in bsMat for the vector
303-
* @param scalar the scalar value (from mat), as a byte
304-
* @param acc the accumulator array where the result is added
305-
* @param accOffset the starting index in the accumulator array for the current vector
306-
*/
307-
public static void mVecMulAdd(int mVecLimbs, long[] bsMat, int bsMatOffset, byte scalar, long[] acc, int accOffset)
308-
{
309-
for (int i = 0; i < mVecLimbs; i++)
310-
{
311-
acc[accOffset + i] ^= gf16vMulU64(bsMat[bsMatOffset + i], scalar);
312-
}
313-
}
314-
315263
/**
316264
* GF(16) multiplication mod x^4 + x + 1.
317265
* <p>
@@ -339,8 +287,7 @@ public static int mulF(int a, int b)
339287
// Extract the upper nibble (bits 4 to 7).
340288
int topP = p & 0xF0;
341289
// The reduction: XOR p with (topP shifted right by 4 and by 3) and mask to 4 bits.
342-
int out = (p ^ (topP >> 4) ^ (topP >> 3)) & 0x0F;
343-
return out;
290+
return (p ^ (topP >> 4) ^ (topP >> 3)) & 0x0F;
344291
}
345292

346293
/**
@@ -364,8 +311,7 @@ public static long mulFx8(byte a, long b)
364311

365312
// Reduction mod (x^4 + x + 1): process each byte in parallel.
366313
long topP = p & 0xf0f0f0f0f0f0f0f0L;
367-
long out = (p ^ (topP >> 4) ^ (topP >> 3)) & 0x0f0f0f0f0f0f0f0fL;
368-
return out;
314+
return (p ^ (topP >> 4) ^ (topP >> 3)) & 0x0f0f0f0f0f0f0f0fL;
369315
}
370316

371317
public static void matMul(byte[] a, byte[] b, byte[] c,
@@ -420,9 +366,6 @@ public static void matAdd(byte[] a, int aOff, byte[] b, int bOff, byte[] c, int
420366
}
421367
}
422368

423-
// Define the blocker constant as needed (set to 0 if not used).
424-
private static final byte UNSIGNED_CHAR_BLOCKER = 0;
425-
426369
/**
427370
* Returns 0x00 if a equals b, otherwise returns 0xFF.
428371
* This operation is performed in constant time.
@@ -442,9 +385,7 @@ public static byte ctCompare8(byte a, byte b)
442385
// If diff is 0, then -diff is 0, and shifting yields 0.
443386
// If diff is nonzero, -diff is negative, so the arithmetic shift yields -1 (0xFFFFFFFF),
444387
// which when cast to a byte becomes 0xFF.
445-
int result = negDiff >> 31;
446-
// XOR with UNSIGNED_CHAR_BLOCKER (assumed 0 here) and cast to byte.
447-
return (byte)(result ^ UNSIGNED_CHAR_BLOCKER);
388+
return (byte) (negDiff >> 31);
448389
}
449390

450391
public static void efUnpackMVector(int legs, long[] packedRow, int packedRowOff, byte[] out)

core/src/main/java/org/bouncycastle/pqc/crypto/mayo/MayoKeyPairGenerator.java

Lines changed: 47 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -5,173 +5,120 @@
55
import org.bouncycastle.crypto.AsymmetricCipherKeyPair;
66
import org.bouncycastle.crypto.AsymmetricCipherKeyPairGenerator;
77
import org.bouncycastle.crypto.KeyGenerationParameters;
8-
import org.bouncycastle.util.Pack;
8+
import org.bouncycastle.util.Arrays;
9+
import org.bouncycastle.util.Longs;
910

1011
public class MayoKeyPairGenerator
1112
implements AsymmetricCipherKeyPairGenerator
1213
{
1314
private MayoParameters p;
1415
private SecureRandom random;
1516

16-
1717
public void init(KeyGenerationParameters param)
1818
{
1919
this.p = ((MayoKeyGenerationParameters)param).getParameters();
2020
this.random = param.getRandom();
2121
}
2222

23-
2423
@Override
2524
public AsymmetricCipherKeyPair generateKeyPair()
2625
{
27-
int ret = MayoEngine.MAYO_OK;
26+
// Retrieve parameters from p.
27+
int mVecLimbs = p.getMVecLimbs();
28+
int m = p.getM();
29+
int v = p.getV();
30+
int o = p.getO();
31+
int oBytes = p.getOBytes();
32+
int p1Limbs = p.getP1Limbs();
33+
int p3Limbs = p.getP3Limbs();
34+
int pkSeedBytes = p.getPkSeedBytes();
35+
int skSeedBytes = p.getSkSeedBytes();
36+
2837
byte[] cpk = new byte[p.getCpkBytes()];
2938
// seed_sk points to csk.
3039
byte[] seed_sk = new byte[p.getCskBytes()];
3140

3241
// Allocate S = new byte[PK_SEED_BYTES_MAX + O_BYTES_MAX]
33-
byte[] S = new byte[p.getPkSeedBytes() + p.getOBytes()];
42+
byte[] seed_pk = new byte[pkSeedBytes + oBytes];
3443

3544
// Allocate P as a long array of size (P1_LIMBS_MAX + P2_LIMBS_MAX)
36-
long[] P = new long[p.getP1Limbs() + p.getP2Limbs()];
45+
long[] P = new long[p1Limbs + p.getP2Limbs()];
3746

3847
// Allocate P3 as a long array of size (O_MAX * O_MAX * M_VEC_LIMBS_MAX), zero-initialized.
39-
long[] P3 = new long[p.getO() * p.getO() * p.getMVecLimbs()];
40-
41-
// seed_pk will be a reference into S.
42-
byte[] seed_pk;
48+
long[] P3 = new long[o * o * mVecLimbs];
4349

4450
// Allocate O as a byte array of size (V_MAX * O_MAX).
4551
// Here we assume V_MAX is given by p.getV() (or replace with a constant if needed).
46-
byte[] O = new byte[p.getV() * p.getO()];
47-
48-
// Retrieve parameters from p.
49-
int m_vec_limbs = p.getMVecLimbs();
50-
int param_m = p.getM();
51-
int param_v = p.getV();
52-
int param_o = p.getO();
53-
int param_O_bytes = p.getOBytes();
54-
int param_P1_limbs = p.getP1Limbs();
55-
int param_P3_limbs = p.getP3Limbs();
56-
int param_pk_seed_bytes = p.getPkSeedBytes();
57-
int param_sk_seed_bytes = p.getSkSeedBytes();
58-
59-
// In the C code, P1 is P and P2 is P offset by param_P1_limbs.
60-
// In Java, we will have functions (like expandP1P2) work on the full array P.
52+
byte[] O = new byte[v * o];
6153

6254
// Generate secret key seed (seed_sk) using a secure random generator.
6355
random.nextBytes(seed_sk);
6456

6557
// S ← shake256(seed_sk, pk_seed_bytes + O_bytes)
66-
Utils.shake256(S, param_pk_seed_bytes + param_O_bytes, seed_sk, param_sk_seed_bytes);
67-
68-
// seed_pk is the beginning of S.
69-
seed_pk = S;
58+
Utils.shake256(seed_pk, pkSeedBytes + oBytes, seed_sk, skSeedBytes);
7059

7160
// o ← Decode_o(S[ param_pk_seed_bytes : param_pk_seed_bytes + O_bytes ])
7261
// Decode nibbles from S starting at offset param_pk_seed_bytes into O,
7362
// with expected output length = param_v * param_o.
74-
Utils.decode(S, param_pk_seed_bytes, O, param_v * param_o);
63+
Utils.decode(seed_pk, pkSeedBytes, O, v * o);
7564

7665
// Expand P1 and P2 into the array P using seed_pk.
7766
MayoEngine.expandP1P2(p, P, seed_pk);
7867

7968
// For compute_P3, we need to separate P1 and P2.
8069
// Here, we treat P1 as the first param_P1_limbs elements of P,
8170
// and P2 as the remaining elements.
82-
long[] P1 = P;
83-
long[] P2 = new long[P.length - param_P1_limbs];
84-
System.arraycopy(P, param_P1_limbs, P2, 0, P2.length);
85-
86-
// Compute P3, which (in the process) modifies P2.
87-
computeP3(p, P1, P2, O, P3);
88-
89-
// Store seed_pk into the public key cpk.
90-
System.arraycopy(seed_pk, 0, cpk, 0, param_pk_seed_bytes);
91-
92-
// Allocate an array for the "upper" part of P3.
93-
long[] P3_upper = new long[p.getP3Limbs()];
94-
95-
// Compute Upper(P3) and store the result in P3_upper.
96-
mUpper(p, P3, P3_upper, param_o);
97-
98-
// Pack the m-vectors in P3_upper into cpk (after the seed_pk).
99-
// The number of m-vectors to pack is (param_P3_limbs / m_vec_limbs),
100-
// and param_m is used as the m value.
101-
Utils.packMVecs(P3_upper, cpk, param_pk_seed_bytes, param_P3_limbs / m_vec_limbs, param_m);
102-
// Securely clear sensitive data.
103-
// secureClear(O);
104-
// secureClear(P2);
105-
// secureClear(P3);
106-
107-
return new AsymmetricCipherKeyPair(new MayoPublicKeyParameter(p, cpk), new MayoPrivateKeyParameter(p, seed_sk));
108-
}
109-
110-
/**
111-
* Computes P3 from P1, P2, and O.
112-
* <p>
113-
* In C, compute_P3 does:
114-
* 1. Compute P1*O + P2, storing result in P2.
115-
* 2. Compute P3 = O^T * (P1*O + P2).
116-
*
117-
* @param p the parameter object.
118-
* @param P1 the P1 matrix as a long[] array.
119-
* @param P2 the P2 matrix as a long[] array; on output, P1*O is added to it.
120-
* @param O the O matrix as a byte[] array.
121-
* @param P3 the output matrix (as a long[] array) which will receive O^T*(P1*O + P2).
122-
*/
123-
public static void computeP3(MayoParameters p, long[] P1, long[] P2, byte[] O, long[] P3)
124-
{
125-
int mVecLimbs = p.getMVecLimbs();
126-
int paramV = p.getV();
127-
int paramO = p.getO();
71+
long[] P2 = new long[P.length - p1Limbs];
72+
System.arraycopy(P, p1Limbs, P2, 0, P2.length);
12873

12974
// Compute P1 * O + P2 and store the result in P2.
130-
GF16Utils.P1TimesO(p, P1, O, P2);
75+
GF16Utils.P1TimesO(p, P, O, P2);
13176

13277
// Compute P3 = O^T * (P1*O + P2).
13378
// Here, treat P2 as the bsMat for the multiplication.
13479
// Dimensions: mat = O (size: paramV x paramO), bsMat = P2 (size: paramV x paramO),
13580
// and acc (P3) will have dimensions: (paramO x paramO), each entry being an m-vector.
136-
GF16Utils.mulAddMatTransXMMat(mVecLimbs, O, P2, P3, paramV, paramO, paramO);
137-
}
81+
GF16Utils.mulAddMatTransXMMat(mVecLimbs, O, P2, P3, v, o, o);
13882

139-
/**
140-
* Reproduces the behavior of the C function m_upper.
141-
* <p>
142-
* For each pair (r, c) with 0 <= r <= c < size, it copies the m-vector at
143-
* position (r, c) from 'in' to the next position in 'out' and, if r != c,
144-
* it adds (XORs) the m-vector at position (c, r) into that same output vector.
145-
*
146-
* @param p the parameter object (used to get mVecLimbs)
147-
* @param in the input long array (each vector is mVecLimbs in length)
148-
* @param out the output long array (must be large enough to store all output vectors)
149-
* @param size the size parameter defining the matrix dimensions.
150-
*/
151-
public static void mUpper(MayoParameters p, long[] in, long[] out, int size)
152-
{
153-
int mVecLimbs = p.getMVecLimbs();
83+
// Store seed_pk into the public key cpk.
84+
System.arraycopy(seed_pk, 0, cpk, 0, pkSeedBytes);
85+
86+
// Allocate an array for the "upper" part of P3.
87+
long[] P3_upper = new long[p3Limbs];
88+
89+
// Compute Upper(P3) and store the result in P3_upper.
15490
int mVecsStored = 0;
155-
for (int r = 0; r < size; r++)
91+
for (int r = 0; r < o; r++)
15692
{
157-
for (int c = r; c < size; c++)
93+
for (int c = r; c < o; c++)
15894
{
15995
// Compute the starting index for the (r, c) vector in the input array.
160-
int srcOffset = mVecLimbs * (r * size + c);
96+
int srcOffset = mVecLimbs * (r * o + c);
16197
// Compute the output offset for the current stored vector.
16298
int destOffset = mVecLimbs * mVecsStored;
16399

164100
// Copy the vector at (r, c) into the output.
165-
System.arraycopy(in, srcOffset, out, destOffset, mVecLimbs);
101+
System.arraycopy(P3, srcOffset, P3_upper, destOffset, mVecLimbs);
166102

167103
// If off-diagonal, add (XOR) the vector at (c, r) into the same output vector.
168104
if (r != c)
169105
{
170-
int srcOffset2 = mVecLimbs * (c * size + r);
171-
GF16Utils.mVecAdd(mVecLimbs, in, srcOffset2, out, destOffset);
106+
int srcOffset2 = mVecLimbs * (c * o + r);
107+
Longs.xorTo(mVecLimbs, P3, srcOffset2, P3_upper, destOffset);
172108
}
173109
mVecsStored++;
174110
}
175111
}
112+
113+
// Pack the m-vectors in P3_upper into cpk (after the seed_pk).
114+
// The number of m-vectors to pack is (param_P3_limbs / m_vec_limbs),
115+
// and param_m is used as the m value.
116+
Utils.packMVecs(P3_upper, cpk, pkSeedBytes, p3Limbs / mVecLimbs, m);
117+
// Securely clear sensitive data.
118+
Arrays.clear(O);
119+
Arrays.clear(P2);
120+
Arrays.clear(P3);
121+
122+
return new AsymmetricCipherKeyPair(new MayoPublicKeyParameter(p, cpk), new MayoPrivateKeyParameter(p, seed_sk));
176123
}
177124
}

0 commit comments

Comments
 (0)