Skip to content

Commit 7d33d39

Browse files
committed
Refactoring in pqc.crypto.slhdsa (sphincsplus)
1 parent 906b6f1 commit 7d33d39

File tree

6 files changed

+77
-89
lines changed

6 files changed

+77
-89
lines changed

core/src/main/java/org/bouncycastle/pqc/crypto/slhdsa/ADRS.java

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55

66
class ADRS
77
{
8-
public static final int WOTS_HASH = 0;
9-
public static final int WOTS_PK = 1;
10-
public static final int TREE = 2;
11-
public static final int FORS_TREE = 3;
12-
public static final int FORS_PK = 4;
13-
public static final int WOTS_PRF = 5;
14-
public static final int FORS_PRF = 6;
15-
8+
static final int WOTS_HASH = 0;
9+
static final int WOTS_PK = 1;
10+
static final int TREE = 2;
11+
static final int FORS_TREE = 3;
12+
static final int FORS_PK = 4;
13+
static final int WOTS_PRF = 5;
14+
static final int FORS_PRF = 6;
15+
1616
static final int OFFSET_LAYER = 0;
1717
static final int OFFSET_TREE = 4;
1818
static final int OFFSET_TREE_HGT = 24;
@@ -21,7 +21,7 @@ class ADRS
2121
static final int OFFSET_KP_ADDR = 20;
2222
static final int OFFSET_CHAIN_ADDR = 24;
2323
static final int OFFSET_HASH_ADDR = 28;
24-
24+
2525
final byte[] value = new byte[32];
2626

2727
ADRS()
@@ -59,11 +59,6 @@ public void setTreeHeight(int height)
5959
Pack.intToBigEndian(height, value, OFFSET_TREE_HGT);
6060
}
6161

62-
public int getTreeHeight()
63-
{
64-
return Pack.bigEndianToInt(value, OFFSET_TREE_HGT);
65-
}
66-
6762
public void setTreeIndex(int index)
6863
{
6964
Pack.intToBigEndian(index, value, OFFSET_TREE_INDEX);
@@ -87,11 +82,6 @@ public void changeType(int type)
8782
Pack.intToBigEndian(type, value, OFFSET_TYPE);
8883
}
8984

