Skip to content

Commit 9ef9e78

Browse files
author
gefeili
committed
Refactor of Mayo
1 parent 50f0914 commit 9ef9e78

File tree

4 files changed

+60
-103
lines changed

4 files changed

+60
-103
lines changed

core/src/main/java/org/bouncycastle/pqc/crypto/mayo/GF16Utils.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ static void mulAddMUpperTriangularMatXMat(int mVecLimbs, long[] bsMat, byte[] ma
8282
for (int k = 0, kmVecLimbs = 0; k < matCols; k++, kmVecLimbs += mVecLimbs)
8383
{
8484
// For acc: add into the m-vector at row r, column k.
85-
mVecMulAdd(mVecLimbs, bsMat, bsMatEntriesUsed, mat[cmatCols + k] & 0xFF, acc, accOff + rmatColsmVecLimbs + kmVecLimbs);
85+
mVecMulAdd(mVecLimbs, bsMat, bsMatEntriesUsed, mat[cmatCols + k], acc, accOff + rmatColsmVecLimbs + kmVecLimbs);
8686
}
8787
bsMatEntriesUsed += mVecLimbs;
8888
}
@@ -256,7 +256,7 @@ static long mulFx8(byte a, long b)
256256

257257
// Reduction mod (x^4 + x + 1): process each byte in parallel.
258258
long topP = p & 0xf0f0f0f0f0f0f0f0L;
259-
return (p ^ (topP >> 4) ^ (topP >> 3)) & 0x0f0f0f0f0f0f0f0fL;
259+
return (p ^ (topP >>> 4) ^ (topP >>> 3)) & 0x0f0f0f0f0f0f0f0fL;
260260
}
261261

262262
static void matMul(byte[] a, byte[] b, int bOff, byte[] c, int colrowAB, int rowA)

core/src/main/java/org/bouncycastle/pqc/crypto/mayo/MayoSigner.java

