Skip to content

Commit 7a48919

Browse files
author
gefeili
committed
Refactor for Mayo
1 parent d1b1204 commit 7a48919

File tree

3 files changed

+115
-167
lines changed

3 files changed

+115
-167
lines changed

core/src/main/java/org/bouncycastle/pqc/crypto/mayo/GF16Utils.java

Lines changed: 53 additions & 87 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,8 @@
11
package org.bouncycastle.pqc.crypto.mayo;
22

3-
import org.bouncycastle.util.Pack;
4-
53
public class GF16Utils
64
{
75

8-
/**
9-
* Multiplies a 64-bit limb by a GF(16) element (represented as an int, 0–255).
10-
* This emulates gf16v_mul_u64 from C.
11-
*
12-
* @param a a 64-bit limb
13-
* @param b an 8-bit GF(16) element (only the low 4 bits are used)
14-
* @return the product as a 64-bit limb
15-
*/
16-
public static long gf16vMulU64(long a, int b)
17-
{
18-
long maskMsb = 0x8888888888888888L;
19-
// In the original code there is a conditional XOR with unsigned_char_blocker;
20-
// here we simply use b directly.
21-
long b32 = b & 0x00000000FFFFFFFFL;
22-
long r64 = a * (b32 & 1);
23-
24-
long a_msb = a & maskMsb;
25-
a ^= a_msb;
26-
a = (a << 1) ^ ((a_msb >>> 3) * 3);
27-
r64 ^= a * ((b32 >> 1) & 1);
28-
29-
a_msb = a & maskMsb;
30-
a ^= a_msb;
31-
a = (a << 1) ^ ((a_msb >>> 3) * 3);
32-
r64 ^= a * ((b32 >>> 2) & 1);
33-
34-
a_msb = a & maskMsb;
35-
a ^= a_msb;
36-
a = (a << 1) ^ ((a_msb >>> 3) * 3);
37-
r64 ^= a * ((b32 >> 3) & 1);
38-
39-
return r64;
40-
}
41-
426
/**
437
* Multiplies each limb of a GF(16) vector (subarray of 'in') by the GF(16) element 'a'
448
* and XORs the result into the corresponding subarray of acc.
@@ -48,15 +12,40 @@ public static long gf16vMulU64(long a, int b)
4812
* @param mVecLimbs the number of limbs in the vector
4913
* @param in the input long array containing the vector; the vector starts at index inOffset
5014
* @param inOffset the starting index in 'in'
51-
* @param a the GF(16) element (0–255) to multiply by
15+
* @param b the GF(16) element (0–255) to multiply by
5216
* @param acc the accumulator long array; the target vector starts at index accOffset
5317
* @param accOffset the starting index in 'acc'
5418
*/
55-
public static void mVecMulAdd(int mVecLimbs, long[] in, int inOffset, int a, long[] acc, int accOffset)
19+
public static void mVecMulAdd(int mVecLimbs, long[] in, int inOffset, int b, long[] acc, int accOffset)
5620
{
21+
long maskMsb = 0x8888888888888888L;
22+
long a, r64, a_msb;
23+
long b32 = b & 0x00000000FFFFFFFFL;
24+
long b32and1 = b32 & 1;
25+
long b32_1_1 = ((b32 >>> 1) & 1);
26+
long b32_2_1 = ((b32 >>> 2) & 1);
27+
long b32_3_1 = ((b32 >>> 3) & 1);
5728
for (int i = 0; i < mVecLimbs; i++)
5829
{
59-
acc[accOffset + i] ^= gf16vMulU64(in[inOffset + i], a);
30+
// In the original code there is a conditional XOR with unsigned_char_blocker;
31+
// here we simply use b directly.
32+
a = in[inOffset + i];
33+
r64 = a * b32and1;
34+
35+
a_msb = a & maskMsb;
36+
a ^= a_msb;
37+
a = (a << 1) ^ ((a_msb >>> 3) * 3);
38+
r64 ^= a * b32_1_1;
39+
40+
a_msb = a & maskMsb;
41+
a ^= a_msb;
42+
a = (a << 1) ^ ((a_msb >>> 3) * 3);
43+
r64 ^= a * b32_2_1;
44+
45+
a_msb = a & maskMsb;
46+
a ^= a_msb;
47+
a = (a << 1) ^ ((a_msb >>> 3) * 3);
48+
acc[accOffset + i] ^= r64 ^ (a * b32_3_1);
6049
}
6150
}
6251

@@ -65,38 +54,32 @@ public static void mVecMulAdd(int mVecLimbs, long[] in, int inOffset, int a, lon
6554
* Performs the multiplication and accumulation of a block of an upper‐triangular matrix
6655
* times a second matrix.
6756
*
68-
* @param mVecLimbs number of limbs per m-vector.
69-
* @param bsMat the “basis” matrix (as a flat long[] array); each entry occupies mVecLimbs elements.
70-
* @param mat the second matrix (as a flat byte[] array) stored row‐major,
71-
* with dimensions (bsMatCols x matCols).
72-
* @param acc the accumulator (as a flat long[] array) with dimensions (bsMatRows x matCols);
73-
* each “entry” is an m‐vector (length mVecLimbs).
74-
* @param bsMatRows number of rows in the bsMat (the “triangular” matrix’s row count).
75-
* @param bsMatCols number of columns in bsMat.
76-
* @param matCols number of columns in the matrix “mat.”
77-
* @param triangular if 1, start column index for each row is (r * triangular); otherwise use 0.
57+
* @param mVecLimbs number of limbs per m-vector.
58+
* @param bsMat the “basis” matrix (as a flat long[] array); each entry occupies mVecLimbs elements.
59+
* @param mat the second matrix (as a flat byte[] array) stored row‐major,
60+
* with dimensions (bsMatCols x matCols).
61+
* @param acc the accumulator (as a flat long[] array) with dimensions (bsMatRows x matCols);
62+
* each “entry” is an m‐vector (length mVecLimbs).
63+
* @param bsMatRows number of rows in the bsMat (the “triangular” matrix’s row count).
64+
* @param bsMatCols number of columns in bsMat.
65+
* @param matCols number of columns in the matrix “mat.”
7866
*/
79-
public static void mulAddMUpperTriangularMatXMat(int mVecLimbs, long[] bsMat, byte[] mat, long[] acc,
80-
int bsMatRows, int bsMatCols, int matCols, int triangular)
67+
public static void mulAddMUpperTriangularMatXMat(int mVecLimbs, long[] bsMat, byte[] mat, long[] acc, int accOff,
68+
int bsMatRows, int bsMatCols, int matCols)
8169
{
8270
int bsMatEntriesUsed = 0;
83-
for (int r = 0; r < bsMatRows; r++)
71+
int matColsmVecLimbs = matCols * mVecLimbs;
72+
for (int r = 0, rmatCols = 0, rmatColsmVecLimbs = 0; r < bsMatRows; r++, rmatCols += matCols, rmatColsmVecLimbs += matColsmVecLimbs)
8473
{
8574
// For each row r, the inner loop goes from column triangular*r to bsMatCols-1.
86-
for (int c = triangular * r; c < bsMatCols; c++)
75+
for (int c = r, cmatCols = rmatCols; c < bsMatCols; c++, cmatCols += matCols)
8776
{
88-
for (int k = 0; k < matCols; k++)
77+
for (int k = 0, kmVecLimbs = 0; k < matCols; k++, kmVecLimbs += mVecLimbs)
8978
{
90-
// Calculate the offsets:
91-
// For bsMat: the m-vector starting at index bsMatEntriesUsed * mVecLimbs.
92-
int bsMatOffset = bsMatEntriesUsed * mVecLimbs;
93-
// For mat: element at row c, column k (row-major layout).
94-
int a = mat[c * matCols + k] & 0xFF;
9579
// For acc: add into the m-vector at row r, column k.
96-
int accOffset = (r * matCols + k) * mVecLimbs;
97-
mVecMulAdd(mVecLimbs, bsMat, bsMatOffset, a, acc, accOffset);
80+
mVecMulAdd(mVecLimbs, bsMat, bsMatEntriesUsed, mat[cmatCols + k] & 0xFF, acc, accOff + rmatColsmVecLimbs + kmVecLimbs);
9881
}
99-
bsMatEntriesUsed++;
82+
bsMatEntriesUsed += mVecLimbs;
10083
}
10184
}
10285
}
@@ -114,18 +97,18 @@ public static void mulAddMUpperTriangularMatXMat(int mVecLimbs, long[] bsMat, by
11497
* @param matCols number of columns in “mat.”
11598
* @param bsMatCols number of columns in the bsMat matrix.
11699
*/
117-
public static void mulAddMatTransXMMat(int mVecLimbs, byte[] mat, long[] bsMat, long[] acc,
100+
public static void mulAddMatTransXMMat(int mVecLimbs, byte[] mat, long[] bsMat, int bsMatOff, long[] acc,
118101
int matRows, int matCols, int bsMatCols)
119102
{
120103
// Loop over each column r of mat (which becomes row of mat^T)
121104
for (int r = 0; r < matCols; r++)
122105
{
123-
for (int c = 0; c < matRows; c++)
106+
for (int c = 0, cmatCols = 0; c < matRows; c++, cmatCols += matCols)
124107
{
125-
byte matVal = mat[c * matCols + r];
108+
byte matVal = mat[cmatCols + r];
126109
for (int k = 0; k < bsMatCols; k++)
127110
{
128-
int bsMatOffset = (c * bsMatCols + k) * mVecLimbs;
111+
int bsMatOffset = bsMatOff + (c * bsMatCols + k) * mVecLimbs;
129112
// For acc: add into the m-vector at index (r * bsMatCols + k)
130113
int accOffset = (r * bsMatCols + k) * mVecLimbs;
131114
mVecMulAdd(mVecLimbs, bsMat, bsMatOffset, matVal, acc, accOffset);
@@ -257,10 +240,7 @@ public static int mulF(int a, int b)
257240

258241
// Perform carryless multiplication:
259242
// Multiply b by each bit of a and XOR the results.
260-
int p = ((a & 1) * b) ^
261-
((a & 2) * b) ^
262-
((a & 4) * b) ^
263-
((a & 8) * b);
243+
int p = ((a & 1) * b) ^ ((a & 2) * b) ^ ((a & 4) * b) ^ ((a & 8) * b);
264244

265245
// Reduce modulo f(X) = x^4 + x + 1.
266246
// Extract the upper nibble (bits 4 to 7).
@@ -308,17 +288,15 @@ public static void matMul(byte[] a, int aOff, byte[] b, int bOff, byte[] c, int
308288
int colrowAB, int rowA, int colB)
309289
{
310290
int cIndex = 0;
311-
for (int i = 0; i < rowA; i++)
291+
for (int i = 0, aRowStart = 0; i < rowA; i++, aRowStart += colrowAB)
312292
{
313-
int aRowStart = i * colrowAB;
314293
for (int j = 0; j < colB; j++)
315294
{
316295
c[cOff + cIndex++] = lincomb(a, aOff + aRowStart, b, bOff + j, colrowAB, colB);
317296
}
318297
}
319298
}
320299

321-
322300
private static byte lincomb(byte[] a, int aStart, byte[] b, int bStart,
323301
int colrowAB, int colB)
324302
{
@@ -332,26 +310,14 @@ private static byte lincomb(byte[] a, int aStart, byte[] b, int bStart,
332310

333311
public static void matAdd(byte[] a, int aOff, byte[] b, int bOff, byte[] c, int cOff, int m, int n)
334312
{
335-
for (int i = 0; i < m; i++)
313+
for (int i = 0, in = 0; i < m; i++, in += n)
336314
{
337315
for (int j = 0; j < n; j++)
338316
{
339-
int idx = i * n + j;
317+
int idx = in + j;
340318
c[idx + cOff] = (byte)(a[idx + aOff] ^ b[idx + bOff]);
341319
}
342320
}
343321
}
344-
345-
public static void efUnpackMVector(int legs, long[] packedRow, int packedRowOff, byte[] out)
346-
{
347-
int outIndex = 0;
348-
byte[] bytes = new byte[out.length >> 1];
349-
Pack.longToLittleEndian(packedRow, packedRowOff, out.length >> 4, bytes, 0);
350-
for (int i = 0; i < legs * 16; i += 2)
351-
{
352-
out[outIndex++] = (byte)(bytes[i / 2] & 0x0F); // Lower nibble
353-
out[outIndex++] = (byte)((bytes[i / 2] >> 4) & 0x0F); // Upper nibble
354-
}
355-
}
356322
}
357323

core/src/main/java/org/bouncycastle/pqc/crypto/mayo/MayoKeyPairGenerator.java

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,6 @@ public AsymmetricCipherKeyPair generateKeyPair()
4747
// Allocate P3 as a long array of size (O_MAX * O_MAX * M_VEC_LIMBS_MAX), zero-initialized.
4848
long[] P3 = new long[o * o * mVecLimbs];
4949

50-
// Allocate O as a byte array of size (V_MAX * O_MAX).
51-
// Here we assume V_MAX is given by p.getV() (or replace with a constant if needed).
5250
byte[] O = new byte[v * o];
5351

5452
// Generate secret key seed (seed_sk) using a secure random generator.
@@ -60,27 +58,21 @@ public AsymmetricCipherKeyPair generateKeyPair()
6058
// o ← Decode_o(S[ param_pk_seed_bytes : param_pk_seed_bytes + O_bytes ])
6159
// Decode nibbles from S starting at offset param_pk_seed_bytes into O,
6260
// with expected output length = param_v * param_o.
63-
Utils.decode(seed_pk, pkSeedBytes, O, v * o);
61+
Utils.decode(seed_pk, pkSeedBytes, O, O.length);
6462

6563
// Expand P1 and P2 into the array P using seed_pk.
6664
MayoEngine.expandP1P2(p, P, seed_pk);
6765

68-
// For compute_P3, we need to separate P1 and P2.
69-
// Here, we treat P1 as the first param_P1_limbs elements of P,
70-
// and P2 as the remaining elements.
71-
long[] P2 = new long[P.length - p1Limbs];
72-
System.arraycopy(P, p1Limbs, P2, 0, P2.length);
73-
7466
// Compute P1 * O + P2 and store the result in P2.
75-
// GF16Utils.P1TimesO(p, P, O, P2);
67+
// GF16Utils.P1TimesO(p, P, O, P2);
7668
// Here, bsMatRows and bsMatCols are both paramV, and matCols is paramO, triangular=1.
77-
GF16Utils.mulAddMUpperTriangularMatXMat(mVecLimbs, P, O, P2, v, v, o, 1);
69+
GF16Utils.mulAddMUpperTriangularMatXMat(mVecLimbs, P, O, P, p1Limbs, v, v, o);
7870

7971
// Compute P3 = O^T * (P1*O + P2).
8072
// Here, treat P2 as the bsMat for the multiplication.
8173
// Dimensions: mat = O (size: paramV x paramO), bsMat = P2 (size: paramV x paramO),
8274
// and acc (P3) will have dimensions: (paramO x paramO), each entry being an m-vector.
83-
GF16Utils.mulAddMatTransXMMat(mVecLimbs, O, P2, P3, v, o, o);
75+
GF16Utils.mulAddMatTransXMMat(mVecLimbs, O, P, p1Limbs, P3, v, o, o);
8476

8577
// Store seed_pk into the public key cpk.
8678
System.arraycopy(seed_pk, 0, cpk, 0, pkSeedBytes);
@@ -90,25 +82,20 @@ public AsymmetricCipherKeyPair generateKeyPair()
9082

9183
// Compute Upper(P3) and store the result in P3_upper.
9284
int mVecsStored = 0;
93-
for (int r = 0; r < o; r++)
85+
int omVecLimbs = o * mVecLimbs;
86+
for (int r = 0, rmVecLimbs = 0, romVecLimbs = 0; r < o; r++, romVecLimbs += omVecLimbs, rmVecLimbs += mVecLimbs)
9487
{
95-
for (int c = r; c < o; c++)
88+
for (int c = r, cmVecLimbs = rmVecLimbs, comVecLimbs = romVecLimbs; c < o; c++, cmVecLimbs += mVecLimbs, comVecLimbs += omVecLimbs)
9689
{
97-
// Compute the starting index for the (r, c) vector in the input array.
98-
int srcOffset = mVecLimbs * (r * o + c);
99-
// Compute the output offset for the current stored vector.
100-
int destOffset = mVecLimbs * mVecsStored;
101-
10290
// Copy the vector at (r, c) into the output.
103-
System.arraycopy(P3, srcOffset, P3_upper, destOffset, mVecLimbs);
91+
System.arraycopy(P3, romVecLimbs + cmVecLimbs, P3_upper, mVecsStored, mVecLimbs);
10492

10593
// If off-diagonal, add (XOR) the vector at (c, r) into the same output vector.
10694
if (r != c)
10795
{
108-
int srcOffset2 = mVecLimbs * (c * o + r);
109-
Longs.xorTo(mVecLimbs, P3, srcOffset2, P3_upper, destOffset);
96+
Longs.xorTo(mVecLimbs, P3, comVecLimbs + rmVecLimbs, P3_upper, mVecsStored);
11097
}
111-
mVecsStored++;
98+
mVecsStored += mVecLimbs;
11299
}
113100
}
114101

@@ -118,7 +105,6 @@ public AsymmetricCipherKeyPair generateKeyPair()
118105
Utils.packMVecs(P3_upper, cpk, pkSeedBytes, p3Limbs / mVecLimbs, m);
119106
// Securely clear sensitive data.
120107
Arrays.clear(O);
121-
Arrays.clear(P2);
122108
Arrays.clear(P3);
123109

124110
return new AsymmetricCipherKeyPair(new MayoPublicKeyParameter(p, cpk), new MayoPrivateKeyParameter(p, seed_sk));

0 commit comments

Comments
 (0)