Skip to content

Commit 8dffb9b

Browse files
committed
sphincsplus refactoring
1 parent 1f1dc86 commit 8dffb9b

File tree

2 files changed

+33
-29
lines changed

2 files changed

+33
-29
lines changed

core/src/main/java/org/bouncycastle/pqc/crypto/sphincsplus/SPHINCSPlusEngine.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -605,7 +605,7 @@ public byte[] T_l(byte[] pkSeed, ADRS adrs, byte[] m)
605605

606606
byte[] PRF(byte[] pkSeed, byte[] skSeed, ADRS adrs)
607607
{
608-
byte[] rv = new byte[64];
608+
byte[] rv = new byte[32];
609609
harakaS512Digest.update(adrs.value, 0, adrs.value.length);
610610
harakaS512Digest.update(skSeed, 0, skSeed.length);
611611
harakaS512Digest.doFinal(rv, 0);

core/src/main/java/org/bouncycastle/pqc/crypto/sphincsplus/WotsPlus.java

Lines changed: 32 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,8 @@ byte[] pkGen(byte[] skSeed, byte[] pkSeed, ADRS paramAdrs)
4242
return engine.T_l(pkSeed, wotspkADRS, Arrays.concatenate(tmp));
4343
}
4444

45-
// #Input: Input string X, start index i, number of steps s, public seed PK.seed,
46-
// address ADRS
47-
// #Output: value of F iterated s times on X
45+
// #Input: Input string X, start index i, number of steps s, public seed PK.seed, address ADRS
46+
// #Output: value of F iterated s times on X
4847
byte[] chain(byte[] X, int i, int s, byte[] pkSeed, ADRS adrs)
4948
{
5049
if (s == 0)
@@ -55,36 +54,42 @@ byte[] chain(byte[] X, int i, int s, byte[] pkSeed, ADRS adrs)
5554
{
5655
return null;
5756
}
58-
byte[] tmp = chain(X, i, s - 1, pkSeed, adrs);
59-
adrs.setHashAddress(i + s - 1);
60-
tmp = engine.F(pkSeed, adrs, tmp);
61-
62-
return tmp;
57+
byte[] result = X;
58+
for (int j = 0; j < s; ++j)
59+
{
60+
adrs.setHashAddress(i + j);
61+
result = engine.F(pkSeed, adrs, result);
62+
}
63+
return result;
6364
}
6465

65-
//
6666
// #Input: Message M, secret seed SK.seed, public seed PK.seed, address ADRS
6767
// #Output: WOTS+ signature sig
6868
public byte[] sign(byte[] M, byte[] skSeed, byte[] pkSeed, ADRS paramAdrs)
6969
{
7070
ADRS adrs = new ADRS(paramAdrs);
7171

72-
int csum = 0;
72+
int[] msg = new int[engine.WOTS_LEN];
73+
7374
// convert message to base w
74-
int[] msg = base_w(M, w, engine.WOTS_LEN1);
75+
base_w(M, 0, w, msg, 0, engine.WOTS_LEN1);
76+
7577
// compute checksum
78+
int csum = 0;
7679
for (int i = 0; i < engine.WOTS_LEN1; i++)
7780
{
7881
csum += w - 1 - msg[i];
7982
}
83+
8084
// convert csum to base w
8185
if ((engine.WOTS_LOGW % 8) != 0)
8286
{
8387
csum = csum << (8 - ((engine.WOTS_LEN2 * engine.WOTS_LOGW) % 8));
8488
}
8589
int len_2_bytes = (engine.WOTS_LEN2 * engine.WOTS_LOGW + 7) / 8;
86-
byte[] bytes = Pack.intToBigEndian(csum);
87-
msg = Arrays.concatenate(msg, base_w(Arrays.copyOfRange(bytes, 4 - len_2_bytes, bytes.length), w, engine.WOTS_LEN2));
90+
byte[] csum_bytes = Pack.intToBigEndian(csum);
91+
base_w(csum_bytes, 4 - len_2_bytes, w, msg, engine.WOTS_LEN1, engine.WOTS_LEN2);
92+
8893
byte[][] sig = new byte[engine.WOTS_LEN][];
8994
for (int i = 0; i < engine.WOTS_LEN; i++)
9095
{
@@ -105,45 +110,44 @@ public byte[] sign(byte[] M, byte[] skSeed, byte[] pkSeed, ADRS paramAdrs)
105110
//
106111
// Input: len_X-byte string X, int w, output length out_len
107112
// Output: out_len int array basew
108-
int[] base_w(byte[] X, int w, int out_len)
113+
void base_w(byte[] X, int XOff, int w, int[] output, int outOff, int outLen)
109114
{
110-
int in = 0;
111-
int out = 0;
112115
int total = 0;
113116
int bits = 0;
114-
int[] output = new int[out_len];
115117

116-
for (int consumed = 0; consumed < out_len; consumed++)
118+
for (int consumed = 0; consumed < outLen; consumed++)
117119
{
118120
if (bits == 0)
119121
{
120-
total = X[in];
121-
in++;
122+
total = X[XOff++];
122123
bits += 8;
123124
}
124125
bits -= engine.WOTS_LOGW;
125-
output[out] = ((total >>> bits) & (w - 1));
126-
out++;
126+
output[outOff++] = ((total >>> bits) & (w - 1));
127127
}
128-
return output;
129128
}
130129

131130
public byte[] pkFromSig(byte[] sig, byte[] M, byte[] pkSeed, ADRS adrs)
132131
{
133-
int csum = 0;
134132
ADRS wotspkADRS = new ADRS(adrs);
133+
134+
int[] msg = new int[engine.WOTS_LEN];
135+
135136
// convert message to base w
136-
int[] msg = base_w(M, w, engine.WOTS_LEN1);
137+
base_w(M, 0, w, msg, 0, engine.WOTS_LEN1);
138+
137139
// compute checksum
140+
int csum = 0;
138141
for (int i = 0; i < engine.WOTS_LEN1; i++ )
139142
{
140143
csum += w - 1 - msg[i];
141144
}
145+
142146
// convert csum to base w
143147
csum = csum << (8 - ((engine.WOTS_LEN2 * engine.WOTS_LOGW) % 8));
144148
int len_2_bytes = (engine.WOTS_LEN2 * engine.WOTS_LOGW + 7) / 8;
145-
146-
msg = Arrays.concatenate(msg, base_w(Arrays.copyOfRange(Pack.intToBigEndian(csum), 4 - len_2_bytes, 4), w, engine.WOTS_LEN2));
149+
byte[] csum_bytes = Pack.intToBigEndian(csum);
150+
base_w(csum_bytes, 4 - len_2_bytes, w, msg, engine.WOTS_LEN1, engine.WOTS_LEN2);
147151

148152
byte[] sigI = new byte[engine.N];
149153
byte[][] tmp = new byte[engine.WOTS_LEN][];
@@ -152,7 +156,7 @@ public byte[] pkFromSig(byte[] sig, byte[] M, byte[] pkSeed, ADRS adrs)
152156
adrs.setChainAddress(i);
153157
System.arraycopy(sig, i * engine.N, sigI, 0, engine.N);
154158
tmp[i] = chain(sigI, msg[i], w - 1 - msg[i], pkSeed, adrs);
155-
} // f6be78d057cc8056907ad2bf83cc8be7
159+
}
156160

157161
wotspkADRS.setType(ADRS.WOTS_PK);
158162
wotspkADRS.setKeyPairAddress(adrs.getKeyPairAddress());

0 commit comments

Comments
 (0)