Lines changed: 33 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ public class MayoSigner
1616
implements MessageSigner
1717
{
1818
private SecureRandom random;
19-
MayoParameters params;
19+
private MayoParameters params;
2020
private MayoPublicKeyParameters pubKey;
2121
private MayoPrivateKeyParameters privKey;
2222

@@ -57,6 +57,7 @@ public byte[] generateSignature(byte[] message)
5757
int v = params.getV();
5858
int o = params.getO();
5959
int n = params.getN();
60+
int m = params.getM();
6061
int vbytes = params.getVBytes();
6162
int oBytes = params.getOBytes();
6263
int saltBytes = params.getSaltBytes();
@@ -66,16 +67,17 @@ public byte[] generateSignature(byte[] message)
6667
int digestBytes = params.getDigestBytes();
6768
int skSeedBytes = params.getSkSeedBytes();
6869
byte[] tenc = new byte[params.getMBytes()];
69-
byte[] t = new byte[params.getM()];
70-
byte[] y = new byte[params.getM()];
70+
byte[] t = new byte[m];
71+
byte[] y = new byte[m];
7172
byte[] salt = new byte[saltBytes];
7273
byte[] V = new byte[k * vbytes + params.getRBytes()];
7374
byte[] Vdec = new byte[v * k];
7475
int ok = k * o;
75-
byte[] A = new byte[((params.getM() + 7) / 8 * 8) * (ok + 1)];
76-
byte[] x = new byte[k * n];
76+
int nk = k * n;
77+
byte[] A = new byte[((m + 7) / 8 * 8) * (ok + 1)];
78+
byte[] x = new byte[nk];
7779
byte[] r = new byte[ok + 1];
78-
byte[] s = new byte[k * n];
80+
byte[] s = new byte[nk];
7981
byte[] tmp = new byte[digestBytes + saltBytes + skSeedBytes + 1];
8082
byte[] sig = new byte[params.getSigBytes()];
8183
long[] P = new long[p1Limbs + params.getP2Limbs()];
@@ -121,9 +123,9 @@ public byte[] generateSignature(byte[] message)
121123
{
122124
// Multiply the m-vector at P1 for the current matrix entry,
123125
// and accumulate into acc for row r.
124-
GF16Utils.mVecMulAdd(mVecLimbs, P, bsMatEntriesUsed, O[co + j] & 0xFF, P, iomVecLimbs + jmVecLimbs);
126+
GF16Utils.mVecMulAdd(mVecLimbs, P, bsMatEntriesUsed, O[co + j], P, iomVecLimbs + jmVecLimbs);
125127
// Similarly, accumulate into acc for row c.
126-
GF16Utils.mVecMulAdd(mVecLimbs, P, bsMatEntriesUsed, O[io + j] & 0xFF, P, comVecLimbs + jmVecLimbs);
128+
GF16Utils.mVecMulAdd(mVecLimbs, P, bsMatEntriesUsed, O[io + j], P, comVecLimbs + jmVecLimbs);
127129
}
128130
bsMatEntriesUsed += mVecLimbs;
129131
}
@@ -150,9 +152,10 @@ public byte[] generateSignature(byte[] message)
150152
System.arraycopy(salt, 0, tmp, digestBytes, saltBytes);
151153
shake.update(tmp, 0, digestBytes + saltBytes);
152154
shake.doFinal(tenc, 0, params.getMBytes());
153-
Utils.decode(tenc, t, params.getM());
155+
Utils.decode(tenc, t, m);
154156
int size = v * k * mVecLimbs;
155157
long[] Pv = new long[size];
158+
byte[] Ox = new byte[v];
156159
for (int ctr = 0; ctr <= 255; ctr++)
157160
{
158161
tmp[tmp.length - 1] = (byte)ctr;
@@ -182,12 +185,12 @@ public byte[] generateSignature(byte[] message)
182185
computeA(Mtmp, A);
183186

184187
// Clear trailing bytes
185-
// for (int i = 0; i < params.getM(); ++i)
188+
// for (int i = 0; i < m; ++i)
186189
// {
187190
// A[(i + 1) * (ok + 1) - 1] = 0;
188191
// }
189192

190-
Utils.decode(V, k * vbytes, r, 0, ok);
193+
Utils.decode(V, k * vbytes, r, ok);
191194

192195
if (sampleSolution(params, A, y, r, x))
193196
{
@@ -201,16 +204,16 @@ public byte[] generateSignature(byte[] message)
201204
}
202205

203206
// Compute final signature components
204-
byte[] Ox = new byte[v];
205-
for (int i = 0; i < k; i++)
207+
208+
for (int i = 0, io = 0, in = 0, iv = 0; i < k; i++, io += o, in+= n, iv += v)
206209
{
207-
GF16Utils.matMul(O, x, i * o, Ox, o, n - o);
208-
Bytes.xor(v, Vdec, i * v, Ox, s, i * n);
209-
System.arraycopy(x, i * o, s, i * n + n - o, o);
210+
GF16Utils.matMul(O, x, io, Ox, o, v);
211+
Bytes.xor(v, Vdec, iv, Ox, s, in);
212+
System.arraycopy(x, io, s, in + v, o);
210213
}
211214

212215
// Encode and add salt
213-
Utils.encode(s, sig, n * k);
216+
Utils.encode(s, sig, nk);
214217
System.arraycopy(salt, 0, sig, sig.length - saltBytes, saltBytes);
215218

216219
return Arrays.concatenate(sig, message);
@@ -294,13 +297,12 @@ void computeRHS(long[] vPv, byte[] t, byte[] y)
294297
final int k = params.getK();
295298
final int[] fTail = params.getFTail();
296299

297-
final int topPos = ((m - 1) & 15) * 4;
300+
final int topPos = ((m - 1) & 15) << 2;
298301

299302
// Zero out tails of m_vecs if necessary
300303
if ((m & 15) != 0)
301304
{
302-
long mask = 1L << ((m & 15) << 2);
303-
mask -= 1;
305+
long mask = (1L << ((m & 15) << 2)) - 1;
304306
final int kSquared = k * k;
305307

306308
for (int i = 0, index = mVecLimbs - 1; i < kSquared; i++, index += mVecLimbs)
@@ -409,7 +411,7 @@ void computeA(long[] Mtmp, byte[] AOut)
409411
}
410412
}
411413

