Skip to content

Commit 4ed7d13

Browse files
committed
Refactoring in pqc.crypto.mldsa
- reduced allocations
1 parent b77d970 commit 4ed7d13

File tree

6 files changed

+183
-393
lines changed

6 files changed

+183
-393
lines changed

core/src/main/java/org/bouncycastle/pqc/crypto/mldsa/MLDSAEngine.java

Lines changed: 22 additions & 174 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ byte[][] generateKeyPairInternal(byte[] seed)
269269

270270
s1hat = new PolyVecL(this);
271271

272-
s1.copyPolyVecL(s1hat);
272+
s1.copyTo(s1hat);
273273
s1hat.polyVecNtt();
274274

275275
// System.out.println(s1hat.toString("s1hat"));
@@ -323,7 +323,7 @@ byte[] deriveT1(byte[] rho, byte[] key, byte[] tr, byte[] s1Enc, byte[] s2Enc, b
323323

324324
s1hat = new PolyVecL(this);
325325

326-
s1.copyPolyVecL(s1hat);
326+
s1.copyTo(s1hat);
327327
s1hat.polyVecNtt();
328328

329329
// System.out.println(s1hat.toString("s1hat"));
@@ -382,7 +382,7 @@ void initVerify(byte[] rho, byte[] encT1, boolean isPreHash, byte[] ctx)
382382
}
383383
}
384384

