Skip to content

Commit e356f4a

Browse files
author
gefeili
committed
Pass all test vectors of Mayo
1 parent fa78d66 commit e356f4a

File tree

7 files changed

+1545
-70
lines changed

7 files changed

+1545
-70
lines changed

core/src/main/java/org/bouncycastle/pqc/crypto/mayo/GF16Utils.java

Lines changed: 279 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package org.bouncycastle.pqc.crypto.mayo;
22

3+
import org.bouncycastle.util.Pack;
4+
35
public class GF16Utils
46
{
57

@@ -105,7 +107,7 @@ public static void mulAddMUpperTriangularMatXMat(int mVecLimbs, long[] bsMat, by
105107
int a = mat[c * matCols + k] & 0xFF;
106108
// For acc: add into the m-vector at row r, column k.
107109
int accOffset = (r * matCols + k) * mVecLimbs;
108-
GF16Utils.mVecMulAdd(mVecLimbs, bsMat, bsMatOffset, a, acc, accOffset);
110+
mVecMulAdd(mVecLimbs, bsMat, bsMatOffset, a, acc, accOffset);
109111
}
110112
bsMatEntriesUsed++;
111113
}
@@ -162,7 +164,7 @@ public static void mulAddMatTransXMMat(int mVecLimbs, byte[] mat, long[] bsMat,
162164
int a = mat[c * matCols + r] & 0xFF;
163165
// For acc: add into the m-vector at index (r * bsMatCols + k)
164166
int accOffset = (r * bsMatCols + k) * mVecLimbs;
165-
GF16Utils.mVecMulAdd(mVecLimbs, bsMat, bsMatOffset, a, acc, accOffset);
167+
mVecMulAdd(mVecLimbs, bsMat, bsMatOffset, a, acc, accOffset);
166168
}
167169
}
168170
}
@@ -181,5 +183,280 @@ public static void mVecAdd(int mVecLimbs, long[] src, int srcOffset, long[] dest
181183
}
182184
}
183185