90-
public int getType()
91-
{
92-
return Pack.bigEndianToInt(value, OFFSET_TYPE);
93-
}
94-
9585
public void setKeyPairAddress(int keyPairAddr)
9686
{
9787
Pack.intToBigEndian(keyPairAddr, value, OFFSET_KP_ADDR);

core/src/main/java/org/bouncycastle/pqc/crypto/slhdsa/Fors.java

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,12 @@ public Fors(SLHDSAEngine engine)
1818
// Output: n-byte root node - top node on Stack
1919
byte[] treehash(byte[] skSeed, int s, int z, byte[] pkSeed, ADRS adrsParam)
2020
{
21-
LinkedList<NodeEntry> stack = new LinkedList<NodeEntry>();
22-
23-
if (s % (1 << z) != 0)
21+
if ((s >>> z) << z != s)
2422
{
2523
return null;
2624
}
2725

26+
LinkedList<NodeEntry> stack = new LinkedList<NodeEntry>();
2827
ADRS adrs = new ADRS(adrsParam);
2928

3029
for (int idx = 0; idx < (1 << z); idx++)
@@ -42,19 +41,23 @@ byte[] treehash(byte[] skSeed, int s, int z, byte[] pkSeed, ADRS adrsParam)
4241

4342
adrs.setTreeHeight(1);
4443

44+
int adrsTreeHeight = 1;
45+
int adrsTreeIndex = s + idx;
46+
4547
// while ( Top node on Stack has same height as node )
46-
while (!stack.isEmpty()
47-
&& ((NodeEntry)stack.get(0)).nodeHeight == adrs.getTreeHeight())
48+
while (!stack.isEmpty() && ((NodeEntry)stack.get(0)).nodeHeight == adrsTreeHeight)
4849
{
49-
adrs.setTreeIndex((adrs.getTreeIndex() - 1) / 2);
50-
NodeEntry current = ((NodeEntry)stack.remove(0));
50+
adrsTreeIndex = (adrsTreeIndex - 1) / 2;
51+
adrs.setTreeIndex(adrsTreeIndex);
5152

53+
NodeEntry current = ((NodeEntry)stack.remove(0));
5254
node = engine.H(pkSeed, adrs, current.nodeValue, node);
53-
//topmost node is now one layer higher
54-
adrs.setTreeHeight(adrs.getTreeHeight() + 1);
55+
56+
// topmost node is now one layer higher
57+
adrs.setTreeHeight(++adrsTreeHeight);
5558
}
5659

57-
stack.add(0, new NodeEntry(node, adrs.getTreeHeight()));
60+
stack.add(0, new NodeEntry(node, adrsTreeHeight));
5861
}
5962

6063
return ((NodeEntry)stack.get(0)).nodeValue;

core/src/main/java/org/bouncycastle/pqc/crypto/slhdsa/HT.java

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -144,21 +144,18 @@ SIG_XMSS xmss_sign(byte[] M, byte[] skSeed, int idx, byte[] pkSeed, ADRS paramAd
144144
return new SIG_XMSS(sig, AUTH);
145145
}
146146

147-
//
148-
// Input: Secret seed SK.seed, start index s, target node height z, public seed
149-
//PK.seed, address ADRS
147+
// Input: Secret seed SK.seed, start index s, target node height z, public seed PK.seed, address ADRS
150148
// Output: n-byte root node - top node on Stack
151149
byte[] treehash(byte[] skSeed, int s, int z, byte[] pkSeed, ADRS adrsParam)
152150
{
153-
ADRS adrs = new ADRS(adrsParam);
154-
155-
LinkedList<NodeEntry> stack = new LinkedList<NodeEntry>();
156-
157-
if (s % (1 << z) != 0)
151+
if ((s >>> z) << z != s)
158152
{
159153
return null;
160154
}
161155

156+
LinkedList<NodeEntry> stack = new LinkedList<NodeEntry>();
157+
ADRS adrs = new ADRS(adrsParam);
158+
162159
for (int idx = 0; idx < (1 << z); idx++)
163160
{
164161
adrs.setTypeAndClear(ADRS.WOTS_HASH);
@@ -169,19 +166,23 @@ byte[] treehash(byte[] skSeed, int s, int z, byte[] pkSeed, ADRS adrsParam)
169166
adrs.setTreeHeight(1);
170167
adrs.setTreeIndex(s + idx);
171168

169+
int adrsTreeHeight = 1;
170+
int adrsTreeIndex = s + idx;
171+
172172
// while ( Top node on Stack has same height as node )
173-
while (!stack.isEmpty()
174-
&& ((NodeEntry)stack.get(0)).nodeHeight == adrs.getTreeHeight())
173+
while (!stack.isEmpty() && ((NodeEntry)stack.get(0)).nodeHeight == adrsTreeHeight)
175174
{
176-
adrs.setTreeIndex((adrs.getTreeIndex() - 1) / 2);
177-
NodeEntry current = ((NodeEntry)stack.remove(0));
175+
adrsTreeIndex = (adrsTreeIndex - 1) / 2;
176+
adrs.setTreeIndex(adrsTreeIndex);
178177

178+
NodeEntry current = ((NodeEntry)stack.remove(0));
179179
node = engine.H(pkSeed, adrs, current.nodeValue, node);
180-
//topmost node is now one layer higher
181-
adrs.setTreeHeight(adrs.getTreeHeight() + 1);
180+
181+
// topmost node is now one layer higher
182+
adrs.setTreeHeight(++adrsTreeHeight);
182183
}
183184

184-
stack.add(0, new NodeEntry(node, adrs.getTreeHeight()));
185+
stack.add(0, new NodeEntry(node, adrsTreeHeight));
185186
}
186187

187188
return ((NodeEntry)stack.get(0)).nodeValue;

core/src/main/java/org/bouncycastle/pqc/crypto/sphincsplus/ADRS.java

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,14 @@
55

66
class ADRS
77
{
8-
public static final int WOTS_HASH = 0;
9-
public static final int WOTS_PK = 1;
10-
public static final int TREE = 2;
11-
public static final int FORS_TREE = 3;
12-
public static final int FORS_PK = 4;
13-
public static final int WOTS_PRF = 5;
14-
public static final int FORS_PRF = 6;
15-
8+
static final int WOTS_HASH = 0;
9+
static final int WOTS_PK = 1;
10+
static final int TREE = 2;
11+
static final int FORS_TREE = 3;
12+
static final int FORS_PK = 4;
13+
static final int WOTS_PRF = 5;
14+
static final int FORS_PRF = 6;
15+
1616
static final int OFFSET_LAYER = 0;
1717
static final int OFFSET_TREE = 4;
1818
static final int OFFSET_TREE_HGT = 24;
@@ -21,7 +21,7 @@ class ADRS
2121
static final int OFFSET_KP_ADDR = 20;
2222
static final int OFFSET_CHAIN_ADDR = 24;
2323
static final int OFFSET_HASH_ADDR = 28;
24-
24+
2525
final byte[] value = new byte[32];
2626

2727
ADRS()
@@ -59,11 +59,6 @@ public void setTreeHeight(int height)
5959
Pack.intToBigEndian(height, value, OFFSET_TREE_HGT);
6060
}
6161

62-
public int getTreeHeight()
63-
{
64-
return Pack.bigEndianToInt(value, OFFSET_TREE_HGT);
65-
}
66-
6762
public void setTreeIndex(int index)
6863
{
6964
Pack.intToBigEndian(index, value, OFFSET_TREE_INDEX);
@@ -87,11 +82,6 @@ public void changeType(int type)
8782
Pack.intToBigEndian(type, value, OFFSET_TYPE);
8883
}
8984

90-
public int getType()
91-
{
92-
return Pack.bigEndianToInt(value, OFFSET_TYPE);
93-
}
94-
9585
public void setKeyPairAddress(int keyPairAddr)
9686
{
9787
Pack.intToBigEndian(keyPairAddr, value, OFFSET_KP_ADDR);

core/src/main/java/org/bouncycastle/pqc/crypto/sphincsplus/Fors.java

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,12 @@ public Fors(SPHINCSPlusEngine engine)
1717
// Output: n-byte root node - top node on Stack
1818
byte[] treehash(byte[] skSeed, int s, int z, byte[] pkSeed, ADRS adrsParam)
1919
{
20-
LinkedList<NodeEntry> stack = new LinkedList<NodeEntry>();
21-
22-
if (s % (1 << z) != 0)
20+
if ((s >>> z) << z != s)
2321
{
2422
return null;
2523
}
2624

25+
LinkedList<NodeEntry> stack = new LinkedList<NodeEntry>();
2726
ADRS adrs = new ADRS(adrsParam);
2827

2928
for (int idx = 0; idx < (1 << z); idx++)
@@ -41,19 +40,23 @@ byte[] treehash(byte[] skSeed, int s, int z, byte[] pkSeed, ADRS adrsParam)
4140

4241
adrs.setTreeHeight(1);
4342

43+
int adrsTreeHeight = 1;
44+
int adrsTreeIndex = s + idx;
45+
4446
// while ( Top node on Stack has same height as node )
45-
while (!stack.isEmpty()
46-
&& ((NodeEntry)stack.get(0)).nodeHeight == adrs.getTreeHeight())
47+
while (!stack.isEmpty() && ((NodeEntry)stack.get(0)).nodeHeight == adrsTreeHeight)
4748
{
48-
adrs.setTreeIndex((adrs.getTreeIndex() - 1) / 2);
49-
NodeEntry current = ((NodeEntry)stack.remove(0));
49+
adrsTreeIndex = (adrsTreeIndex - 1) / 2;
50+
adrs.setTreeIndex(adrsTreeIndex);
5051

52+
NodeEntry current = ((NodeEntry)stack.remove(0));
5153
node = engine.H(pkSeed, adrs, current.nodeValue, node);
52-
//topmost node is now one layer higher
53-
adrs.setTreeHeight(adrs.getTreeHeight() + 1);
54+
55+
// topmost node is now one layer higher
56+
adrs.setTreeHeight(++adrsTreeHeight);
5457
}
5558

56-
stack.add(0, new NodeEntry(node, adrs.getTreeHeight()));
59+
stack.add(0, new NodeEntry(node, adrsTreeHeight));
5760
}
5861

5962
return ((NodeEntry)stack.get(0)).nodeValue;

core/src/main/java/org/bouncycastle/pqc/crypto/sphincsplus/HT.java

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -144,21 +144,18 @@ SIG_XMSS xmss_sign(byte[] M, byte[] skSeed, int idx, byte[] pkSeed, ADRS paramAd
144144
return new SIG_XMSS(sig, AUTH);
145145
}
146146

147-
//
148-
// Input: Secret seed SK.seed, start index s, target node height z, public seed
149-
//PK.seed, address ADRS
147+
// Input: Secret seed SK.seed, start index s, target node height z, public seed PK.seed, address ADRS
150148
// Output: n-byte root node - top node on Stack
151149
byte[] treehash(byte[] skSeed, int s, int z, byte[] pkSeed, ADRS adrsParam)
152150
{
153-
ADRS adrs = new ADRS(adrsParam);
154-
155-
LinkedList<NodeEntry> stack = new LinkedList<NodeEntry>();
156-
157-
if (s % (1 << z) != 0)
151+
if ((s >>> z) << z != s)
158152
{
159153
return null;
160154
}
161155

156+
LinkedList<NodeEntry> stack = new LinkedList<NodeEntry>();
157+
ADRS adrs = new ADRS(adrsParam);
158+
162159
for (int idx = 0; idx < (1 << z); idx++)
163160
{
164161
adrs.setTypeAndClear(ADRS.WOTS_HASH);
@@ -169,21 +166,25 @@ byte[] treehash(byte[] skSeed, int s, int z, byte[] pkSeed, ADRS adrsParam)
169166
adrs.setTreeHeight(1);
170167
adrs.setTreeIndex(s + idx);
171168

169+
int adrsTreeHeight = 1;
170+
int adrsTreeIndex = s + idx;
171+
172172
// while ( Top node on Stack has same height as node )
173-
while (!stack.isEmpty()
174-
&& ((NodeEntry)stack.get(0)).nodeHeight == adrs.getTreeHeight())
173+
while (!stack.isEmpty() && ((NodeEntry)stack.get(0)).nodeHeight == adrsTreeHeight)
175174
{
176-
adrs.setTreeIndex((adrs.getTreeIndex() - 1) / 2);
177-
NodeEntry current = ((NodeEntry)stack.remove(0));
175+
adrsTreeIndex = (adrsTreeIndex - 1) / 2;
176+
adrs.setTreeIndex(adrsTreeIndex);
178177

178+
NodeEntry current = ((NodeEntry)stack.remove(0));
179179
node = engine.H(pkSeed, adrs, current.nodeValue, node);
180-
//topmost node is now one layer higher
181-
adrs.setTreeHeight(adrs.getTreeHeight() + 1);
180+
181+
// topmost node is now one layer higher
182+
adrs.setTreeHeight(++adrsTreeHeight);
182183
}
183184

184-
stack.add(0, new NodeEntry(node, adrs.getTreeHeight()));
185+
stack.add(0, new NodeEntry(node, adrsTreeHeight));
185186
}
186-
187+
187188
return ((NodeEntry)stack.get(0)).nodeValue;
188189
}
189190

0 commit comments

Comments
 (0)