-
Notifications
You must be signed in to change notification settings - Fork 677
Expand file tree
/
Copy pathnneval.h
More file actions
277 lines (234 loc) · 8.66 KB
/
nneval.h
File metadata and controls
277 lines (234 loc) · 8.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
#ifndef NEURALNET_NNEVAL_H_
#define NEURALNET_NNEVAL_H_
#include <memory>
#include "../core/global.h"
#include "../core/commontypes.h"
#include "../core/logger.h"
#include "../core/multithread.h"
#include "../core/threadsafequeue.h"
#include "../game/board.h"
#include "../game/boardhistory.h"
#include "../neuralnet/nninputs.h"
#include "../neuralnet/sgfmetadata.h"
#include "../neuralnet/nninterface.h"
#include "../search/mutexpool.h"
class NNEvaluator;
class NNCacheTable {
struct Entry {
std::shared_ptr<NNOutput> ptr;
Entry();
~Entry();
};
Entry* entries;
MutexPool* mutexPool;
uint64_t tableSize;
uint64_t tableMask;
uint32_t mutexPoolMask;
public:
NNCacheTable(int sizePowerOfTwo, int mutexPoolSizePowerOfTwo);
~NNCacheTable();
NNCacheTable(const NNCacheTable& other) = delete;
NNCacheTable& operator=(const NNCacheTable& other) = delete;
//These are thread-safe. For get, ret will be set to nullptr upon a failure to find.
bool get(Hash128 nnHash, std::shared_ptr<NNOutput>& ret);
void set(const std::shared_ptr<NNOutput>& p);
void clear();
};
//Each thread should allocate and re-use one of these
struct NNResultBuf {
std::condition_variable clientWaitingForResult;
std::mutex resultMutex;
bool hasResult;
bool includeOwnerMap;
int boardXSizeForServer;
int boardYSizeForServer;
std::vector<float> rowSpatialBuf;
std::vector<float> rowGlobalBuf;
std::vector<float> rowMetaBuf;
bool hasRowMeta;
std::shared_ptr<NNOutput> result;
bool errorLogLockout; //error flag to restrict log to 1 error to prevent spam
int symmetry; //The symmetry to use for this eval
double policyOptimism; //The policy optimism to use for this eval
NNResultBuf();
~NNResultBuf();
NNResultBuf(const NNResultBuf& other) = delete;
NNResultBuf& operator=(const NNResultBuf& other) = delete;
};
//Each server thread should allocate and re-use one of these
struct NNServerBuf {
InputBuffers* inputBuffers;
NNServerBuf(const NNEvaluator& nneval, const LoadedModel* model);
~NNServerBuf();
NNServerBuf(const NNServerBuf& other) = delete;
NNServerBuf& operator=(const NNServerBuf& other) = delete;
};
class ONNXModelHeader;
class NNEvaluator {
public:
NNEvaluator(
const std::string& modelName,
const std::string& modelFileName,
const std::string& expectedSha256,
Logger* logger,
int maxBatchSize,
int nnXLen,
int nnYLen,
bool requireExactNNLen,
bool inputsUseNHWC,
int nnCacheSizePowerOfTwo,
int nnMutexPoolSizePowerofTwo,
bool debugSkipNeuralNet,
const std::string& openCLTunerFile,
const std::string& homeDataDirOverride,
bool openCLReTunePerBoardSize,
enabled_t useFP16Mode,
enabled_t useNHWCMode,
int numThreads,
const std::vector<int>& gpuIdxByServerThread,
const std::string& randSeed,
bool doRandomize,
int defaultSymmetry
);
~NNEvaluator();
NNEvaluator(const NNEvaluator& other) = delete;
NNEvaluator& operator=(const NNEvaluator& other) = delete;
std::string getModelName() const;
std::string getModelFileName() const;
std::string getInternalModelName() const;
std::string getAbbrevInternalModelName() const;
Logger* getLogger();
bool isNeuralNetLess() const;
int getMaxBatchSize() const;
int getCurrentBatchSize() const;
void setCurrentBatchSize(int batchSize);
bool requiresSGFMetadata() const;
int getNumGpus() const;
int getNumServerThreads() const;
std::set<int> getGpuIdxs() const;
int getNNXLen() const;
int getNNYLen() const;
int getModelVersion() const;
double getTrunkSpatialConvDepth() const;
enabled_t getUsingFP16Mode() const;
enabled_t getUsingNHWCMode() const;
//Check if the loaded neural net supports shorttermError fields
bool supportsShorttermError() const;
//Return the "nearest" supported ruleset to desiredRules by this model.
//Fills supported with true if desiredRules itself was exactly supported, false if some modifications had to be made.
Rules getSupportedRules(const Rules& desiredRules, bool& supported);
//Clear all entires cached in the table
void clearCache();
//Queue a position for the next neural net batch evaluation and wait for it. Upon evaluation, result
//will be supplied in NNResultBuf& buf, the shared_ptr there can grabbed via std::move if desired.
//logStream is for some error logging, can be NULL.
//This function is threadsafe.
void evaluate(
Board& board,
const BoardHistory& history,
Player nextPlayer,
const MiscNNInputParams& nnInputParams,
NNResultBuf& buf,
bool skipCache,
bool includeOwnerMap
);
void evaluate(
Board& board,
const BoardHistory& history,
Player nextPlayer,
const SGFMetadata* sgfMeta,
const MiscNNInputParams& nnInputParams,
NNResultBuf& buf,
bool skipCache,
bool includeOwnerMap
);
std::shared_ptr<NNOutput>* averageMultipleSymmetries(
Board& board,
const BoardHistory& history,
Player nextPlayer,
const SGFMetadata* sgfMeta,
const MiscNNInputParams& baseNNInputParams,
NNResultBuf& buf,
bool includeOwnerMap,
Rand& rand,
int numSymmetriesToSample
);
//If there is at least one evaluate ongoing, wait until at least one finishes.
//Returns immediately if there isn't one ongoing right now.
void waitForNextNNEvalIfAny();
//Actually spawn threads to handle evaluations.
//If doRandomize, uses randSeed as a seed, further randomized per-thread
//If not doRandomize, uses defaultSymmetry for all nn evaluations, unless a symmetry is requested in MiscNNInputParams.
//This function itself is not threadsafe.
void spawnServerThreads();
//Kill spawned server threads and join and free them. This function is not threadsafe, and along with spawnServerThreads
//should have calls to it and spawnServerThreads singlethreaded.
void killServerThreads();
//Set the number of threads and what gpus they use. Only call this if threads are not spawned yet, or have been killed.
void setNumThreads(const std::vector<int>& gpuIdxByServerThr);
//After spawnServerThreads has returned, check if is was using FP16.
bool isAnyThreadUsingFP16() const;
//These are thread-safe. Setting them in the middle of operation might only affect future
//neural net evals, rather than any in-flight.
bool getDoRandomize() const;
int getDefaultSymmetry() const;
void setDoRandomize(bool b);
void setDefaultSymmetry(int s);
//Some stats
uint64_t numRowsProcessed() const;
uint64_t numBatchesProcessed() const;
double averageProcessedBatchSize() const;
void clearStats();
private:
const std::string modelName;
const std::string modelFileName;
int nnXLen;
int nnYLen;
bool requireExactNNLen;
int policySize;
const bool inputsUseNHWC;
const enabled_t usingFP16Mode;
const enabled_t usingNHWCMode;
int numThreads;
std::vector<int> gpuIdxByServerThread;
const std::string randSeed;
const bool debugSkipNeuralNet;
ComputeContext* computeContext;
LoadedModel* loadedModel;
NNCacheTable* nnCacheTable;
Logger* logger;
std::string internalModelName;
int modelVersion;
int inputsVersion;
int numInputMetaChannels;
ModelPostProcessParams postProcessParams;
int numServerThreadsEverSpawned;
std::vector<std::thread*> serverThreads;
const int maxBatchSize;
//Counters for statistics
std::atomic<uint64_t> m_numRowsProcessed;
std::atomic<uint64_t> m_numBatchesProcessed;
mutable std::mutex bufferMutex;
//Everything in this section is protected under bufferMutex--------------------------------------------
bool isKilled; //Flag used for killing server threads
int numServerThreadsStartingUp; //Counter for waiting until server threads are spawned
std::condition_variable mainThreadWaitingForSpawn; //Condvar for waiting until server threads are spawned
std::vector<int> serverThreadsIsUsingFP16;
int numOngoingEvals; //Current number of ongoing evals.
int numWaitingEvals; //Current number of things waiting for finish.
int numEvalsToAwaken; //Current number of things waitingForFinish that should be woken up. Used to avoid spurious wakeups.
std::condition_variable waitingForFinish; //Condvar for waiting for at least one ongoing eval to finish.
//-------------------------------------------------------------------------------------------------
//Randomization settings for symmetries
std::atomic<bool> currentDoRandomize;
std::atomic<int> currentDefaultSymmetry;
//Modifiable batch size smaller than maxBatchSize
std::atomic<int> currentBatchSize;
//Queued up requests
ThreadSafeQueue<NNResultBuf*> queryQueue;
friend class ONNXModelHeader;
public:
//Helper, for internal use only
void serve(NNServerBuf& buf, Rand& rand, int gpuIdxForThisThread, int serverThreadIdx);
};
#endif // NEURALNET_NNEVAL_H_