186+
/**
187+
* Multiplies a matrix (given as a byte array) with a bit‐sliced matrix (given as a long array)
188+
* and accumulates the result into the acc array.
189+
*
190+
* <p>
191+
* The operation iterates over the rows and columns of the matrix. For each element in the matrix,
192+
* it multiplies a corresponding vector (from bsMat) by the scalar value (from mat) and adds the
193+
* result to the accumulator vector in acc.
194+
* </p>
195+
*
196+
* @param mVecLimbs the number of limbs (elements) in each vector
197+
* @param mat the matrix as a byte array with dimensions [matRows x matCols]
198+
* @param bsMat the bit‐sliced matrix as a long array
199+
* @param acc the accumulator array (long[]) where results are accumulated
200+
* @param matRows the number of rows in the matrix
201+
* @param matCols the number of columns in the matrix
202+
* @param bsMatCols the number of columns in the bit‐sliced matrix (per block)
203+
*/
204+
public static void mulAddMatXMMat(int mVecLimbs, byte[] mat, long[] bsMat, long[] acc,
205+
int matRows, int matCols, int bsMatCols)
206+
{
207+
for (int r = 0; r < matRows; r++)
208+
{
209+
for (int c = 0; c < matCols; c++)
210+
{
211+
// Retrieve the scalar from the matrix for row r and column c.
212+
byte matVal = mat[r * matCols + c];
213+
for (int k = 0; k < bsMatCols; k++)
214+
{
215+
// Compute the starting index for the vector in bsMat.
216+
int bsMatOffset = mVecLimbs * (c * bsMatCols + k);
217+
// Compute the starting index for the accumulator vector in acc.
218+
int accOffset = mVecLimbs * (r * bsMatCols + k);
219+
// Multiply the vector by the scalar and add the result to the accumulator.
220+
mVecMulAdd(mVecLimbs, bsMat, bsMatOffset, matVal, acc, accOffset);
221+
}
222+
}
223+
}
224+
}
225+
226+
public static void mulAddMatXMMat(int mVecLimbs, byte[] mat, long[] bsMat, int bsMatOff, long[] acc,
227+
int matRows, int matCols, int bsMatCols)
228+
{
229+
for (int r = 0; r < matRows; r++)
230+
{
231+
for (int c = 0; c < matCols; c++)
232+
{
233+
// Retrieve the scalar from the matrix for row r and column c.
234+
byte matVal = mat[r * matCols + c];
235+
for (int k = 0; k < bsMatCols; k++)
236+
{
237+
// Compute the starting index for the vector in bsMat.
238+
int bsMatOffset = mVecLimbs * (c * bsMatCols + k) + bsMatOff;
239+
// Compute the starting index for the accumulator vector in acc.
240+
int accOffset = mVecLimbs * (r * bsMatCols + k);
241+
// Multiply the vector by the scalar and add the result to the accumulator.
242+
mVecMulAdd(mVecLimbs, bsMat, bsMatOffset, matVal, acc, accOffset);
243+
}
244+
}
245+
}
246+
}
247+
248+
/**
249+
* Multiplies m (possibly upper triangular) matrices with the transpose of a single matrix
250+
* and adds the result to the accumulator.
251+
*
252+
* <p>
253+
* For each row {@code r} in the bit‑sliced matrix and for each column {@code c} (starting from
254+
* {@code triangular * r}) in the bit‑sliced matrix, this method iterates over all rows {@code k}
255+
* of the single matrix, and for each element, it multiplies the vector (from {@code bsMat})
256+
* by the scalar (from {@code mat}) and adds the result to the corresponding vector in {@code acc}.
257+
* </p>
258+
*
259+
* @param mVecLimbs the number of limbs (elements) in each vector.
260+
* @param bsMat the bit‑sliced matrix stored as a long array.
261+
* @param mat the matrix stored as a byte array.
262+
* @param acc the accumulator array where the results are added.
263+
* @param bsMatRows the number of rows in the bit‑sliced matrix.
264+
* @param bsMatCols the number of columns in the bit‑sliced matrix.
265+
* @param matRows the number of rows in the matrix.
266+
* @param triangular if non‑zero, indicates that the matrix is upper triangular (i.e. the loop for {@code c}
267+
* starts at {@code triangular * r}).
268+
*/
269+
public static void mulAddMUpperTriangularMatXMatTrans(int mVecLimbs, long[] bsMat, byte[] mat, long[] acc,
270+
int bsMatRows, int bsMatCols, int matRows, int triangular)
271+
{
272+
int bsMatEntriesUsed = 0;
273+
for (int r = 0; r < bsMatRows; r++)
274+
{
275+
// For upper triangular, start c at triangular * r; otherwise, triangular is zero.
276+
for (int c = triangular * r; c < bsMatCols; c++)
277+
{
278+
for (int k = 0; k < matRows; k++)
279+
{
280+
int bsMatOffset = mVecLimbs * bsMatEntriesUsed;
281+
int accOffset = mVecLimbs * (r * matRows + k);
282+
// Get the matrix element at row k and column c
283+
byte matVal = mat[k * bsMatCols + c];
284+
mVecMulAdd(mVecLimbs, bsMat, bsMatOffset, matVal, acc, accOffset);
285+
}
286+
bsMatEntriesUsed++;
287+
}
288+
}
289+
}
290+
291+
/**
292+
* Multiplies a vector (from bsMat) by an unsigned scalar (from mat) and adds the result
293+
* to the corresponding vector in acc.
294+
*
295+
* <p>
296+
* This method corresponds to the C function <code>m_vec_mul_add</code>.
297+
* It processes {@code mVecLimbs} elements starting from the given offsets in the source and accumulator arrays.
298+
* </p>
299+
*
300+
* @param mVecLimbs the number of limbs (elements) in the vector
301+
* @param bsMat the source array (bit-sliced matrix) of long values
302+
* @param bsMatOffset the starting index in bsMat for the vector
303+
* @param scalar the scalar value (from mat), as a byte
304+
* @param acc the accumulator array where the result is added
305+
* @param accOffset the starting index in the accumulator array for the current vector
306+
*/
307+
public static void mVecMulAdd(int mVecLimbs, long[] bsMat, int bsMatOffset, byte scalar, long[] acc, int accOffset)
308+
{
309+
for (int i = 0; i < mVecLimbs; i++)
310+
{
311+
acc[accOffset + i] ^= gf16vMulU64(bsMat[bsMatOffset + i], scalar);
312+
}
313+
}
314+
315+
/**
316+
* GF(16) multiplication mod x^4 + x + 1.
317+
* <p>
318+
* This method multiplies two elements in GF(16) (represented as integers 0–15)
319+
* using carryless multiplication followed by reduction modulo x^4 + x + 1.
320+
*
321+
* @param a an element in GF(16) (only the lower 4 bits are used)
322+
* @param b an element in GF(16) (only the lower 4 bits are used)
323+
* @return the product a * b in GF(16)
324+
*/
325+
public static int mulF(int a, int b)
326+
{
327+
// In C there is a conditional XOR with unsigned_char_blocker to work around
328+
// compiler-specific behavior. In Java we can omit it (or define it as needed).
329+
// a ^= unsignedCharBlocker; // Omitted in Java
330+
331+
// Perform carryless multiplication:
332+
// Multiply b by each bit of a and XOR the results.
333+
int p = ((a & 1) * b) ^
334+
((a & 2) * b) ^
335+
((a & 4) * b) ^
336+
((a & 8) * b);
337+
338+
// Reduce modulo f(X) = x^4 + x + 1.
339+
// Extract the upper nibble (bits 4 to 7).
340+
int topP = p & 0xF0;
341+
// The reduction: XOR p with (topP shifted right by 4 and by 3) and mask to 4 bits.
342+
int out = (p ^ (topP >> 4) ^ (topP >> 3)) & 0x0F;
343+
return out;
344+
}
345+
346+
/**
347+
* Performs a GF(16) carryless multiplication of a nibble (lower 4 bits of a)
348+
* with a 64-bit word b, then reduces modulo the polynomial x⁴ + x + 1 on each byte.
349+
*
350+
* @param a a GF(16) element (only the low 4 bits are used)
351+
* @param b a 64-bit word representing 16 GF(16) elements (packed 4 bits per element)
352+
* @return the reduced 64-bit word after multiplication
353+
*/
354+
public static long mulFx8(byte a, long b)
355+
{
356+
// Convert 'a' to an unsigned int so that bit operations work as expected.
357+
int aa = a & 0xFF;
358+
// Carryless multiplication: for each bit in 'aa' (considering only the lower 4 bits),
359+
// if that bit is set, multiply 'b' (by 1, 2, 4, or 8) and XOR the result.
360+
long p = ((aa & 1) * b)
361+
^ ((aa & 2) * b)
362+
^ ((aa & 4) * b)
363+
^ ((aa & 8) * b);
364+
365+
// Reduction mod (x^4 + x + 1): process each byte in parallel.
366+
long topP = p & 0xf0f0f0f0f0f0f0f0L;
367+
long out = (p ^ (topP >> 4) ^ (topP >> 3)) & 0x0f0f0f0f0f0f0f0fL;
368+
return out;
369+
}
370+
371+
public static void matMul(byte[] a, byte[] b, byte[] c,
372+
int colrowAB, int rowA, int colB)
373+
{
374+
int cIndex = 0;
375+
for (int i = 0; i < rowA; i++)
376+
{
377+
int aRowStart = i * colrowAB;
378+
for (int j = 0; j < colB; j++)
379+
{
380+
c[cIndex++] = lincomb(a, aRowStart, b, j, colrowAB, colB);
381+
}
382+
}
383+
}
384+
385+
public static void matMul(byte[] a, int aOff, byte[] b, int bOff, byte[] c, int cOff,
386+
int colrowAB, int rowA, int colB)
387+
{
388+
int cIndex = 0;
389+
for (int i = 0; i < rowA; i++)
390+
{
391+
int aRowStart = i * colrowAB;
392+
for (int j = 0; j < colB; j++)
393+
{
394+
c[cOff + cIndex++] = lincomb(a, aOff + aRowStart, b, bOff + j, colrowAB, colB);
395+
}
396+
}
397+
}
398+
399+
400+
private static byte lincomb(byte[] a, int aStart, byte[] b, int bStart,
401+
int colrowAB, int colB)
402+
{
403+
byte result = 0;
404+
for (int k = 0; k < colrowAB; k++)
405+
{
406+
result ^= mulF(a[aStart + k], b[bStart + k * colB]);
407+
}
408+
return result;
409+
}
410+
411+
public static void matAdd(byte[] a, int aOff, byte[] b, int bOff, byte[] c, int cOff, int m, int n)
412+
{
413+
for (int i = 0; i < m; i++)
414+
{
415+
for (int j = 0; j < n; j++)
416+
{
417+
int idx = i * n + j;
418+
c[idx + cOff] = (byte)(a[idx + aOff] ^ b[idx + bOff]);
419+
}
420+
}
421+
}
422+
423+
// Define the blocker constant as needed (set to 0 if not used).
424+
private static final byte UNSIGNED_CHAR_BLOCKER = 0;
425+
426+
/**
427+
* Returns 0x00 if a equals b, otherwise returns 0xFF.
428+
* This operation is performed in constant time.
429+
*
430+
* @param a an 8-bit value
431+
* @param b an 8-bit value
432+
* @return 0x00 if a == b, 0xFF if a != b
433+
*/
434+
public static byte ctCompare8(byte a, byte b)
435+
{
436+
// Compute the difference between a and b using XOR.
437+
// Masking with 0xFF ensures we work with values in 0..255.
438+
int diff = (a ^ b) & 0xFF;
439+
// Negate the difference.
440+
int negDiff = -diff;
441+
// Right shift by 31 bits (since 8*sizeof(uint32_t)-1 equals 31 for 32-bit integers).
442+
// If diff is 0, then -diff is 0, and shifting yields 0.
443+
// If diff is nonzero, -diff is negative, so the arithmetic shift yields -1 (0xFFFFFFFF),
444+
// which when cast to a byte becomes 0xFF.
445+
int result = negDiff >> 31;
446+
// XOR with UNSIGNED_CHAR_BLOCKER (assumed 0 here) and cast to byte.
447+
return (byte)(result ^ UNSIGNED_CHAR_BLOCKER);
448+
}
449+
450+
public static void efUnpackMVector(int legs, long[] packedRow, int packedRowOff, byte[] out)
451+
{
452+
int outIndex = 0;
453+
byte[] bytes = new byte[out.length >> 1];
454+
Pack.longToLittleEndian(packedRow, packedRowOff, out.length >> 4, bytes, 0);
455+
for (int i = 0; i < legs * 16; i += 2)
456+
{
457+
out[outIndex++] = (byte)(bytes[i / 2] & 0x0F); // Lower nibble
458+
out[outIndex++] = (byte)((bytes[i / 2] >> 4) & 0x0F); // Upper nibble
459+
}
460+
}
184461
}
185462

0 commit comments

Comments
 (0)