Skip to content

Commit 92cc716

Browse files
committed
refactor of ML-DSA verification to use digest as accumulator.
1 parent 499cb95 commit 92cc716

File tree

4 files changed

+242
-126
lines changed

4 files changed

+242
-126
lines changed

core/src/main/java/org/bouncycastle/pqc/crypto/mldsa/HashMLDSASigner.java

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,16 @@ public void init(boolean forSigning, CipherParameters param)
6363
{
6464
pubKey = (MLDSAPublicKeyParameters)param;
6565

66+
engine = pubKey.getParameters().getEngine(this.random);
67+
68+
byte[] ctx = pubKey.getContext();
69+
if (ctx.length > 255)
70+
{
71+
throw new IllegalArgumentException("context too long");
72+
}
73+
74+
engine.initVerify(pubKey.rho, pubKey.t1, true, ctx);
75+
6676
initDigest(pubKey);
6777
}
6878

@@ -114,12 +124,23 @@ public byte[] generateSignature() throws CryptoException, DataLengthException
114124

115125
byte[] ds_message = Arrays.concatenate(digestOidEncoding, hash);
116126

117-
return engine.signInternal(ds_message, ds_message.length, privKey.rho, privKey.k, privKey.t0, privKey.s1, privKey.s2, rnd);
127+
msgDigest.update(ds_message, 0, ds_message.length);
128+
129+
return engine.generateSignature(msgDigest, privKey.rho, privKey.k, privKey.t0, privKey.s1, privKey.s2, rnd);
118130
}
119131

120132
@Override
121133
public boolean verifySignature(byte[] signature)
122134
{
135+
SHAKEDigest msgDigest = engine.getShake256Digest();
136+
byte[] hash = new byte[digest.getDigestSize()];
137+
138+
digest.doFinal(hash, 0);
139+
140+
byte[] ds_message = Arrays.concatenate(digestOidEncoding, hash);
141+
142+
msgDigest.update(ds_message, 0, ds_message.length);
143+
123144
MLDSAEngine engine = pubKey.getParameters().getEngine(random);
124145

125146
byte[] ctx = pubKey.getContext();
@@ -128,17 +149,7 @@ public boolean verifySignature(byte[] signature)
128149
throw new RuntimeException("Context too long");
129150
}
130151

131-
byte[] hash = new byte[digest.getDigestSize()];
132-
digest.doFinal(hash, 0);
133-
134-
byte[] ds_message = new byte[1 + 1 + ctx.length + + digestOidEncoding.length + hash.length];
135-
ds_message[0] = 1;
136-
ds_message[1] = (byte)ctx.length;
137-
System.arraycopy(ctx, 0, ds_message, 2, ctx.length);
138-
System.arraycopy(digestOidEncoding, 0, ds_message, 2 + ctx.length, digestOidEncoding.length);
139-
System.arraycopy(hash, 0, ds_message, 2 + ctx.length + digestOidEncoding.length, hash.length);
140-
141-
return engine.verifyInternal(signature, signature.length, ds_message, ds_message.length, pubKey.rho, pubKey.t1);
152+
return engine.verifyInternal(signature, signature.length, msgDigest, pubKey.rho, pubKey.t1);
142153
}
143154

