1818#include < array>
1919#include < memory>
2020#include < string>
21+ #include < utility>
2122#include < vector>
2223
2324#include " open_spiel/abseil-cpp/absl/strings/str_cat.h"
@@ -52,12 +53,13 @@ const GameType kGameType{
5253 /* provides_observation_string=*/ true ,
5354 /* provides_observation_tensor=*/ true ,
5455 /* parameter_specification=*/
55- {{" size" , GameParameter (kDefaultSize )},
56- {" dims" , GameParameter (kDefaultDims )},
57- {" connect" , GameParameter (kDefaultConnect )},
58- {" anti" , GameParameter (kDefaultAnti )},
59- {" wrap" , GameParameter (kDefaultWrap )}
60- }
56+ {
57+ {" size" , GameParameter (kDefaultSize )},
58+ {" dims" , GameParameter (kDefaultDims )},
59+ {" connect" , GameParameter (kDefaultConnect )},
60+ {" anti" , GameParameter (kDefaultAnti )},
61+ {" wrap" , GameParameter (kDefaultWrap )}
62+ }
6163};
6264
6365std::shared_ptr<const Game> Factory (const GameParameters& params) {
@@ -83,41 +85,39 @@ int StoneToInt(Stone s) {
8385 case Stone::kBlack : return 0 ;
8486 case Stone::kWhite : return 1 ;
8587 }
86- SpielFatalError (" Unknown stone." );
87- return 0 ; // never happens
88-
88+ SpielFatalError (" Unknown stone." );
89+ return 0 ; // This never happens
8990}
9091
91-
9292GomokuGame::GomokuGame (const GameParameters& params)
9393 : Game(kGameType , params),
9494 size_ (ParameterValue<int >(" size" )),
9595 dims_(ParameterValue<int >(" dims" )),
9696 connect_(ParameterValue<int >(" connect" )),
9797 anti_(ParameterValue<bool >(" anti" )),
9898 wrap_(ParameterValue<bool >(" wrap" )) {
99- total_size_ = 1 ;
100- for (int i = 0 ; i < dims_; ++i) {
101- total_size_ *= size_;
102- }
103- strides_.resize (dims_);
99+ total_size_ = 1 ;
100+ for (int i = 0 ; i < dims_; ++i) {
101+ total_size_ *= size_;
102+ }
103+ strides_.resize (dims_);
104104 std::size_t stride = 1 ;
105105 for (int d = dims_ - 1 ; d >= 0 ; --d) {
106106 strides_[d] = stride;
107107 stride *= size_;
108108 }
109- absl::BitGen gen (absl::SeedSeq{ 52616 } );
109+ std::mt19937_64 gen (52616 );
110110 zobrist_table_.resize (total_size_);
111111 for (int i = 0 ; i < total_size_; ++i) {
112- zobrist_table_[i][0 ] = absl::Uniform< uint64_t >( gen);
113- zobrist_table_[i][1 ] = absl::Uniform< uint64_t >( gen);
112+ zobrist_table_[i][0 ] = gen ( );
113+ zobrist_table_[i][1 ] = gen ( );
114114 }
115- player_to_move_hash_ = absl::Uniform< uint64_t >( gen);
115+ player_to_move_hash_ = gen ( );
116116}
117117
118118
119119std::vector<int > GomokuGame::UnflattenAction (Action action_id) const {
120- SPIEL_CHECK_LT (action_id, NumDistinctActions ());
120+ SPIEL_CHECK_LT (action_id, NumDistinctActions ());
121121 std::vector<int > coord (dims_);
122122 std::size_t index = action_id;
123123
@@ -160,18 +160,18 @@ uint64_t GomokuState::HashValue() const {
160160}
161161
162162void GomokuState::DoApplyAction (Action move) {
163- SPIEL_CHECK_EQ (board_.AtIndex (move), Stone::kEmpty );
163+ SPIEL_CHECK_EQ (board_.AtIndex (move), Stone::kEmpty );
164164 board_.AtIndex (move) = current_player_ == kBlackPlayer
165165 ? Stone::kBlack
166166 : Stone::kWhite ;
167- const auto * gomoku = static_cast <const GomokuGame*>(game_.get ());
168- zobrist_hash_ ^=
169- gomoku->ZobristTable ()[move][current_player_];
170- zobrist_hash_ ^=
171- gomoku->PlayerToMoveHash ();
167+ const auto * gomoku = static_cast <const GomokuGame*>(game_.get ());
168+ zobrist_hash_ ^=
169+ gomoku->ZobristTable ()[move][current_player_];
170+ zobrist_hash_ ^=
171+ gomoku->PlayerToMoveHash ();
172172 current_player_ = 1 - current_player_;
173173 move_count_ += 1 ;
174- CheckWinFromLastMove (move);
174+ CheckWinFromLastMove (move);
175175}
176176
177177void GomokuState::CheckWinFromLastMove (Action last_move) {
@@ -190,8 +190,9 @@ void GomokuState::CheckWinFromLastMove(Action last_move) {
190190
191191 // forward direction
192192 {
193- auto c = start;
194- while (count < connect_ && board_.Step (c, dir) &&
193+ Coord c = start;
194+ while (static_cast <int >(line.size ()) < connect_ &&
195+ board_.Step (c, dir) &&
195196 board_.At (c) == stone) {
196197 ++count;
197198 }
@@ -202,32 +203,45 @@ void GomokuState::CheckWinFromLastMove(Action last_move) {
202203 auto neg_dir = dir;
203204 for (int & v : neg_dir) v = -v;
204205
205- auto c = start;
206- while (count < connect_ && board_.Step (c, neg_dir) &&
206+ Coord c = start;
207+ while (static_cast <int >(line.size ()) < connect_ &&
208+ board_.Step (c, neg_dir) &&
207209 board_.At (c) == stone) {
208210 ++count;
209211 }
210212 }
211- // currrent player just moved
212- if (count >= connect_) {
213- terminal_ = true ;
214- if (current_player_ == 0 ){
215- black_score_ = 1.0 ;
216- white_score_ = -1.0 ;
217- } else {
218- black_score_ = -1.0 ;
219- white_score_ = 1.0 ;
220- }
221- if (gomoku->Anti ()){
222- black_score_ *= -1 ;
223- white_score_ *= -1 ;
224- }
225- return ;
213+ if (static_cast <int >(line.size ()) >= connect_) {
214+ line.resize (connect_); // arbitrary truncation is fine
215+ return line;
226216 }
227217 }
228- if (move_count_ + initial_stones_ == board_.NumCells ()) {
229- terminal_ = true ;
230- }
218+
219+ return absl::nullopt ;
220+ }
221+
222+ void GomokuState::CheckWinFromLastMove (Action last_move) {
223+ auto maybe_line = FindWinLineFromLastMove (last_move);
224+ if (maybe_line.has_value ()) {
225+ terminal_ = true ;
226+ winning_line_ = *maybe_line;
227+ // current_player_ flips before this is called.
228+ if (current_player_ == 0 ) {
229+ black_score_ = -1.0 ;
230+ white_score_ = 1.0 ;
231+ } else {
232+ black_score_ = 1.0 ;
233+ white_score_ = -1.0 ;
234+ }
235+ }
236+ const auto * gomoku = static_cast <const GomokuGame*>(game_.get ());
237+ if (gomoku->Anti ()) {
238+ black_score_ *= -1 ;
239+ white_score_ *= -1 ;
240+ }
241+ }
242+
243+ const std::vector<Grid<Stone>::Coord>& GomokuState::WinningLine () const {
244+ return winning_line_;
231245}
232246
233247std::vector<Action> GomokuState::LegalActions () const {
@@ -270,7 +284,7 @@ std::string GomokuState::ToString() const {
270284}
271285
272286std::vector<double > GomokuState::Returns () const {
273- return {black_score_, white_score_};
287+ return {black_score_, white_score_};
274288}
275289
276290
@@ -287,12 +301,12 @@ std::string GomokuState::ObservationString(Player player) const {
287301}
288302
289303std::vector<int > GomokuGame::ObservationTensorShape () const {
290- return {2 , total_size_};
304+ return {3 , total_size_};
291305}
292306
293307inline absl::optional<Player> StoneOwner (Stone stone) {
294308 if (stone == Stone::kEmpty ) return absl::nullopt ;
295- if (stone == Stone::kBlack ) return static_cast <Player>(0 );
309+ if (stone == Stone::kBlack ) return static_cast <Player>(0 );
296310 return static_cast <Player>(1 );
297311}
298312
@@ -302,27 +316,32 @@ void GomokuState::ObservationTensor(Player player,
302316 SPIEL_CHECK_LT (player, num_players_);
303317
304318 const int num_cells = board_.NumCells ();
305- SPIEL_CHECK_EQ (values.size (), 2 * num_cells);
319+ SPIEL_CHECK_EQ (values.size (), 3 * num_cells);
306320
307321 std::fill (values.begin (), values.end (), 0 .0f );
308322
309- const Player opponent = Opponent (player);
310-
323+ // Plane 0: black stones
324+ // Plane 1: white stones
311325 for (int i = 0 ; i < num_cells; ++i) {
312- const Stone stone = board_.AtIndex (i);
313- const auto owner = StoneOwner (stone);
314- if (owner.has_value ()) {
315- if (*owner == player) {
316- values[i] = 1 .0f ;
317- } else {
318- values[num_cells + i] = 1 .0f ;
319- }
320- }
326+ const Stone stone = board_.AtIndex (i);
327+ if (stone == Stone::kBlack ) {
328+ values[i] = 1 .0f ;
329+ } else if (stone == Stone::kWhite ) {
330+ values[num_cells + i] = 1 .0f ;
331+ }
332+ }
333+
334+ // Plane 2: player to move
335+ // Convention: 1.0 if Black to move, 0.0 if White to move
336+ if (CurrentPlayer () == Player{0 }) {
337+ std::fill (values.begin () + 2 * num_cells,
338+ values.begin () + 3 * num_cells,
339+ 1 .0f );
321340 }
322341}
323342
324343void GomokuState::UndoAction (Player player, Action move) {
325- board_.AtIndex (move) = Stone::kEmpty ;
344+ board_.AtIndex (move) = Stone::kEmpty ;
326345 current_player_ = player;
327346 move_count_ -= 1 ;
328347 history_.pop_back ();
@@ -346,11 +365,11 @@ std::string GomokuGame::ActionToString(Player player,
346365
347366
348367int GomokuGame::MaxGameLength () const {
349- return total_size_;
368+ return total_size_;
350369}
351370
352371int GomokuGame::NumDistinctActions () const {
353- return total_size_;
372+ return total_size_;
354373}
355374
356375GomokuState::GomokuState (std::shared_ptr<const Game> game,
@@ -367,63 +386,53 @@ GomokuState::GomokuState(std::shared_ptr<const Game> game,
367386 current_player_(kBlackPlayer ),
368387 move_count_(0 ),
369388 initial_stones_(0 ),
370- black_score_(0.0 ),
371- white_score_(0.0 ){
389+ black_score_(0.0 ),
390+ white_score_(0.0 ) {
372391 if (state_str.empty ()) {
373392 board_.Fill (Stone::kEmpty );
374393 current_player_ = kBlackPlayer ;
375394 return ;
376395 }
377- const auto * gomoku = static_cast <const GomokuGame*>(game_.get ());
378- const std::size_t expected =
396+ const auto * gomoku = static_cast <const GomokuGame*>(game_.get ());
397+ const std::size_t expected =
379398 1 + board_.NumCells (); // size^dims
380399
381- SPIEL_CHECK_EQ (state_str.size (), expected);
400+ SPIEL_CHECK_EQ (state_str.size (), expected);
382401 switch (state_str[0 ]) {
383402 case ' W' :
384403 current_player_ = kWhitePlayer ;
385- break ;
386- case ' B' :
387- current_player_ = kBlackPlayer ;
388- break ;
404+ break ;
405+ case ' B' :
406+ current_player_ = kBlackPlayer ;
407+ break ;
389408 default :
390409 SpielFatalError (" Invalid player char in state string" );
391410 }
392- for (std::size_t i = 0 ; i < board_.NumCells (); ++i) {
411+ for (std::size_t i = 0 ; i < board_.NumCells (); ++i) {
393412 char c = state_str[i + 1 ];
394413 Stone s;
395414
396415 switch (c) {
397416 case ' b' :
398- s = Stone::kBlack ;
399- initial_stones_++;
400- break ;
417+ s = Stone::kBlack ;
418+ initial_stones_++;
419+ break ;
401420 case ' w' :
402- s = Stone::kWhite ;
403- initial_stones_++;
404- break ;
421+ s = Stone::kWhite ;
422+ initial_stones_++;
423+ break ;
405424 case ' .' :
406- s = Stone::kEmpty ;
407- break ;
425+ s = Stone::kEmpty ;
426+ break ;
408427 case ' ' : s = Stone::kEmpty ;
409- break ;
428+ break ;
410429 default :
411430 SpielFatalError (" Invalid board char in state string" );
412431 }
413432 board_.AtIndex (i) = s;
414433 }
415- for (int i = 0 ; i < board_.NumCells (); ++i) {
416- if (board_.AtIndex (i) != Stone::kEmpty ) {
417- zobrist_hash_ ^=
418- gomoku->ZobristTable ()[i][StoneToInt (board_.AtIndex (i))];
419- }
420- }
421- if (current_player_ == 1 ) {
422- zobrist_hash_ ^= gomoku->PlayerToMoveHash ();
423- }
434+ zobrist_hash_ = ComputeZobrist (board_);
424435}
425436
426-
427-
428437} // namespace gomoku
429438} // namespace open_spiel
0 commit comments