Skip to content

Commit 9b8da21

Browse files
committed
Update utils
Bench 7992749
1 parent 8a126e2 commit 9b8da21

File tree

6 files changed

+83
-67
lines changed

6 files changed

+83
-67
lines changed

src/utils/DumpGames.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#include "Common.hpp"
2+
#include "GameCollection.hpp"
3+
4+
#include <filesystem>
5+
#include <fstream>
6+
7+
static bool DumpGames(const std::string& path)
8+
{
9+
std::vector<Move> moves;
10+
11+
if (!std::filesystem::exists(path))
12+
{
13+
return false;
14+
}
15+
16+
FileInputStream gamesFile(path.c_str());
17+
if (!gamesFile.IsOpen())
18+
{
19+
std::cout << "ERROR: Failed to load selfplay data file: " << path << std::endl;
20+
return false;
21+
}
22+
23+
Game game;
24+
while (GameCollection::ReadGame(gamesFile, game, moves))
25+
{
26+
std::cout << game.ToPGN() << std::endl << std::endl;
27+
}
28+
29+
return true;
30+
}
31+
32+
void DumpGames(const std::vector<std::string>& args)
33+
{
34+
for (const auto& path : args)
35+
{
36+
DumpGames(path);
37+
}
38+
}

src/utils/Main.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ extern bool RunPerformanceTests(const std::vector<std::string>& paths);
1212
extern void SelfPlay(const std::vector<std::string>& args);
1313
extern void PrepareTrainingData(const std::vector<std::string>& args);
1414
extern void PlainTextToTrainingData(const std::vector<std::string>& args);
15+
extern void DumpGames(const std::vector<std::string>& args);
1516
extern void GenerateEndgamePositions();
1617
extern bool TestNetwork();
1718
extern bool TrainNetwork();
@@ -65,6 +66,8 @@ int main(int argc, const char* argv[])
6566
PrepareTrainingData(args);
6667
else if (toolName == "plainTextToTrainingData")
6768
PlainTextToTrainingData(args);
69+
else if (toolName == "dumpGames")
70+
DumpGames(args);
6871
else if (toolName == "testNetwork")
6972
TestNetwork();
7073
else if (toolName == "validateEndgame")

src/utils/NetworkTrainer.cpp

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ using namespace threadpool;
3939
static const uint32_t cMaxIterations = 1'000'000'000;
4040
static const uint32_t cNumTrainingVectorsPerIteration = 512 * 1024;
4141
static const uint32_t cNumValidationVectorsPerIteration = 128 * 1024;
42-
static const uint32_t cBatchSize = 64 * 1024;
42+
static const uint32_t cBatchSize = 32 * 1024;
4343
#ifdef USE_VIRTUAL_FEATURES
4444
static const uint32_t cNumVirtualFeatures = 12 * 64;
4545
#endif // USE_VIRTUAL_FEATURES
@@ -443,6 +443,8 @@ void NetworkTrainer::Validate(size_t iteration)
443443
"8/8/8/5B1p/5p1r/4kP2/6K1/8 w - - 0 1", // should be 0
444444
"8/8/8/p7/K5R1/1n6/1k1r4/8 w - - 0 1", // should be 0
445445
"8/8/2k3N1/8/Nn2N3/4K3/8/7n w - - 0 1", // should be 1
446+
"rnbqk1nr/3p1pbp/p1pPp1p1/PpP5/1P6/8/4PPPP/1NBQKBNR w kq - 1 9", // should be 1?
447+
"rn1qkbnr/pbp1p3/1p1pPp1p/5PpP/6P1/8/PPPP4/RNBQKBN1 w Qkq - 1 9", // should be 1?
446448
};
447449

448450
for (const char* testPosition : s_testPositions)
@@ -651,13 +653,13 @@ bool NetworkTrainer::UnpackNetwork()
651653

