Skip to content

Commit aa109f7

Browse files
🦈 IMP: Avoid Duplicating Neural Network Weights and Biases
Bench: 4929964
1 parent 3379740 commit aa109f7

File tree

3 files changed

+42
-19
lines changed

3 files changed

+42
-19
lines changed

‎CMakeLists.txt‎

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ include(${CPM_DOWNLOAD_LOCATION})
8484
CPMAddPackage(
8585
NAME MantaRay
8686
GITHUB_REPOSITORY TheBlackPlague/MantaRay
87-
GIT_TAG 1c547d179c7cdf4f24e4f8a1fea228c6178d0ae2
87+
GIT_TAG ddc063e1fa688ca3bc9de793e4d5a6813d901289
8888
OPTIONS
8989
"BUILD_TEST OFF"
9090
"BUILD_MB OFF"

‎src/Engine/Evaluation.h‎

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,13 @@ namespace StockDory
2121
class Evaluation
2222
{
2323

24-
static inline std::vector<Aurora> NNs;
24+
static inline Aurora NN = [] -> Aurora
25+
{
26+
MantaRay::BinaryMemoryStream stream (_NeuralNetworkBinaryData, sizeof _NeuralNetworkBinaryData);
27+
return Aurora(stream);
28+
}();
29+
30+
static inline std::vector<AuroraStack> ThreadLocalStack;
2531

2632
public:
2733
static std::string Name()
@@ -31,58 +37,70 @@ namespace StockDory
3137

3238
static void Initialize()
3339
{
34-
const size_t threadCount = ThreadPool.Size() + 1;
40+
const size_t threadCount = ThreadPool.Size();
3541

36-
NNs.clear();
37-
NNs.reserve(threadCount);
42+
ThreadLocalStack.clear();
43+
ThreadLocalStack.reserve(threadCount);
3844

3945
for (size_t i = 0; i < threadCount; i++) {
40-
MantaRay::BinaryMemoryStream stream (_NeuralNetworkBinaryData, sizeof _NeuralNetworkBinaryData);
41-
NNs.emplace_back(stream);
46+
ThreadLocalStack.emplace_back();
47+
NN.Refresh(*ThreadLocalStack[i]);
4248
}
4349
}
4450

4551
static void ResetNetworkState(const size_t threadId = 0)
4652
{
47-
NNs[threadId].Reset();
48-
NNs[threadId].Refresh();
53+
AuroraStack& stack = ThreadLocalStack[threadId];
54+
55+
stack.Reset();
56+
NN.Refresh(*stack);
4957
}
5058

5159
[[clang::always_inline]]
5260
static void PreMove(const size_t threadId = 0)
5361
{
54-
NNs[threadId].Push();
62+
AuroraStack& stack = ThreadLocalStack[threadId];
63+
stack++;
5564
}
5665

5766
[[clang::always_inline]]
5867
static void PreUndoMove(const size_t threadId = 0)
5968
{
60-
NNs[threadId].Pop();
69+
AuroraStack& stack = ThreadLocalStack[threadId];
70+
stack--;
6171
}
6272

6373
[[clang::always_inline]]
6474
static void Activate(const Piece piece, const Color color, const Square sq, const size_t threadId = 0)
6575
{
66-
NNs[threadId].Insert(piece, color, sq);
76+
AuroraStack& stack = ThreadLocalStack[threadId];
77+
78+
NN.Insert(piece, color, sq, *stack);
6779
}
6880

6981
[[clang::always_inline]]
7082
static void Deactivate(const Piece piece, const Color color, const Square sq, const size_t threadId = 0)
7183
{
72-
NNs[threadId].Remove(piece, color, sq);
84+
AuroraStack& stack = ThreadLocalStack[threadId];
85+
86+
NN.Remove(piece, color, sq, *stack);
7387
}
7488

7589
[[clang::always_inline]]
7690
static void Transition(const Piece piece, const Color color, const Square from, const Square to,
7791
const size_t threadId = 0)
7892
{
79-
NNs[threadId].Move(piece, color, from, to);
93+
AuroraStack& stack = ThreadLocalStack[threadId];
94+
95+
NN.Move(piece, color, from, to, *stack);
8096
}
8197

8298
[[clang::always_inline]]
8399
static Score Evaluate(const Color color, const size_t threadId = 0)
84100
{
85-
return NNs[threadId].Evaluate(color);
101+
AuroraStack& stack = ThreadLocalStack[threadId];
102+
103+
return NN.Evaluate(color, *stack);
86104
}
87105

88106
};

‎src/Engine/NetworkArchitecture.h‎

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,18 +8,23 @@
88

99
#include <MantaRay/Backend/Kernel/Activation/ClippedReLU.h>
1010
#include <MantaRay/Frontend/Architecture/Perspective.h>
11+
#include <MantaRay/Frontend/Architecture/Common/AccumulatorStack.h>
1112

1213
// Activation Function:
1314
constexpr auto ClippedReLU = &MantaRay::ClippedReLU<MantaRay::i16, 0, 255>::Activate;
1415

15-
constexpr size_t AccumulatorStackSize = StockDory::MaxDepth * 4;
16-
1716
// Architecture:
1817
using Starshard = MantaRay::Perspective<
19-
MantaRay::i16, MantaRay::i32, ClippedReLU, 768, 256, 1, AccumulatorStackSize, 400, 255, 64
18+
MantaRay::i16, MantaRay::i32, ClippedReLU, 768, 256, 1, 400, 255, 64
2019
>;
2120
using Aurora = MantaRay::Perspective<
22-
MantaRay::i16, MantaRay::i32, ClippedReLU, 768, 384, 1, AccumulatorStackSize, 400, 255, 64
21+
MantaRay::i16, MantaRay::i32, ClippedReLU, 768, 384, 1, 400, 255, 64
2322
>;
2423

24+
// Accumulator Stack:
25+
constexpr size_t AccumulatorStackSize = StockDory::MaxDepth * 4;
26+
27+
using StarshardStack = MantaRay::AccumulatorStack<MantaRay::i16, 256, AccumulatorStackSize>;
28+
using AuroraStack = MantaRay::AccumulatorStack<MantaRay::i16, 384, AccumulatorStackSize>;
29+
2530
#endif //STOCKDORY_NETWORKARCHITECTURE_H

0 commit comments

Comments
 (0)