412-
for (int i = 0, io = 0; i < k; i++, io += o)
414+
for (int i = 0, io = 0, iomVecLimbs = 0; i < k; i++, io += o, iomVecLimbs += omVecLimbs)
413415
{
414416
for (int j = k - 1, jomVecLimbs = j * omVecLimbs, jo = j * o; j >= i; j--, jomVecLimbs -= omVecLimbs, jo -= o)
415417
{
@@ -433,13 +435,11 @@ void computeA(long[] Mtmp, byte[] AOut)
433435
if (i != j)
434436
{
435437
// Process Mi
436-
int miOffset = i * mVecLimbs * o;
437438
for (int c = 0, cmVecLimbs = 0; c < o; c++, cmVecLimbs += mVecLimbs)
438439
{
439440
for (int limb = 0, limbAWidhth = 0; limb < mVecLimbs; limb++, limbAWidhth += AWidth)
440441
{
441-
long value = Mtmp[miOffset + limb + cmVecLimbs];
442-
442+
long value = Mtmp[iomVecLimbs + limb + cmVecLimbs];
443443
int aIndex = jo + c + wordsToShift + limbAWidhth;
444444
A[aIndex] ^= value << bitsToShift;
445445

@@ -461,7 +461,7 @@ void computeA(long[] Mtmp, byte[] AOut)
461461
}
462462

463463
// Transpose blocks
464-
for (int c = 0; c < AWidth * ((m + (k + 1) * k / 2 + 15) >>> 4); c += 16)
464+
for (int c = 0; c < AWidth * ((m + (((k + 1) * k) >> 1) + 15) >>> 4); c += 16)
465465
{
466466
transpose16x16Nibbles(A, c);
467467
}
@@ -554,8 +554,7 @@ private static void transpose16x16Nibbles(long[] M, int offset)
554554
}
555555
}
556556

557-
boolean sampleSolution(MayoParameters params, byte[] A, byte[] y,
558-
byte[] r, byte[] x)
557+
boolean sampleSolution(MayoParameters params, byte[] A, byte[] y, byte[] r, byte[] x)
559558
{
560559
final int k = params.getK();
561560
final int o = params.getO();
@@ -576,19 +575,19 @@ boolean sampleSolution(MayoParameters params, byte[] A, byte[] y,
576575
GF16Utils.matMul(A, r, 0, Ar, ok + 1, m);
577576

578577
// Update last column of A with y - Ar
579-
for (int i = 0; i < m; i++)
578+
for (int i = 0, idx = ok; i < m; i++, idx += ok + 1)
580579
{
581-
A[ok + i * (ok + 1)] = (byte)(y[i] ^ Ar[i]);
580+
A[idx] = (byte)(y[i] ^ Ar[i]);
582581
}
583582

584583
// Perform row echelon form transformation
585584
ef(A, m, aCols);
586585

587586
// Check matrix rank
588587
boolean fullRank = false;
589-
for (int i = 0; i < aCols - 1; i++)
588+
for (int i = 0, idx = (m - 1) * aCols; i < aCols - 1; i++, idx++)
590589
{
591-
fullRank |= (A[(m - 1) * aCols + i] != 0);
590+
fullRank |= (A[idx] != 0);
592591
}
593592
if (!fullRank)
594593
{
@@ -609,7 +608,6 @@ boolean sampleSolution(MayoParameters params, byte[] A, byte[] y,
609608
byte u = (byte)(correctCol & ~finished & A[rowAcols + aCols - 1]);
610609
x[col] ^= u;
611610

612-
613611
// Update matrix entries
614612
for (int i = 0, iaCols_col = col, iaCols_aCols1 = aCols - 1; i < row; i += 8,
615613
iaCols_col += aCols << 3, iaCols_aCols1 += aCols << 3)
@@ -647,7 +645,7 @@ boolean sampleSolution(MayoParameters params, byte[] A, byte[] y,
647645
void ef(byte[] A, int nrows, int ncols)
648646
{
649647
// Each 64-bit long can hold 16 nibbles (16 GF(16) elements).
650-
int rowLen = (ncols + 15) / 16;
648+
int rowLen = (ncols + 15) >> 4;
651649

652650
// Allocate temporary arrays.
653651
long[] pivotRow = new long[rowLen];
@@ -686,11 +684,8 @@ void ef(byte[] A, int nrows, int ncols)
686684
int upperBound = Math.min(nrows - 1, pivotCol);
687685

688686
// Zero out pivot row buffers.
689-
for (int i = 0; i < rowLen; i++)
690-
{
691-
pivotRow[i] = 0;
692-
pivotRow2[i] = 0;
693-
}
687+
Arrays.clear(pivotRow);
688+
Arrays.clear(pivotRow2);
694689

695690
// Try to select a pivot row in constant time.
696691
int pivot = 0;

core/src/main/java/org/bouncycastle/pqc/crypto/mayo/Utils.java

Lines changed: 24 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,7 @@ public class Utils
2222
*/
2323
public static void decode(byte[] m, byte[] mdec, int mdecLen)
2424
{
25-
int i;
26-
int decIndex = 0;
27-
int blocks = mdecLen >> 1;
25+
int i, decIndex = 0, blocks = mdecLen >> 1;
2826
// Process pairs of nibbles from each byte
2927
for (i = 0; i < blocks; i++)
3028
{
@@ -34,7 +32,7 @@ public static void decode(byte[] m, byte[] mdec, int mdecLen)
3432
mdec[decIndex++] = (byte)((m[i] >> 4) & 0x0F);
3533
}
3634
// If there is an extra nibble (odd number of nibbles), decode only the lower nibble
37-
if (mdecLen % 2 == 1)
35+
if ((mdecLen & 1) == 1)
3836
{
3937
mdec[decIndex] = (byte)((m[i] & 0xFF) & 0x0F);
4038
}
@@ -52,7 +50,7 @@ public static void decode(byte[] m, int mOff, byte[] mdec, int decIndex, int mde
5250
mdec[decIndex++] = (byte)((m[mOff++] >> 4) & 0x0F);
5351
}
5452
// If there is an extra nibble (odd number of nibbles), decode only the lower nibble
55-
if (mdecLen % 2 == 1)
53+
if ((mdecLen & 1) == 1)
5654
{
5755
mdec[decIndex] = (byte)(m[mOff] & 0x0F);
5856
}
@@ -68,14 +66,13 @@ public static void decode(byte[] m, int mOff, byte[] mdec, int decIndex, int mde
6866
*/
6967
public static void decode(byte[] input, int inputOffset, byte[] output, int mdecLen)
7068
{
71-
int decIndex = 0;
72-
int blocks = mdecLen >> 1;
69+
int decIndex = 0, blocks = mdecLen >> 1;
7370
for (int i = 0; i < blocks; i++)
7471
{
7572
output[decIndex++] = (byte)(input[inputOffset] & 0x0F);
7673
output[decIndex++] = (byte)((input[inputOffset++] >> 4) & 0x0F);
7774
}
78-
if (mdecLen % 2 == 1)
75+
if ((mdecLen & 1) == 1)
7976
{
8077
output[decIndex] = (byte)(input[inputOffset] & 0x0F);
8178
}
@@ -92,8 +89,7 @@ public static void decode(byte[] input, int inputOffset, byte[] output, int mdec
9289
*/
9390
public static void encode(byte[] m, byte[] menc, int mlen)
9491
{
95-
int i;
96-
int srcIndex = 0;
92+
int i, srcIndex = 0;
9793
// Process pairs of 4-bit values
9894
for (i = 0; i < mlen / 2; i++)
9995
{
@@ -103,58 +99,28 @@ public static void encode(byte[] m, byte[] menc, int mlen)
10399
srcIndex += 2;
104100
}
105101
// If there is an extra nibble (odd number of nibbles), store it directly in lower 4 bits.
106-
if (mlen % 2 == 1)
102+
if ((mlen & 1) == 1)
107103
{
108104
menc[i] = (byte)(m[srcIndex] & 0x0F);
109105
}
110106
}
111107

112-
/**
113-
* Unpacks m-vectors from a packed byte array into an array of 64-bit limbs.
114-
*
115-
* @param in the input byte array containing packed data
116-
* @param out the output long array where unpacked limbs are stored
117-
* @param vecs the number of vectors
118-
* @param m the m parameter (used to compute m_vec_limbs and copy lengths)
119-
*/
120-
public static void unpackMVecs(byte[] in, long[] out, int vecs, int m)
121-
{
122-
int mVecLimbs = (m + 15) / 16;
123-
int bytesToCopy = m / 2; // Number of bytes to copy per vector
124-
// Temporary buffer to hold mVecLimbs longs (each long is 8 bytes)
125-
byte[] tmp = new byte[mVecLimbs << 3];
126-
127-
// Process vectors in reverse order
128-
for (int i = vecs - 1; i >= 0; i--)
129-
{
130-
// Copy m/2 bytes from the input into tmp. The rest remains zero.
131-
System.arraycopy(in, i * bytesToCopy, tmp, 0, bytesToCopy);
132-
133-
// Convert each 8-byte block in tmp into a long using Pack
134-
for (int j = 0; j < mVecLimbs; j++)
135-
{
136-
out[i * mVecLimbs + j] = Pack.littleEndianToLong(tmp, j << 3);
137-
}
138-
}
139-
}
140-
141108
public static void unpackMVecs(byte[] in, int inOff, long[] out, int outOff, int vecs, int m)
142109
{
143-
int mVecLimbs = (m + 15) / 16;
144-
int bytesToCopy = m / 2; // Number of bytes to copy per vector
110+
int mVecLimbs = (m + 15) >> 4;
111+
int bytesToCopy = m >> 1; // Number of bytes to copy per vector
145112
// Temporary buffer to hold mVecLimbs longs (each long is 8 bytes)
146-
byte[] tmp = new byte[mVecLimbs << 3];
113+
int lastblockLen = 8 - (mVecLimbs << 3) + bytesToCopy;
114+
int i, j;
147115
// Process vectors in reverse order
148-
for (int i = vecs - 1; i >= 0; i--)
116+
for (i = vecs - 1, outOff += i * mVecLimbs, inOff += i * bytesToCopy; i >= 0; i--, outOff -= mVecLimbs, inOff -= bytesToCopy)
149117
{
150-
// Copy m/2 bytes from the input into tmp. The rest remains zero.
151-
System.arraycopy(in, inOff + i * bytesToCopy, tmp, 0, bytesToCopy);
152-
153118
// Convert each 8-byte block in tmp into a long using Pack
154-
for (int j = 0; j < mVecLimbs; j++)
119+
for (j = 0; j < mVecLimbs - 1; j++)
155120
{
156-
out[outOff + i * mVecLimbs + j] = Pack.littleEndianToLong(tmp, j * 8);
121+
out[outOff + j] = Pack.littleEndianToLong(in, inOff + (j << 3));
157122
}
123+
out[outOff + j] = Pack.littleEndianToLong(in, inOff + (j << 3), lastblockLen);
158124
}
159125
}
160126

@@ -168,23 +134,19 @@ public static void unpackMVecs(byte[] in, int inOff, long[] out, int outOff, int
168134
*/
169135
public static void packMVecs(long[] in, byte[] out, int outOff, int vecs, int m)
170136
{
171-
int mVecLimbs = (m + 15) / 16;
172-
int bytesToCopy = m / 2; // Number of bytes per vector to write
173-
137+
int mVecLimbs = (m + 15) >> 4;
138+
int bytesToCopy = m >> 1; // Number of bytes per vector to write
139+
int lastBlockLen = 8 - (mVecLimbs << 3) + bytesToCopy;
140+
int j;
174141
// Process each vector in order
175-
for (int i = 0; i < vecs; i++)
142+
for (int i = 0, inOff = 0; i < vecs; i++, outOff += bytesToCopy, inOff += mVecLimbs)
176143
{
177-
// Temporary buffer to hold the bytes for this vector
178-
byte[] tmp = new byte[mVecLimbs * 8];
179-
180144
// Convert each long into 8 bytes using Pack
181-
for (int j = 0; j < mVecLimbs; j++)
145+
for (j = 0; j < mVecLimbs - 1; j++)
182146
{
183-
Pack.longToLittleEndian(in[i * mVecLimbs + j], tmp, j * 8);
147+
Pack.longToLittleEndian(in[inOff + j], out, outOff + (j << 3));
184148
}
185-
186-
// Copy the first m/2 bytes from tmp to the output array
187-
System.arraycopy(tmp, 0, out, i * bytesToCopy + outOff, bytesToCopy);
149+
Pack.longToLittleEndian(in[inOff + j], out, outOff + (j << 3), lastBlockLen);
188150
}
189151
}
190152

@@ -242,6 +204,6 @@ public static void expandP1P2(MayoParameters p, long[] P, byte[] seed_pk)
242204

243205
// Unpack the byte array 'temp' into the long array 'P'
244206
// using our previously defined unpackMVecs method.
245-
unpackMVecs(temp, P, numVectors, p.getM());
207+
unpackMVecs(temp, 0, P, 0, numVectors, p.getM());
246208
}
247209
}

0 commit comments

Comments
 (0)