652654
static volatile float g_learningRateScale = 0.5f;
653655
static volatile float g_lambdaScale = 0.0f;
654-
static volatile float g_weightDecay = 1.0f / 512.0f;
656+
static volatile float g_weightDecay = 0.0f; // 1.0f / 512.0f;
655657

656658
bool NetworkTrainer::Train()
657659
{
658660
InitNetwork();
659661

660-
if (!m_packedNet.LoadFromFile("eval-61.pnn"))
662+
if (!m_packedNet.LoadFromFile("eval-64-38-9.pnn"))
661663
{
662664
std::cout << "ERROR: Failed to load packed network" << std::endl;
663665
return false;
@@ -688,7 +690,7 @@ bool NetworkTrainer::Train()
688690
size_t epoch = 0;
689691
for (size_t iteration = 0; iteration < cMaxIterations; ++iteration)
690692
{
691-
const float warmup = iteration < 20.0f ? (float)(iteration + 1) / 20.0f : 1.0f;
693+
const float warmup = iteration < 50.0f ? (float)(iteration + 1) / 50.0f : 1.0f;
692694
const float learningRate = g_learningRateScale * warmup * std::lerp(minLearningRate, maxLearningRate, expf(-0.0005f * (float)iteration));
693695
const float lambda = g_lambdaScale * std::lerp(minLambda, maxLambda, expf(-0.0005f * (float)iteration));
694696

src/utils/PrepareTrainingData.cpp

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,8 @@ using namespace threadpool;
1919

2020
static std::mutex g_mutex;
2121

22-
static constexpr int32_t c_ScoreTreshold = 1000;
23-
static constexpr int32_t c_EvalTreshold = 500;
22+
static constexpr int32_t c_ScoreTreshold = 1600;
23+
static constexpr int32_t c_EvalTreshold = 800;
2424

2525
static bool IsPositionImbalanced(const Position& pos, ScoreType moveScore)
2626
{
@@ -87,9 +87,8 @@ static bool ConvertGamesToTrainingData(const std::string& inputPath, const std::
8787

8888
if (move.IsQuiet() && // best move must be quiet
8989
pos.GetNumPieces() >= 4 && // skip known endgames
90-
//(i + 1 >= game.GetMoves().size() || moves[i + 1].IsQuiet()) && // next best move must be quiet
91-
!pos.IsInCheck() && // skip check positions
92-
!IsPositionImbalanced(pos, moveScore)) // skip imbalanced positions
90+
!pos.IsInCheck() /* && // skip check positions
91+
!IsPositionImbalanced(pos, moveScore)*/) // skip imbalanced positions
9392
{
9493
PositionEntry entry{};
9594

src/utils/SelfPlay.cpp

Lines changed: 24 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,16 @@
2424

2525
static const bool randomizeOrder = true;
2626
static const uint32_t c_printPgnFrequency = 1;
27-
static const uint32_t c_minNodes = 80000;
28-
static const uint32_t c_maxNodes = 80000;
29-
static const uint32_t c_maxDepth = 32;
30-
static const int32_t c_maxEval = 3000;
31-
static const int32_t c_openingMaxEval = 1000;
32-
static const int32_t c_multiPv = 1;
33-
static const int32_t c_multiPvMaxPly = 0;
34-
static const int32_t c_multiPvScoreTreshold = 50;
35-
static const uint32_t c_minRandomMoves = 2;
36-
static const uint32_t c_maxRandomMoves = 2;
27+
28+
static const uint32_t c_minNodes = 25'000;
29+
static const uint32_t c_maxNodes = 30'000;
30+
static const uint32_t c_maxDepth = 40;
31+
32+
static const int32_t c_maxEval = 2000;
33+
static const int32_t c_openingMaxEval = 300;
34+
35+
static const uint32_t c_minRandomMoves = 6;
36+
static const uint32_t c_maxRandomMoves = 10;
3737

3838
bool LoadOpeningPositions(const std::string& path, std::vector<PackedPosition>& outPositions)
3939
{
@@ -65,8 +65,6 @@ bool LoadOpeningPositions(const std::string& path, std::vector<PackedPosition>&
6565
outPositions.push_back(packedPos);
6666
}
6767

68-
std::cout << "Loaded " << outPositions.size() << " opening positions" << std::endl;
69-
7068
return true;
7169
}
7270

@@ -107,7 +105,7 @@ static bool SelfPlayThreadFunc(
107105
const std::vector<PackedPosition>& openingPositions,
108106
SelfPlayStats& stats)
109107
{
110-
const size_t c_transpositionTableSize = 2ull * 1024ull * 1024ull;
108+
const size_t c_transpositionTableSize = 4ull * 1024ull * 1024ull;
111109

112110
std::random_device rd;
113111
std::mt19937 gen(rd());
@@ -175,7 +173,6 @@ static bool SelfPlayThreadFunc(
175173
search.Clear();
176174
game.Reset(openingPos);
177175

178-
int32_t multiPvScoreTreshold = c_multiPvScoreTreshold;
179176
int32_t halfMoveNumber = 0;
180177
uint32_t drawScoreCounter = 0;
181178
uint32_t whiteWinsCounter = 0;
@@ -190,9 +187,9 @@ static bool SelfPlayThreadFunc(
190187
searchParam.useRootTablebase = false;
191188
searchParam.evalRandomization = 1;
192189
searchParam.seed = searchSeed;
193-
searchParam.numPvLines = (halfMoveNumber < c_multiPvMaxPly) ? c_multiPv : 1;
194190
searchParam.limits.maxDepth = c_maxDepth;
195191
searchParam.limits.maxNodesSoft = c_minNodes + (c_maxNodes - c_minNodes) * std::max(0, 80 - halfMoveNumber) / 80;
192+
if (halfMoveNumber < 10) searchParam.limits.maxNodesSoft *= 2; // more nodes in the first moves
196193
searchParam.limits.maxNodes = 5 * searchParam.limits.maxNodesSoft;
197194

198195
searchResult.clear();
@@ -202,39 +199,13 @@ static bool SelfPlayThreadFunc(
202199
ASSERT(!searchResult.empty());
203200

204201
// skip game if starting position is unbalanced
205-
if (halfMoveNumber == 0 && std::abs(searchResult.begin()->score) > c_openingMaxEval)
202+
if (halfMoveNumber == 0 && std::abs(searchResult.begin()->score) * 100 / wld::NormalizeToPawnValue > c_openingMaxEval)
206203
break;
207204

208-
// sort moves by score
209-
std::sort(searchResult.begin(), searchResult.end(), [](const PvLine& a, const PvLine& b)
210-
{
211-
return a.score > b.score;
212-
});
205+
ASSERT(!searchResult.front().moves.empty());
206+
Move move = searchResult.front().moves.front();
213207

214-
// if one of the move is much worse than the best candidate, ignore it and the rest
215-
for (size_t i = 1; i < searchResult.size(); ++i)
216-
{
217-
ASSERT(searchResult[i].score <= searchResult[0].score);
218-
const int32_t diff = std::abs((int32_t)searchResult[i].score - (int32_t)searchResult[0].score);
219-
if (diff > multiPvScoreTreshold)
220-
{
221-
searchResult.erase(searchResult.begin() + i, searchResult.end());
222-
break;
223-
}
224-
}
225-
226-
// select random move
227-
// TODO prefer moves with higher score
228-
std::uniform_int_distribution<size_t> distrib(0, searchResult.size() - 1);
229-
const size_t moveIndex = distrib(gen);
230-
ASSERT(!searchResult[moveIndex].moves.empty());
231-
Move move = searchResult[moveIndex].moves.front();
232-
233-
// reduce threshold of picking worse move
234-
// this way the game will be more random at the beginning and there will be less blunders later in the game
235-
multiPvScoreTreshold = std::max(10, multiPvScoreTreshold - 2);
236-
237-
ScoreType moveScore = searchResult[moveIndex].score;
208+
ScoreType moveScore = searchResult.front().score;
238209
ScoreType eval = Evaluate(game.GetPosition());
239210

240211
if (game.GetSideToMove() == Black)
@@ -261,12 +232,12 @@ static bool SelfPlayThreadFunc(
261232
}
262233

263234
// adjudicate win
264-
if (halfMoveNumber >= 20)
235+
if (halfMoveNumber >= 40)
265236
{
266237
if (moveScore > c_maxEval && eval > c_maxEval / 4)
267238
{
268239
whiteWinsCounter++;
269-
if (whiteWinsCounter > 3) game.SetScore(Game::Score::WhiteWins);
240+
if (whiteWinsCounter > 4) game.SetScore(Game::Score::WhiteWins);
270241
}
271242
else
272243
{
@@ -276,7 +247,7 @@ static bool SelfPlayThreadFunc(
276247
if (moveScore < -c_maxEval && eval < -c_maxEval / 4)
277248
{
278249
blackWinsCounter++;
279-
if (blackWinsCounter > 3) game.SetScore(Game::Score::BlackWins);
250+
if (blackWinsCounter > 4) game.SetScore(Game::Score::BlackWins);
280251
}
281252
else
282253
{
@@ -311,12 +282,12 @@ static bool SelfPlayThreadFunc(
311282
// save game
312283
if (halfMoveNumber > 0)
313284
{
314-
writer.WriteGame(game);
315-
316285
GameMetadata metadata;
317286
metadata.roundNumber = index;
318287
game.SetMetadata(metadata);
319288

289+
writer.WriteGame(game);
290+
320291
if (threadIndex == 0 && c_printPgnFrequency != 0 && (index % c_printPgnFrequency == 0))
321292
{
322293
const std::string pgn = game.ToPGN(true);
@@ -353,10 +324,11 @@ void SelfPlay(const std::vector<std::string>& args)
353324

354325
std::cout << "Loading opening positions..." << std::endl;
355326
std::vector<PackedPosition> openingPositions;
356-
if (!args.empty())
327+
for (const std::string& path : args)
357328
{
358-
LoadOpeningPositions(args[0], openingPositions);
329+
LoadOpeningPositions(path, openingPositions);
359330
}
331+
std::cout << "Loaded " << openingPositions.size() << " opening positions" << std::endl;
360332

361333
alignas(CACHELINE_SIZE) SelfPlayStats stats;
362334

src/utils/TrainerCommon.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -156,10 +156,9 @@ bool TrainingDataLoader::InputFileContext::FetchNextPosition(std::mt19937& gen,
156156
}
157157
else
158158
{
159-
// skip drawn game based half-move counter
160-
if (outEntry.wdlScore == (uint8_t)Game::Score::Draw)
159+
// skip based on half-move counter
161160
{
162-
const float hmcSkipProb = (float)outEntry.pos.halfMoveCount / 100.0f;
161+
const float hmcSkipProb = sqrtf((float)outEntry.pos.halfMoveCount / 100.0f);
163162
std::bernoulli_distribution skippingDistr(hmcSkipProb);
164163
if (skippingDistr(gen))
165164
continue;
@@ -176,9 +175,12 @@ bool TrainingDataLoader::InputFileContext::FetchNextPosition(std::mt19937& gen,
176175
if (numPieces <= 3)
177176
continue;
178177

179-
//const float pieceCountSkipProb = Sqr(static_cast<float>(numPieces - 28) / 40.0f);
180-
//if (pieceCountSkipProb > 0.0f && std::bernoulli_distribution(pieceCountSkipProb)(gen))
181-
// continue;
178+
if (CheckInsufficientMaterial(outPosition))
179+
continue;
180+
181+
const float pieceCountSkipProb = Sqr(static_cast<float>(numPieces - 26) / 50.0f);
182+
if (pieceCountSkipProb > 0.0f && std::bernoulli_distribution(pieceCountSkipProb)(gen))
183+
continue;
182184
}
183185
}
184186

0 commit comments

Comments
 (0)