Skip to content

Commit 50f0914

Browse files
author
gefeili
committed
Refactor of Mayo
1 parent 025f99d commit 50f0914

File tree

4 files changed

+87
-107
lines changed

4 files changed

+87
-107
lines changed

core/src/main/java/org/bouncycastle/pqc/crypto/mayo/GF16Utils.java

Lines changed: 29 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -100,23 +100,19 @@ static void mulAddMUpperTriangularMatXMat(int mVecLimbs, long[] bsMat, byte[] ma
100100
* each entry is an m-vector.
101101
* @param matRows number of rows in the matrix “mat.”
102102
* @param matCols number of columns in “mat.”
103-
* @param bsMatCols number of columns in the bsMat matrix.
104103
*/
105104
static void mulAddMatTransXMMat(int mVecLimbs, byte[] mat, long[] bsMat, int bsMatOff, long[] acc,
106-
int matRows, int matCols, int bsMatCols)
105+
int matRows, int matCols)
107106
{
108-
// Loop over each column r of mat (which becomes row of mat^T)
109-
for (int r = 0; r < matCols; r++)
107+
int multiply = matCols * mVecLimbs;
108+
for (int r = 0, rmultiply = 0; r < matCols; r++, rmultiply += multiply)
110109
{
111-
for (int c = 0, cmatCols = 0; c < matRows; c++, cmatCols += matCols)
110+
for (int c = 0, cmatCols = 0, cmultiply = 0; c < matRows; c++, cmatCols += matCols, cmultiply += multiply)
112111
{
113112
byte matVal = mat[cmatCols + r];
114-
for (int k = 0; k < bsMatCols; k++)
113+
for (int k = 0, kmVecLimbs = 0; k < matCols; k++, kmVecLimbs += mVecLimbs)
115114
{
116-
int bsMatOffset = bsMatOff + (c * bsMatCols + k) * mVecLimbs;
117-
// For acc: add into the m-vector at index (r * bsMatCols + k)
118-
int accOffset = (r * bsMatCols + k) * mVecLimbs;
119-
mVecMulAdd(mVecLimbs, bsMat, bsMatOffset, matVal, acc, accOffset);
115+
mVecMulAdd(mVecLimbs, bsMat, bsMatOff + cmultiply + kmVecLimbs, matVal, acc, rmultiply + kmVecLimbs);
120116
}
121117
}
122118
}
@@ -138,25 +134,19 @@ static void mulAddMatTransXMMat(int mVecLimbs, byte[] mat, long[] bsMat, int bsM
138134
* @param acc the accumulator array (long[]) where results are accumulated
139135
* @param matRows the number of rows in the matrix
140136
* @param matCols the number of columns in the matrix
141-
* @param bsMatCols the number of columns in the bit‐sliced matrix (per block)
142137
*/
143-
static void mulAddMatXMMat(int mVecLimbs, byte[] mat, long[] bsMat, long[] acc,
144-
int matRows, int matCols, int bsMatCols)
138+
static void mulAddMatXMMat(int mVecLimbs, byte[] mat, long[] bsMat, long[] acc, int matRows, int matCols)
145139
{
146-
for (int r = 0; r < matRows; r++)
140+
int multiply = mVecLimbs * matRows;
141+
for (int r = 0, rmatCols = 0, rmultiply = 0; r < matRows; r++, rmatCols += matCols, rmultiply += multiply)
147142
{
148-
for (int c = 0; c < matCols; c++)
143+
for (int c = 0, cmultiply = 0; c < matCols; c++, cmultiply += multiply)
149144
{
150145
// Retrieve the scalar from the matrix for row r and column c.
151-
byte matVal = mat[r * matCols + c];
152-
for (int k = 0; k < bsMatCols; k++)
146+
byte matVal = mat[rmatCols + c];
147+
for (int k = 0, kmVecLimbs = 0; k < matRows; k++, kmVecLimbs += mVecLimbs)
153148
{
154-
// Compute the starting index for the vector in bsMat.
155-
int bsMatOffset = mVecLimbs * (c * bsMatCols + k);
156-
// Compute the starting index for the accumulator vector in acc.
157-
int accOffset = mVecLimbs * (r * bsMatCols + k);
158-
// Multiply the vector by the scalar and add the result to the accumulator.
159-
mVecMulAdd(mVecLimbs, bsMat, bsMatOffset, matVal, acc, accOffset);
149+
mVecMulAdd(mVecLimbs, bsMat, cmultiply + kmVecLimbs, matVal, acc, rmultiply + kmVecLimbs);
160150
}
161151
}
162152
}
@@ -165,20 +155,16 @@ static void mulAddMatXMMat(int mVecLimbs, byte[] mat, long[] bsMat, long[] acc,
165155
static void mulAddMatXMMat(int mVecLimbs, byte[] mat, long[] bsMat, int bsMatOff, long[] acc,
166156
int matRows, int matCols, int bsMatCols)
167157
{
168-
for (int r = 0; r < matRows; r++)
158+
int multiply = mVecLimbs * bsMatCols;
159+
for (int r = 0, rmultiply = 0, rmatCols = 0; r < matRows; r++, rmultiply += multiply, rmatCols += matCols)
169160
{
170-
for (int c = 0; c < matCols; c++)
161+
for (int c = 0, cmultiply = 0; c < matCols; c++, cmultiply += multiply)
171162
{
172163
// Retrieve the scalar from the matrix for row r and column c.
173-
byte matVal = mat[r * matCols + c];
174-
for (int k = 0; k < bsMatCols; k++)
164+
byte matVal = mat[rmatCols + c];
165+
for (int k = 0, kmVecLimbs = 0; k < bsMatCols; k++, kmVecLimbs += mVecLimbs)
175166
{
176-
// Compute the starting index for the vector in bsMat.
177-
int bsMatOffset = mVecLimbs * (c * bsMatCols + k) + bsMatOff;
178-
// Compute the starting index for the accumulator vector in acc.
179-
int accOffset = mVecLimbs * (r * bsMatCols + k);
180-
// Multiply the vector by the scalar and add the result to the accumulator.
181-
mVecMulAdd(mVecLimbs, bsMat, bsMatOffset, matVal, acc, accOffset);
167+
mVecMulAdd(mVecLimbs, bsMat, cmultiply + kmVecLimbs + bsMatOff, matVal, acc, rmultiply + kmVecLimbs);
182168
}
183169
}
184170
}
@@ -200,27 +186,22 @@ static void mulAddMatXMMat(int mVecLimbs, byte[] mat, long[] bsMat, int bsMatOff
200186
* @param mat the matrix stored as a byte array.
201187
* @param acc the accumulator array where the results are added.
202188
* @param bsMatRows the number of rows in the bit‑sliced matrix.
203-
* @param bsMatCols the number of columns in the bit‑sliced matrix.
204189
* @param matRows the number of rows in the matrix.
205190
*/
206-
static void mulAddMUpperTriangularMatXMatTrans(int mVecLimbs, long[] bsMat, byte[] mat, long[] acc,
207-
int bsMatRows, int bsMatCols, int matRows)
191+
static void mulAddMUpperTriangularMatXMatTrans(int mVecLimbs, long[] bsMat, byte[] mat, long[] acc, int bsMatRows, int matRows)
208192
{
209193
int bsMatEntriesUsed = 0;
210-
for (int r = 0; r < bsMatRows; r++)
194+
int multiply = mVecLimbs * matRows;
195+
for (int r = 0, rmultiply = 0; r < bsMatRows; r++, rmultiply += multiply)
211196
{
212197
// For upper triangular, start c at triangular * r; otherwise, triangular is zero.
213-
for (int c = r; c < bsMatCols; c++)
198+
for (int c = r; c < bsMatRows; c++)
214199
{
215-
for (int k = 0; k < matRows; k++)
200+
for (int k = 0, kbsMatRows = 0, kmVecLimbs = 0; k < matRows; k++, kbsMatRows += bsMatRows, kmVecLimbs += mVecLimbs)
216201
{
217-
int bsMatOffset = mVecLimbs * bsMatEntriesUsed;
218-
int accOffset = mVecLimbs * (r * matRows + k);
219-
// Get the matrix element at row k and column c
220-
byte matVal = mat[k * bsMatCols + c];
221-
mVecMulAdd(mVecLimbs, bsMat, bsMatOffset, matVal, acc, accOffset);
202+
mVecMulAdd(mVecLimbs, bsMat, bsMatEntriesUsed, mat[kbsMatRows + c], acc, rmultiply + kmVecLimbs);
222203
}
223-
bsMatEntriesUsed++;
204+
bsMatEntriesUsed += mVecLimbs;
224205
}
225206
}
226207
}
@@ -254,7 +235,7 @@ static byte inverseF(int a)
254235
int a4 = mulF(a2, a2);
255236
int a8 = mulF(a4, a4);
256237
int a6 = mulF(a2, a4);
257-
return (byte) mulF(a8, a6);
238+
return (byte)mulF(a8, a6);
258239
}
259240

260241
/**
@@ -280,12 +261,12 @@ static long mulFx8(byte a, long b)
280261

281262
static void matMul(byte[] a, byte[] b, int bOff, byte[] c, int colrowAB, int rowA)
282263
{
283-
for (int i = 0, aRowStart = 0, cOff = 0; i < rowA; i++, aRowStart += colrowAB)
264+
for (int i = 0, aRowStart = 0, cOff = 0; i < rowA; i++)
284265
{
285266
byte result = 0;
286267
for (int k = 0; k < colrowAB; k++)
287268
{
288-
result ^= mulF(a[aRowStart + k], b[bOff + k]);
269+
result ^= mulF(a[aRowStart++], b[bOff + k]);
289270
}
290271
c[cOff++] = result;
291272
}

core/src/main/java/org/bouncycastle/pqc/crypto/mayo/MayoKeyPairGenerator.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ public AsymmetricCipherKeyPair generateKeyPair()
7575
// Here, treat P2 as the bsMat for the multiplication.
7676
// Dimensions: mat = O (size: paramV x paramO), bsMat = P2 (size: paramV x paramO),
7777
// and acc (P3) will have dimensions: (paramO x paramO), each entry being an m-vector.
78-
GF16Utils.mulAddMatTransXMMat(mVecLimbs, O, P, p1Limbs, P3, v, o, o);
78+
GF16Utils.mulAddMatTransXMMat(mVecLimbs, O, P, p1Limbs, P3, v, o);
7979

8080
// Store seed_pk into the public key cpk.
8181
System.arraycopy(seed_pk, 0, cpk, 0, pkSeedBytes);

0 commit comments

Comments
 (0)