144155
/**

core/src/main/java/org/bouncycastle/pqc/crypto/mldsa/MLDSAEngine.java

Lines changed: 126 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,7 @@ SHAKEDigest getShake256Digest()
309309
{
310310
return new SHAKEDigest(shake256Digest);
311311
}
312+
312313
void initSign(byte[] tr, boolean isPreHash, byte[] ctx)
313314
{
314315
this.shake256Digest.update(tr, 0, TrBytes);
@@ -320,6 +321,26 @@ void initSign(byte[] tr, boolean isPreHash, byte[] ctx)
320321
}
321322
}
322323

324+
void initVerify(byte[] rho, byte[] encT1, boolean isPreHash, byte[] ctx)
325+
{
326+
byte[] mu = new byte[TrBytes];
327+
328+
shake256Digest.update(rho, 0, rho.length);
329+
shake256Digest.update(encT1, 0, encT1.length);
330+
shake256Digest.doFinal(mu, 0, TrBytes);
331+
// System.out.println("mu before = ");
332+
// Helper.printByteArray(mu);
333+
334+
shake256Digest.update(mu, 0, TrBytes);
335+
336+
if (ctx != null)
337+
{
338+
this.shake256Digest.update((isPreHash) ? (byte)1 : (byte)0);
339+
this.shake256Digest.update((byte)ctx.length);
340+
this.shake256Digest.update(ctx, 0, ctx.length);
341+
}
342+
}
343+
323344
public byte[] signInternal(byte[] msg, int msglen, byte[] rho, byte[] key, byte[] t0Enc, byte[] s1Enc, byte[] s2Enc, byte[] rnd)
324345
{
325346
SHAKEDigest shake256 = new SHAKEDigest(shake256Digest);
@@ -428,6 +449,107 @@ byte[] generateSignature(SHAKEDigest shake256Digest, byte[] rho, byte[] key, byt
428449
return null;
429450
}
430451

452+
public boolean verifyInternal(byte[] sig, int siglen, SHAKEDigest shake256Digest, byte[] rho, byte[] encT1)
453+
{
454+
if (siglen != CryptoBytes)
455+
{
456+
return false;
457+
}
458+
459+
// System.out.println("publickey = ");
460+
// Helper.printByteArray(publicKey);
461+
byte[] buf,
462+
mu = new byte[CrhBytes],
463+
c,
464+
c2 = new byte[DilithiumCTilde];
465+
Poly cp = new Poly(this);
466+
PolyVecMatrix aMatrix = new PolyVecMatrix(this);
467+
PolyVecL z = new PolyVecL(this);
468+
PolyVecK t1 = new PolyVecK(this), w1 = new PolyVecK(this), h = new PolyVecK(this);
469+
470+
t1 = Packing.unpackPublicKey(t1, encT1, this);
471+
472+
// System.out.println(t1.toString("t1"));
473+
474+
// System.out.println("rho = ");
475+
// Helper.printByteArray(rho);
476+
477+
if (!Packing.unpackSignature(z, h, sig, this))
478+
{
479+
return false;
480+
}
481+
c = Arrays.copyOfRange(sig, 0, DilithiumCTilde);
482+
483+
// System.out.println(z.toString("z"));
484+
// System.out.println(h.toString("h"));
485+
486+
if (z.checkNorm(getDilithiumGamma1() - getDilithiumBeta()))
487+
{
488+
return false;
489+
}
490+
491+
shake256Digest.doFinal(mu, 0);
492+
493+
// System.out.println("mu after = ");
494+
// Helper.printByteArray(mu);
495+
496+
// Matrix-vector multiplication; compute Az - c2^dt1
497+
cp.challenge(Arrays.copyOfRange(c, 0, DilithiumCTilde)); // use only first DilithiumCTilde of c.
498+
// System.out.println("cp = ");
499+
// System.out.println(cp.toString());
500+
501+
aMatrix.expandMatrix(rho);
502+
// System.out.println(aMatrix.toString("aMatrix = "));
503+
504+
505+
z.polyVecNtt();
506+
aMatrix.pointwiseMontgomery(w1, z);
507+
508+
cp.polyNtt();
509+
// System.out.println("cp = ");
510+
// System.out.println(cp.toString());
511+
512+
t1.shiftLeft();
513+
t1.polyVecNtt();
514+
t1.pointwisePolyMontgomery(cp, t1);
515+
516+
// System.out.println(t1.toString("t1"));
517+
518+
w1.subtract(t1);
519+
w1.reduce();
520+
w1.invNttToMont();
521+
522+
// System.out.println(w1.toString("w1 before caddq"));
523+
524+
// Reconstruct w1
525+
w1.conditionalAddQ();
526+
// System.out.println(w1.toString("w1 before hint"));
527+
w1.useHint(w1, h);
528+
// System.out.println(w1.toString("w1"));
529+
530+
buf = w1.packW1();
531+
532+
// System.out.println("buf = ");
533+
// Helper.printByteArray(buf);
534+
535+
// System.out.println("mu = ");
536+
// Helper.printByteArray(mu);
537+
538+
SHAKEDigest shakeDigest256 = new SHAKEDigest(256);
539+
shakeDigest256.update(mu, 0, CrhBytes);
540+
shakeDigest256.update(buf, 0, DilithiumK * DilithiumPolyW1PackedBytes);
541+
shakeDigest256.doFinal(c2, 0, DilithiumCTilde);
542+
543+
// System.out.println("c = ");
544+
// Helper.printByteArray(c);
545+
546+
// System.out.println("c2 = ");
547+
// Helper.printByteArray(c2);
548+
549+
550+
return Arrays.constantTimeAreEqual(c, c2);
551+
}
552+
431553
public boolean verifyInternal(byte[] sig, int siglen, byte[] msg, int msglen, byte[] rho, byte[] encT1)
432554
{
433555
if (siglen != CryptoBytes)
@@ -468,13 +590,13 @@ public boolean verifyInternal(byte[] sig, int siglen, byte[] msg, int msglen, by
468590
}
469591

470592
// Compute crh(crh(rho, t1), msg)
471-
shake256Digest.update(rho, 0, rho.length);
472-
shake256Digest.update(encT1, 0, encT1.length);
473-
shake256Digest.doFinal(mu, 0, TrBytes);
593+
// shake256Digest.update(rho, 0, rho.length);
594+
// shake256Digest.update(encT1, 0, encT1.length);
595+
// shake256Digest.doFinal(mu, 0, TrBytes);
474596
// System.out.println("mu before = ");
475597
// Helper.printByteArray(mu);
476598

477-
shake256Digest.update(mu, 0, TrBytes);
599+
//shake256Digest.update(mu, 0, TrBytes);
478600
shake256Digest.update(msg, 0, msglen);
479601
shake256Digest.doFinal(mu, 0);
480602

Lines changed: 29 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
package org.bouncycastle.pqc.crypto.mldsa;
22

3-
import java.io.ByteArrayOutputStream;
43
import java.security.SecureRandom;
54

65
import org.bouncycastle.crypto.CipherParameters;
@@ -21,9 +20,6 @@ public class MLDSASigner
2120

2221
private SecureRandom random;
2322

24-
// TODO: temporary
25-
private ByteArrayOutputStream bOut = new ByteArrayOutputStream();
26-
2723
public MLDSASigner()
2824
{
2925
}
@@ -62,8 +58,19 @@ public void init(boolean forSigning, CipherParameters param)
6258
else
6359
{
6460
pubKey = (MLDSAPublicKeyParameters)param;
65-
engine = null;
66-
msgDigest = null;
61+
62+
engine = pubKey.getParameters().getEngine(random);
63+
64+
byte[] ctx = pubKey.getContext();
65+
if (ctx.length > 255)
66+
{
67+
throw new IllegalArgumentException("context too long");
68+
}
69+
70+
engine.initVerify(pubKey.rho, pubKey.t1, false, ctx);
71+
72+
msgDigest = engine.getShake256Digest();
73+
6774
isPreHash = pubKey.getParameters().isPreHash();
6875
}
6976

@@ -75,26 +82,12 @@ public void init(boolean forSigning, CipherParameters param)
7582

7683
public void update(byte b)
7784
{
78-
if (msgDigest != null)
79-
{
80-
msgDigest.update(b);
81-
}
82-
else
83-
{
84-
bOut.write(b);
85-
}
85+
msgDigest.update(b);
8686
}
8787

8888
public void update(byte[] in, int off, int len)
8989
{
90-
if (msgDigest != null)
91-
{
92-
msgDigest.update(in, off, len);
93-
}
94-
else
95-
{
96-
bOut.write(in, off, len);
97-
}
90+
msgDigest.update(in, off, len);
9891
}
9992

10093
public byte[] generateSignature()
@@ -106,32 +99,25 @@ public byte[] generateSignature()
10699
random.nextBytes(rnd);
107100
}
108101

109-
return engine.generateSignature(msgDigest, privKey.rho, privKey.k, privKey.t0, privKey.s1, privKey.s2, rnd);
102+
byte[] sig = engine.generateSignature(msgDigest, privKey.rho, privKey.k, privKey.t0, privKey.s1, privKey.s2, rnd);
103+
104+
reset();
105+
106+
return sig;
110107
}
111108

112109
public boolean verifySignature(byte[] signature)
113110
{
114-
boolean isTrue = verifySignature(bOut.toByteArray(), signature);
115-
116-
bOut.reset();
111+
boolean isTrue = engine.verifyInternal(signature, signature.length, msgDigest, pubKey.rho, pubKey.t1);
117112

113+
reset();
114+
118115
return isTrue;
119116
}
120117

121118
public void reset()
122119
{
123-
bOut.reset();
124-
}
125-
126-
byte[] generateSignature(byte[] message)
127-
{
128-
byte[] rnd = new byte[MLDSAEngine.RndBytes];
129-
if (random != null)
130-
{
131-
random.nextBytes(rnd);
132-
}
133-
134-
return engine.signInternal(message, message.length, privKey.rho, privKey.k, privKey.t0, privKey.s1, privKey.s2, rnd);
120+
msgDigest = engine.getShake256Digest();
135121
}
136122

137123
protected byte[] internalGenerateSignature(byte[] message, byte[] random)
@@ -143,29 +129,16 @@ protected byte[] internalGenerateSignature(byte[] message, byte[] random)
143129
return engine.signInternal(message, message.length, privKey.rho, privKey.k, privKey.t0, privKey.s1, privKey.s2, random);
144130
}
145131

146-
boolean verifySignature(byte[] message, byte[] signature)
132+
protected boolean internalVerifySignature(byte[] message, byte[] signature)
147133
{
148134
MLDSAEngine engine = pubKey.getParameters().getEngine(random);
149135

150-
byte[] ctx = pubKey.getContext();
151-
if (ctx.length > 255)
152-
{
153-
throw new RuntimeException("Context too long");
154-
}
136+
engine.initVerify(pubKey.rho, pubKey.t1, false, null);
155137

156-
byte[] ds_message = new byte[1 + 1 + ctx.length + message.length];
157-
ds_message[0] = 0;
158-
ds_message[1] = (byte)ctx.length;
159-
System.arraycopy(ctx, 0, ds_message, 2, ctx.length);
160-
System.arraycopy(message, 0, ds_message, 2 + ctx.length, message.length);
138+
SHAKEDigest msgDigest = engine.getShake256Digest();
161139

162-
return engine.verifyInternal(signature, signature.length, ds_message, ds_message.length, pubKey.rho, pubKey.t1);
163-
}
164-
165-
public boolean internalVerifySignature(byte[] message, byte[] signature)
166-
{
167-
MLDSAEngine engine = pubKey.getParameters().getEngine(random);
140+
msgDigest.update(message, 0, message.length);
168141

169-
return engine.verifyInternal(signature, signature.length, message, message.length, pubKey.rho, pubKey.t1);
142+
return engine.verifyInternal(signature, signature.length, msgDigest, pubKey.rho, pubKey.t1);
170143
}
171144
}

0 commit comments

Comments
 (0)