Skip to content

Commit 8be6b14

Browse files
anematodevondele
authored andcommitted
Network loading refactoring
closes official-stockfish#6523 No functional change
1 parent d678f83 commit 8be6b14

File tree

4 files changed

+51
-66
lines changed

4 files changed

+51
-66
lines changed

src/engine.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,10 @@ Engine::Engine(std::optional<std::string> path) :
5757
numaContext(NumaConfig::from_system()),
5858
states(new std::deque<StateInfo>(1)),
5959
threads(),
60-
networks(
61-
numaContext,
62-
// Heap-allocate because sizeof(NN::Networks) is large
63-
std::make_unique<NN::Networks>(
64-
std::make_unique<NN::NetworkBig>(NN::EvalFile{EvalFileDefaultNameBig, "None", ""},
65-
NN::EmbeddedNNUEType::BIG),
66-
std::make_unique<NN::NetworkSmall>(NN::EvalFile{EvalFileDefaultNameSmall, "None", ""},
67-
NN::EmbeddedNNUEType::SMALL))) {
60+
networks(numaContext,
61+
// Heap-allocate because sizeof(NN::Networks) is large
62+
std::make_unique<NN::Networks>(NN::EvalFile{EvalFileDefaultNameBig, "None", ""},
63+
NN::EvalFile{EvalFileDefaultNameSmall, "None", ""})) {
6864

6965
pos.set(StartFEN, false, &states->back());
7066

src/nnue/network.h

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
#include <string>
2929
#include <string_view>
3030
#include <tuple>
31-
#include <utility>
3231

3332
#include "../misc.h"
3433
#include "../types.h"
@@ -130,9 +129,9 @@ using NetworkSmall = Network<SmallNetworkArchitecture, SmallFeatureTransformer>;
130129

131130

132131
struct Networks {
133-
Networks(std::unique_ptr<NetworkBig>&& nB, std::unique_ptr<NetworkSmall>&& nS) :
134-
big(std::move(*nB)),
135-
small(std::move(*nS)) {}
132+
Networks(EvalFile bigFile, EvalFile smallFile) :
133+
big(bigFile, EmbeddedNNUEType::BIG),
134+
small(smallFile, EmbeddedNNUEType::SMALL) {}
136135

137136
NetworkBig big;
138137
NetworkSmall small;

src/nnue/nnue_common.h

Lines changed: 39 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -169,52 +169,56 @@ inline void write_little_endian(std::ostream& stream, const IntType* values, std
169169
write_little_endian<IntType>(stream, values[i]);
170170
}
171171

172-
173172
// Read N signed integers from the stream s, putting them in the array out.
174173
// The stream is assumed to be compressed using the signed LEB128 format.
175174
// See https://en.wikipedia.org/wiki/LEB128 for a description of the compression scheme.
176-
template<typename IntType, std::size_t Count>
177-
inline void read_leb_128(std::istream& stream, std::array<IntType, Count>& out) {
178-
179-
// Check the presence of our LEB128 magic string
180-
char leb128MagicString[Leb128MagicStringSize];
181-
stream.read(leb128MagicString, Leb128MagicStringSize);
182-
assert(strncmp(Leb128MagicString, leb128MagicString, Leb128MagicStringSize) == 0);
175+
template<typename BufType, typename IntType, std::size_t Count>
176+
inline void read_leb_128_detail(std::istream& stream,
177+
std::array<IntType, Count>& out,
178+
std::uint32_t& bytes_left,
179+
BufType& buf,
180+
std::uint32_t& buf_pos) {
183181

184182
static_assert(std::is_signed_v<IntType>, "Not implemented for unsigned types");
183+
static_assert(sizeof(IntType) <= 4, "Not implemented for types larger than 32 bit");
185184

186-
const std::uint32_t BUF_SIZE = 4096;
187-
std::uint8_t buf[BUF_SIZE];
188-
189-
auto bytes_left = read_little_endian<std::uint32_t>(stream);
190-
191-
std::uint32_t buf_pos = BUF_SIZE;
192-
for (std::size_t i = 0; i < Count; ++i)
185+
IntType result = 0;
186+
size_t shift = 0, i = 0;
187+
while (i < Count)
193188
{
194-
IntType result = 0;
195-
size_t shift = 0;
196-
do
189+
if (buf_pos == buf.size())
197190
{
198-
if (buf_pos == BUF_SIZE)
199-
{
200-
stream.read(reinterpret_cast<char*>(buf), std::min(bytes_left, BUF_SIZE));
201-
buf_pos = 0;
202-
}
191+
stream.read(reinterpret_cast<char*>(buf.data()),
192+
std::min(std::size_t(bytes_left), buf.size()));
193+
buf_pos = 0;
194+
}
203195

204-
std::uint8_t byte = buf[buf_pos++];
205-
--bytes_left;
206-
result |= (byte & 0x7f) << shift;
207-
shift += 7;
196+
std::uint8_t byte = buf[buf_pos++];
197+
--bytes_left;
198+
result |= (byte & 0x7f) << (shift % 32);
199+
shift += 7;
208200

209-
if ((byte & 0x80) == 0)
210-
{
211-
out[i] = (sizeof(IntType) * 8 <= shift || (byte & 0x40) == 0)
212-
? result
213-
: result | ~((1 << shift) - 1);
214-
break;
215-
}
216-
} while (shift < sizeof(IntType) * 8);
201+
if ((byte & 0x80) == 0)
202+
{
203+
out[i++] = (shift >= 32 || (byte & 0x40) == 0) ? result : result | ~((1 << shift) - 1);
204+
result = 0;
205+
shift = 0;
206+
}
217207
}
208+
}
209+
210+
template<typename... Arrays>
211+
inline void read_leb_128(std::istream& stream, Arrays&... outs) {
212+
// Check the presence of our LEB128 magic string
213+
char leb128MagicString[Leb128MagicStringSize];
214+
stream.read(leb128MagicString, Leb128MagicStringSize);
215+
assert(strncmp(Leb128MagicString, leb128MagicString, Leb128MagicStringSize) == 0);
216+
217+
auto bytes_left = read_little_endian<std::uint32_t>(stream);
218+
std::array<std::uint8_t, 8192> buf;
219+
std::uint32_t buf_pos = buf.size();
220+
221+
(read_leb_128_detail(stream, outs, bytes_left, buf, buf_pos), ...);
218222

219223
assert(bytes_left == 0);
220224
}

src/nnue/nnue_feature_transformer.h

Lines changed: 5 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -152,35 +152,21 @@ class FeatureTransformer {
152152
}
153153

154154
// Read network parameters
155-
// TODO: This is ugly. Currently LEB128 on the entire L1 necessitates
156-
// reading the weights into a combined array, and then splitting.
157155
bool read_parameters(std::istream& stream) {
158-
read_leb_128<BiasType>(stream, biases);
156+
read_leb_128(stream, biases);
159157

160158
if (UseThreats)
161159
{
162160
read_little_endian<ThreatWeightType>(stream, threatWeights.data(),
163161
ThreatInputDimensions * HalfDimensions);
164-
read_leb_128<WeightType>(stream, weights);
162+
read_leb_128(stream, weights);
165163

166-
auto combinedPsqtWeights =
167-
std::make_unique<std::array<PSQTWeightType, TotalInputDimensions * PSQTBuckets>>();
168-
169-
read_leb_128<PSQTWeightType>(stream, *combinedPsqtWeights);
170-
171-
std::copy(combinedPsqtWeights->begin(),
172-
combinedPsqtWeights->begin() + ThreatInputDimensions * PSQTBuckets,
173-
std::begin(threatPsqtWeights));
174-
175-
std::copy(combinedPsqtWeights->begin() + ThreatInputDimensions * PSQTBuckets,
176-
combinedPsqtWeights->begin()
177-
+ (ThreatInputDimensions + InputDimensions) * PSQTBuckets,
178-
std::begin(psqtWeights));
164+
read_leb_128(stream, threatPsqtWeights, psqtWeights);
179165
}
180166
else
181167
{
182-
read_leb_128<WeightType>(stream, weights);
183-
read_leb_128<PSQTWeightType>(stream, psqtWeights);
168+
read_leb_128(stream, weights);
169+
read_leb_128(stream, psqtWeights);
184170
}
185171

186172
permute_weights();

0 commit comments

Comments
 (0)