Skip to content

Commit 702b150

Browse files
committed
Merge in changes mistakenly added to master
1 parent 6e15f22 commit 702b150

File tree

7 files changed

+463
-189
lines changed

7 files changed

+463
-189
lines changed

docs/games.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ Status | Game
4040
🟢 | [First-price Sealed-Bid Auction](https://en.wikipedia.org/wiki/First-price_sealed-bid_auction) | 2-10 | ❌ | ❌ | Agents submit bids simultaneously; highest bid wins, and that's the price paid.
4141
🟢 | [Gin Rummy](https://en.wikipedia.org/wiki/Gin_rummy) | 2 | ❌ | ❌ | Players score points by forming specific sets with the cards in their hands.
4242
🟢 | [Go](https://en.wikipedia.org/wiki/Go_\(game\)) | 2 | ✅ | ✅ | Players place tokens on the board with the goal of encircling territory.
43+
🔶 | [Gomoku](https://en.wikipedia.org/wiki/Gomoku) | 2 | ✅ | ✅ | Try to get 5 stones in a row.
4344
🟢 | [Goofspiel](https://en.wikipedia.org/wiki/Goofspiel) | 2-10 | ❌ | ❌ | Players bid with their cards to win other cards.
4445
🟢 | [Hanabi](https://en.wikipedia.org/wiki/Hanabi_\(card_game\)) | 2-5 | ❌ | ❌ | Players can see only other player's pieces, and everyone must cooperate to win. References: [Bard et al. '19, The Hanabi Challenge: A New Frontier for AI Research](https://arxiv.org/abs/1902.00506). Implemented via [Hanabi Learning Environment](https://github.com/deepmind/hanabi-learning-environment).
4546
🟢 | [Havannah](https://en.wikipedia.org/wiki/Havannah_\(board_game\)) | 2 | ✅ | ✅ | Players add tokens to a hex grid to try and form a winning structure.

open_spiel/games/gomoku/gomoku.cc

Lines changed: 106 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <array>
1919
#include <memory>
2020
#include <string>
21+
#include <utility>
2122
#include <vector>
2223

2324
#include "open_spiel/abseil-cpp/absl/strings/str_cat.h"
@@ -52,12 +53,13 @@ const GameType kGameType{
5253
/*provides_observation_string=*/true,
5354
/*provides_observation_tensor=*/true,
5455
/*parameter_specification=*/
55-
{{"size", GameParameter(kDefaultSize)},
56-
{"dims", GameParameter(kDefaultDims)},
57-
{"connect", GameParameter(kDefaultConnect)},
58-
{"anti", GameParameter(kDefaultAnti)},
59-
{"wrap", GameParameter(kDefaultWrap)}
60-
}
56+
{
57+
{"size", GameParameter(kDefaultSize)},
58+
{"dims", GameParameter(kDefaultDims)},
59+
{"connect", GameParameter(kDefaultConnect)},
60+
{"anti", GameParameter(kDefaultAnti)},
61+
{"wrap", GameParameter(kDefaultWrap)}
62+
}
6163
};
6264

6365
std::shared_ptr<const Game> Factory(const GameParameters& params) {
@@ -83,41 +85,39 @@ int StoneToInt(Stone s) {
8385
case Stone::kBlack: return 0;
8486
case Stone::kWhite: return 1;
8587
}
86-
SpielFatalError("Unknown stone.");
87-
return 0; // never happens
88-
88+
SpielFatalError("Unknown stone.");
89+
return 0; // This never happens
8990
}
9091

91-
9292
GomokuGame::GomokuGame(const GameParameters& params)
9393
: Game(kGameType, params),
9494
size_(ParameterValue<int>("size")),
9595
dims_(ParameterValue<int>("dims")),
9696
connect_(ParameterValue<int>("connect")),
9797
anti_(ParameterValue<bool>("anti")),
9898
wrap_(ParameterValue<bool>("wrap")) {
99-
total_size_ = 1;
100-
for (int i = 0; i < dims_; ++i) {
101-
total_size_ *= size_;
102-
}
103-
strides_.resize(dims_);
99+
total_size_ = 1;
100+
for (int i = 0; i < dims_; ++i) {
101+
total_size_ *= size_;
102+
}
103+
strides_.resize(dims_);
104104
std::size_t stride = 1;
105105
for (int d = dims_ - 1; d >= 0; --d) {
106106
strides_[d] = stride;
107107
stride *= size_;
108108
}
109-
absl::BitGen gen(absl::SeedSeq{52616});
109+
std::mt19937_64 gen(52616);
110110
zobrist_table_.resize(total_size_);
111111
for (int i = 0; i < total_size_; ++i) {
112-
zobrist_table_[i][0] = absl::Uniform<uint64_t>(gen);
113-
zobrist_table_[i][1] = absl::Uniform<uint64_t>(gen);
112+
zobrist_table_[i][0] = gen();
113+
zobrist_table_[i][1] = gen();
114114
}
115-
player_to_move_hash_ = absl::Uniform<uint64_t>(gen);
115+
player_to_move_hash_ = gen();
116116
}
117117

118118

119119
std::vector<int> GomokuGame::UnflattenAction(Action action_id) const {
120-
SPIEL_CHECK_LT(action_id, NumDistinctActions());
120+
SPIEL_CHECK_LT(action_id, NumDistinctActions());
121121
std::vector<int> coord(dims_);
122122
std::size_t index = action_id;
123123

@@ -160,18 +160,18 @@ uint64_t GomokuState::HashValue() const {
160160
}
161161

162162
void GomokuState::DoApplyAction(Action move) {
163-
SPIEL_CHECK_EQ(board_.AtIndex(move), Stone::kEmpty);
163+
SPIEL_CHECK_EQ(board_.AtIndex(move), Stone::kEmpty);
164164
board_.AtIndex(move) = current_player_ == kBlackPlayer
165165
? Stone::kBlack
166166
: Stone::kWhite;
167-
const auto* gomoku = static_cast<const GomokuGame*>(game_.get());
168-
zobrist_hash_ ^=
169-
gomoku->ZobristTable()[move][current_player_];
170-
zobrist_hash_ ^=
171-
gomoku->PlayerToMoveHash();
167+
const auto* gomoku = static_cast<const GomokuGame*>(game_.get());
168+
zobrist_hash_ ^=
169+
gomoku->ZobristTable()[move][current_player_];
170+
zobrist_hash_ ^=
171+
gomoku->PlayerToMoveHash();
172172
current_player_ = 1 - current_player_;
173173
move_count_ += 1;
174-
CheckWinFromLastMove(move);
174+
CheckWinFromLastMove(move);
175175
}
176176

177177
void GomokuState::CheckWinFromLastMove(Action last_move) {
@@ -190,8 +190,9 @@ void GomokuState::CheckWinFromLastMove(Action last_move) {
190190

191191
// forward direction
192192
{
193-
auto c = start;
194-
while (count < connect_ && board_.Step(c, dir) &&
193+
Coord c = start;
194+
while (static_cast<int>(line.size()) < connect_ &&
195+
board_.Step(c, dir) &&
195196
board_.At(c) == stone) {
196197
++count;
197198
}
@@ -202,32 +203,45 @@ void GomokuState::CheckWinFromLastMove(Action last_move) {
202203
auto neg_dir = dir;
203204
for (int& v : neg_dir) v = -v;
204205

205-
auto c = start;
206-
while (count < connect_ && board_.Step(c, neg_dir) &&
206+
Coord c = start;
207+
while (static_cast<int>(line.size()) < connect_ &&
208+
board_.Step(c, neg_dir) &&
207209
board_.At(c) == stone) {
208210
++count;
209211
}
210212
}
211-
// currrent player just moved
212-
if (count >= connect_) {
213-
terminal_ = true;
214-
if (current_player_ == 0){
215-
black_score_ = 1.0;
216-
white_score_ = -1.0;
217-
} else {
218-
black_score_ = -1.0;
219-
white_score_ = 1.0;
220-
}
221-
if (gomoku->Anti()){
222-
black_score_ *= -1;
223-
white_score_ *= -1;
224-
}
225-
return;
213+
if (static_cast<int>(line.size()) >= connect_) {
214+
line.resize(connect_); // arbitrary truncation is fine
215+
return line;
226216
}
227217
}
228-
if (move_count_ + initial_stones_ == board_.NumCells()) {
229-
terminal_ = true;
230-
}
218+
219+
return absl::nullopt;
220+
}
221+
222+
void GomokuState::CheckWinFromLastMove(Action last_move) {
223+
auto maybe_line = FindWinLineFromLastMove(last_move);
224+
if (maybe_line.has_value()) {
225+
terminal_ = true;
226+
winning_line_ = *maybe_line;
227+
// current_player_ flips before this is called.
228+
if (current_player_ == 0) {
229+
black_score_ = -1.0;
230+
white_score_ = 1.0;
231+
} else {
232+
black_score_ = 1.0;
233+
white_score_ = -1.0;
234+
}
235+
}
236+
const auto* gomoku = static_cast<const GomokuGame*>(game_.get());
237+
if (gomoku->Anti()) {
238+
black_score_ *= -1;
239+
white_score_ *= -1;
240+
}
241+
}
242+
243+
const std::vector<Grid<Stone>::Coord>& GomokuState::WinningLine() const {
244+
return winning_line_;
231245
}
232246

233247
std::vector<Action> GomokuState::LegalActions() const {
@@ -270,7 +284,7 @@ std::string GomokuState::ToString() const {
270284
}
271285

272286
std::vector<double> GomokuState::Returns() const {
273-
return {black_score_, white_score_};
287+
return {black_score_, white_score_};
274288
}
275289

276290

@@ -287,12 +301,12 @@ std::string GomokuState::ObservationString(Player player) const {
287301
}
288302

289303
std::vector<int> GomokuGame::ObservationTensorShape() const {
290-
return {2, total_size_};
304+
return {3, total_size_};
291305
}
292306

293307
inline absl::optional<Player> StoneOwner(Stone stone) {
294308
if (stone == Stone::kEmpty) return absl::nullopt;
295-
if (stone == Stone::kBlack) return static_cast<Player>(0);
309+
if (stone == Stone::kBlack) return static_cast<Player>(0);
296310
return static_cast<Player>(1);
297311
}
298312

@@ -302,27 +316,32 @@ void GomokuState::ObservationTensor(Player player,
302316
SPIEL_CHECK_LT(player, num_players_);
303317

304318
const int num_cells = board_.NumCells();
305-
SPIEL_CHECK_EQ(values.size(), 2 * num_cells);
319+
SPIEL_CHECK_EQ(values.size(), 3 * num_cells);
306320

307321
std::fill(values.begin(), values.end(), 0.0f);
308322

309-
const Player opponent = Opponent(player);
310-
323+
// Plane 0: black stones
324+
// Plane 1: white stones
311325
for (int i = 0; i < num_cells; ++i) {
312-
const Stone stone = board_.AtIndex(i);
313-
const auto owner = StoneOwner(stone);
314-
if (owner.has_value()) {
315-
if (*owner == player) {
316-
values[i] = 1.0f;
317-
} else {
318-
values[num_cells + i] = 1.0f;
319-
}
320-
}
326+
const Stone stone = board_.AtIndex(i);
327+
if (stone == Stone::kBlack) {
328+
values[i] = 1.0f;
329+
} else if (stone == Stone::kWhite) {
330+
values[num_cells + i] = 1.0f;
331+
}
332+
}
333+
334+
// Plane 2: player to move
335+
// Convention: 1.0 if Black to move, 0.0 if White to move
336+
if (CurrentPlayer() == Player{0}) {
337+
std::fill(values.begin() + 2 * num_cells,
338+
values.begin() + 3 * num_cells,
339+
1.0f);
321340
}
322341
}
323342

324343
void GomokuState::UndoAction(Player player, Action move) {
325-
board_.AtIndex(move) = Stone::kEmpty;
344+
board_.AtIndex(move) = Stone::kEmpty;
326345
current_player_ = player;
327346
move_count_ -= 1;
328347
history_.pop_back();
@@ -346,11 +365,11 @@ std::string GomokuGame::ActionToString(Player player,
346365

347366

348367
int GomokuGame::MaxGameLength() const {
349-
return total_size_;
368+
return total_size_;
350369
}
351370

352371
int GomokuGame::NumDistinctActions() const {
353-
return total_size_;
372+
return total_size_;
354373
}
355374

356375
GomokuState::GomokuState(std::shared_ptr<const Game> game,
@@ -367,63 +386,53 @@ GomokuState::GomokuState(std::shared_ptr<const Game> game,
367386
current_player_(kBlackPlayer),
368387
move_count_(0),
369388
initial_stones_(0),
370-
black_score_(0.0),
371-
white_score_(0.0){
389+
black_score_(0.0),
390+
white_score_(0.0) {
372391
if (state_str.empty()) {
373392
board_.Fill(Stone::kEmpty);
374393
current_player_ = kBlackPlayer;
375394
return;
376395
}
377-
const auto* gomoku = static_cast<const GomokuGame*>(game_.get());
378-
const std::size_t expected =
396+
const auto* gomoku = static_cast<const GomokuGame*>(game_.get());
397+
const std::size_t expected =
379398
1 + board_.NumCells(); // size^dims
380399

381-
SPIEL_CHECK_EQ(state_str.size(), expected);
400+
SPIEL_CHECK_EQ(state_str.size(), expected);
382401
switch (state_str[0]) {
383402
case 'W':
384403
current_player_ = kWhitePlayer;
385-
break;
386-
case 'B':
387-
current_player_ = kBlackPlayer;
388-
break;
404+
break;
405+
case 'B':
406+
current_player_ = kBlackPlayer;
407+
break;
389408
default:
390409
SpielFatalError("Invalid player char in state string");
391410
}
392-
for (std::size_t i = 0; i < board_.NumCells(); ++i) {
411+
for (std::size_t i = 0; i < board_.NumCells(); ++i) {
393412
char c = state_str[i + 1];
394413
Stone s;
395414

396415
switch (c) {
397416
case 'b':
398-
s = Stone::kBlack;
399-
initial_stones_++;
400-
break;
417+
s = Stone::kBlack;
418+
initial_stones_++;
419+
break;
401420
case 'w':
402-
s = Stone::kWhite;
403-
initial_stones_++;
404-
break;
421+
s = Stone::kWhite;
422+
initial_stones_++;
423+
break;
405424
case '.':
406-
s = Stone::kEmpty;
407-
break;
425+
s = Stone::kEmpty;
426+
break;
408427
case ' ': s = Stone::kEmpty;
409-
break;
428+
break;
410429
default:
411430
SpielFatalError("Invalid board char in state string");
412431
}
413432
board_.AtIndex(i) = s;
414433
}
415-
for (int i = 0; i < board_.NumCells(); ++i) {
416-
if (board_.AtIndex(i) != Stone::kEmpty) {
417-
zobrist_hash_ ^=
418-
gomoku->ZobristTable()[i][StoneToInt(board_.AtIndex(i))];
419-
}
420-
}
421-
if (current_player_ == 1) {
422-
zobrist_hash_ ^= gomoku->PlayerToMoveHash();
423-
}
434+
zobrist_hash_ = ComputeZobrist(board_);
424435
}
425436

426-
427-
428437
} // namespace gomoku
429438
} // namespace open_spiel

0 commit comments

Comments
 (0)