385-
public byte[] signInternal(byte[] msg, int msglen, byte[] rho, byte[] key, byte[] t0Enc, byte[] s1Enc, byte[] s2Enc, byte[] rnd)
385+
byte[] signInternal(byte[] msg, int msglen, byte[] rho, byte[] key, byte[] t0Enc, byte[] s1Enc, byte[] s2Enc, byte[] rnd)
386386
{
387387
SHAKEDigest shake256 = new SHAKEDigest(shake256Digest);
388388

@@ -397,7 +397,6 @@ byte[] generateSignature(SHAKEDigest shake256Digest, byte[] rho, byte[] key, byt
397397

398398
shake256Digest.doFinal(mu, 0, CrhBytes);
399399

400-
int n;
401400
byte[] outSig = new byte[CryptoBytes];
402401
byte[] rhoPrime = new byte[CrhBytes];
403402
short nonce = 0;
@@ -428,7 +427,7 @@ byte[] generateSignature(SHAKEDigest shake256Digest, byte[] rho, byte[] key, byt
428427
// Sample intermediate vector
429428
y.uniformGamma1(rhoPrime, nonce++);
430429

431-
y.copyPolyVecL(z);
430+
y.copyTo(z);
432431
z.polyVecNtt();
433432

434433
// Matrix-vector multiplication
@@ -440,13 +439,13 @@ byte[] generateSignature(SHAKEDigest shake256Digest, byte[] rho, byte[] key, byt
440439
w1.conditionalAddQ();
441440
w1.decompose(w0);
442441

443-
System.arraycopy(w1.packW1(), 0, outSig, 0, DilithiumK * DilithiumPolyW1PackedBytes);
442+
w1.packW1(this, outSig, 0);
444443

445444
shake256Digest.update(mu, 0, CrhBytes);
446445
shake256Digest.update(outSig, 0, DilithiumK * DilithiumPolyW1PackedBytes);
447446
shake256Digest.doFinal(outSig, 0, DilithiumCTilde);
448447

449-
cp.challenge(Arrays.copyOfRange(outSig, 0, DilithiumCTilde)); // uses only the first DilithiumCTilde bytes of sig
448+
cp.challenge(outSig, 0, DilithiumCTilde);
450449
cp.polyNtt();
451450

452451
// Compute z, reject if it reveals secret
@@ -478,235 +477,84 @@ byte[] generateSignature(SHAKEDigest shake256Digest, byte[] rho, byte[] key, byt
478477

479478
w0.addPolyVecK(h);
480479
w0.conditionalAddQ();
481-
n = h.makeHint(w0, w1);
480+
int n = h.makeHint(w0, w1);
482481
if (n > DilithiumOmega)
483482
{
484483
continue;
485484
}
486485

487-
return Packing.packSignature(outSig, z, h, this);
486+
Packing.packSignature(outSig, z, h, this);
487+
return outSig;
488488
}
489489

490+
// TODO[pqc] Shouldn't this throw an exception here (or in caller)?
490491
return null;
491492
}
492493

493-
public boolean verifyInternal(byte[] sig, int siglen, SHAKEDigest shake256Digest, byte[] rho, byte[] encT1)
494+
boolean verifyInternal(byte[] sig, int siglen, SHAKEDigest shake256Digest, byte[] rho, byte[] encT1)
494495
{
495496
if (siglen != CryptoBytes)
496497
{
497498
return false;
498499
}
499500

500-
// System.out.println("publickey = ");
501-
// Helper.printByteArray(publicKey);
502-
byte[] buf,
503-
mu = new byte[CrhBytes],
504-
c,
505-
c2 = new byte[DilithiumCTilde];
506-
Poly cp = new Poly(this);
507-
PolyVecMatrix aMatrix = new PolyVecMatrix(this);
501+
PolyVecK h = new PolyVecK(this);
508502
PolyVecL z = new PolyVecL(this);
509-
PolyVecK t1 = new PolyVecK(this), w1 = new PolyVecK(this), h = new PolyVecK(this);
510-
511-
t1 = Packing.unpackPublicKey(t1, encT1, this);
512-
513-
// System.out.println(t1.toString("t1"));
514-
515-
// System.out.println("rho = ");
516-
// Helper.printByteArray(rho);
517503

518504
if (!Packing.unpackSignature(z, h, sig, this))
519505
{
520506
return false;
521507
}
522-
c = Arrays.copyOfRange(sig, 0, DilithiumCTilde);
523-
524-
// System.out.println(z.toString("z"));
525-
// System.out.println(h.toString("h"));
526508

527509
if (z.checkNorm(getDilithiumGamma1() - getDilithiumBeta()))
528510
{
529511
return false;
530512
}
531513

532-
shake256Digest.doFinal(mu, 0);
533-
534-
// System.out.println("mu after = ");
535-
// Helper.printByteArray(mu);
536-
537-
// Matrix-vector multiplication; compute Az - c2^dt1
538-
cp.challenge(Arrays.copyOfRange(c, 0, DilithiumCTilde)); // use only first DilithiumCTilde of c.
539-
// System.out.println("cp = ");
540-
// System.out.println(cp.toString());
541-
542-
aMatrix.expandMatrix(rho);
543-
// System.out.println(aMatrix.toString("aMatrix = "));
544-
545-
546-
z.polyVecNtt();
547-
aMatrix.pointwiseMontgomery(w1, z);
548-
549-
cp.polyNtt();
550-
// System.out.println("cp = ");
551-
// System.out.println(cp.toString());
552-
553-
t1.shiftLeft();
554-
t1.polyVecNtt();
555-
t1.pointwisePolyMontgomery(cp, t1);
556-
557-
// System.out.println(t1.toString("t1"));
558-
559-
w1.subtract(t1);
560-
w1.reduce();
561-
w1.invNttToMont();
562-
563-
// System.out.println(w1.toString("w1 before caddq"));
564-
565-
// Reconstruct w1
566-
w1.conditionalAddQ();
567-
// System.out.println(w1.toString("w1 before hint"));
568-
w1.useHint(w1, h);
569-
// System.out.println(w1.toString("w1"));
570-
571-
buf = w1.packW1();
572-
573-
// System.out.println("buf = ");
574-
// Helper.printByteArray(buf);
575-
576-
// System.out.println("mu = ");
577-
// Helper.printByteArray(mu);
578-
579-
SHAKEDigest shakeDigest256 = new SHAKEDigest(256);
580-
shakeDigest256.update(mu, 0, CrhBytes);
581-
shakeDigest256.update(buf, 0, DilithiumK * DilithiumPolyW1PackedBytes);
582-
shakeDigest256.doFinal(c2, 0, DilithiumCTilde);
583-
584-
// System.out.println("c = ");
585-
// Helper.printByteArray(c);
514+
byte[] buf = new byte[Math.max(CrhBytes + DilithiumK * DilithiumPolyW1PackedBytes, DilithiumCTilde)];
586515

587-
// System.out.println("c2 = ");
588-
// Helper.printByteArray(c2);
516+
// Mu
517+
shake256Digest.doFinal(buf, 0);
589518

590-
591-
return Arrays.constantTimeAreEqual(c, c2);
592-
}
593-
594-
public boolean verifyInternal(byte[] sig, int siglen, byte[] msg, int msglen, byte[] rho, byte[] encT1)
595-
{
596-
if (siglen != CryptoBytes)
597-
{
598-
return false;
599-
}
600-
601-
// System.out.println("publickey = ");
602-
// Helper.printByteArray(publicKey);
603-
byte[] buf,
604-
mu = new byte[CrhBytes],
605-
c,
606-
c2 = new byte[DilithiumCTilde];
607519
Poly cp = new Poly(this);
608520
PolyVecMatrix aMatrix = new PolyVecMatrix(this);
609-
PolyVecL z = new PolyVecL(this);
610-
PolyVecK t1 = new PolyVecK(this), w1 = new PolyVecK(this), h = new PolyVecK(this);
521+
PolyVecK t1 = new PolyVecK(this), w1 = new PolyVecK(this);
611522

612523
t1 = Packing.unpackPublicKey(t1, encT1, this);
613524

614-
// System.out.println(t1.toString("t1"));
615-
616-
// System.out.println("rho = ");
617-
// Helper.printByteArray(rho);
618-
619-
if (!Packing.unpackSignature(z, h, sig, this))
620-
{
621-
return false;
622-
}
623-
c = Arrays.copyOfRange(sig, 0, DilithiumCTilde);
624-
625-
// System.out.println(z.toString("z"));
626-
// System.out.println(h.toString("h"));
627-
628-
if (z.checkNorm(getDilithiumGamma1() - getDilithiumBeta()))
629-
{
630-
return false;
631-
}
632-
633-
// Compute crh(crh(rho, t1), msg)
634-
// shake256Digest.update(rho, 0, rho.length);
635-
// shake256Digest.update(encT1, 0, encT1.length);
636-
// shake256Digest.doFinal(mu, 0, TrBytes);
637-
// System.out.println("mu before = ");
638-
// Helper.printByteArray(mu);
639-
640-
//shake256Digest.update(mu, 0, TrBytes);
641-
shake256Digest.update(msg, 0, msglen);
642-
shake256Digest.doFinal(mu, 0);
643-
644-
// System.out.println("mu after = ");
645-
// Helper.printByteArray(mu);
646-
647525
// Matrix-vector multiplication; compute Az - c2^dt1
648-
cp.challenge(Arrays.copyOfRange(c, 0, DilithiumCTilde)); // use only first DilithiumCTilde of c.
649-
// System.out.println("cp = ");
650-
// System.out.println(cp.toString());
526+
cp.challenge(sig, 0, DilithiumCTilde);
651527

652528
aMatrix.expandMatrix(rho);
653-
// System.out.println(aMatrix.toString("aMatrix = "));
654-
655529

656530
z.polyVecNtt();
657531
aMatrix.pointwiseMontgomery(w1, z);
658532

659533
cp.polyNtt();
660-
// System.out.println("cp = ");
661-
// System.out.println(cp.toString());
662534

663535
t1.shiftLeft();
664536
t1.polyVecNtt();
665537
t1.pointwisePolyMontgomery(cp, t1);
666538

667-
// System.out.println(t1.toString("t1"));
668-
669539
w1.subtract(t1);
670540
w1.reduce();
671541
w1.invNttToMont();
672542

673-
// System.out.println(w1.toString("w1 before caddq"));
674-
675-
// Reconstruct w1
676543
w1.conditionalAddQ();
677-
// System.out.println(w1.toString("w1 before hint"));
678544
w1.useHint(w1, h);
679-
// System.out.println(w1.toString("w1"));
680-
681-
buf = w1.packW1();
682-
683-
// System.out.println("buf = ");
684-
// Helper.printByteArray(buf);
685545

686-
// System.out.println("mu = ");
687-
// Helper.printByteArray(mu);
546+
w1.packW1(this, buf, CrhBytes);
688547

689-
SHAKEDigest shakeDigest256 = new SHAKEDigest(256);
690-
shakeDigest256.update(mu, 0, CrhBytes);
691-
shakeDigest256.update(buf, 0, DilithiumK * DilithiumPolyW1PackedBytes);
692-
shakeDigest256.doFinal(c2, 0, DilithiumCTilde);
548+
shake256Digest.update(buf, 0, CrhBytes + DilithiumK * DilithiumPolyW1PackedBytes);
549+
shake256Digest.doFinal(buf, 0, DilithiumCTilde);
693550

694-
// System.out.println("c = ");
695-
// Helper.printByteArray(c);
696-
697-
// System.out.println("c2 = ");
698-
// Helper.printByteArray(c2);
699-
700-
701-
return Arrays.constantTimeAreEqual(c, c2);
551+
return Arrays.constantTimeAreEqual(DilithiumCTilde, sig, 0, buf, 0);
702552
}
703553

704-
public byte[][] generateKeyPair()
554+
byte[][] generateKeyPair()
705555
{
706556
byte[] seedBuf = new byte[SeedBytes];
707557
random.nextBytes(seedBuf);
708558
return generateKeyPairInternal(seedBuf);
709-
710559
}
711-
712560
}

core/src/main/java/org/bouncycastle/pqc/crypto/mldsa/Packing.java

Lines changed: 12 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
class Packing
66
{
7-
87
static byte[] packPublicKey(PolyVecK t1, MLDSAEngine engine)
98
{
109
byte[] out = new byte[engine.getCryptoPublicKeyBytes() - MLDSAEngine.SeedBytes];
@@ -62,7 +61,6 @@ static byte[][] packSecretKey(byte[] rho, byte[] tr, byte[] key, PolyVecK t0, Po
6261
* @param engine
6362
* @return Byte matrix where byte[0] = rho, byte[1] = tr, byte[2] = key
6463
*/
65-
6664
static void unpackSecretKey(PolyVecK t0, PolyVecL s1, PolyVecK s2, byte[] t0Enc, byte[] s1Enc, byte[] s2Enc, MLDSAEngine engine)
6765
{
6866
for (int i = 0; i < engine.getDilithiumL(); ++i)
@@ -81,40 +79,32 @@ static void unpackSecretKey(PolyVecK t0, PolyVecL s1, PolyVecK s2, byte[] t0Enc,
8179
}
8280
}
8381

84-
static byte[] packSignature(byte[] c, PolyVecL z, PolyVecK h, MLDSAEngine engine)
82+
static void packSignature(byte[] sig, PolyVecL z, PolyVecK h, MLDSAEngine engine)
8583
{
86-
int i, j, k, end = 0;
87-
byte[] outBytes = new byte[engine.getCryptoBytes()];
88-
89-
System.arraycopy(c, 0, outBytes, 0, engine.getDilithiumCTilde());
90-
end += engine.getDilithiumCTilde();
91-
92-
for (i = 0; i < engine.getDilithiumL(); ++i)
84+
int end = engine.getDilithiumCTilde();
85+
for (int i = 0; i < engine.getDilithiumL(); ++i)
9386
{
94-
System.arraycopy(z.getVectorIndex(i).zPack(), 0, outBytes, end + i * engine.getDilithiumPolyZPackedBytes(), engine.getDilithiumPolyZPackedBytes());
87+
z.getVectorIndex(i).zPack(sig, end);
88+
end += engine.getDilithiumPolyZPackedBytes();
9589
}
96-
end += engine.getDilithiumL() * engine.getDilithiumPolyZPackedBytes();
9790

98-
for (i = 0; i < engine.getDilithiumOmega() + engine.getDilithiumK(); ++i)
91+
for (int i = 0; i < engine.getDilithiumOmega() + engine.getDilithiumK(); ++i)
9992
{
100-
outBytes[end + i] = 0;
93+
sig[end + i] = 0;
10194
}
10295

103-
k = 0;
104-
for (i = 0; i < engine.getDilithiumK(); ++i)
96+
int k = 0;
97+
for (int i = 0; i < engine.getDilithiumK(); ++i)
10598
{
106-
for (j = 0; j < MLDSAEngine.DilithiumN; ++j)
99+
for (int j = 0; j < MLDSAEngine.DilithiumN; ++j)
107100
{
108101
if (h.getVectorIndex(i).getCoeffIndex(j) != 0)
109102
{
110-
outBytes[end + k++] = (byte)j;
103+
sig[end + k++] = (byte)j;
111104
}
112105
}
113-
outBytes[end + engine.getDilithiumOmega() + i] = (byte)k;
106+
sig[end + engine.getDilithiumOmega() + i] = (byte)k;
114107
}
115-
116-
return outBytes;
117-
118108
}
119109

120110
static boolean unpackSignature(PolyVecL z, PolyVecK h, byte[] sig, MLDSAEngine engine)
@@ -161,5 +151,4 @@ static boolean unpackSignature(PolyVecL z, PolyVecK h, byte[] sig, MLDSAEngine e
161151
}
162152
return true;
163153
}
164-
165154
}

0 commit comments

Comments
 (0)