@@ -174,42 +174,44 @@ void GomokuState::DoApplyAction(Action move) {
174174 CheckWinFromLastMove (move);
175175}
176176
177- void GomokuState::CheckWinFromLastMove (Action last_move) {
177+ absl::optional<std::vector<Grid<Stone>::Coord>>
178+ GomokuState::FindWinLineFromLastMove (Action last_move) const {
179+ using Coord = Grid<Stone>::Coord;
178180
179- const Grid<Stone>::Coord start =
180- board_.Unflatten (last_move);
181+ const Coord start = board_.Unflatten (last_move);
181182 const Stone stone = board_.At (start);
182- const auto * gomoku = static_cast <const GomokuGame*>(game_.get ());
183183
184184 SPIEL_CHECK_NE (stone, Stone::kEmpty );
185185
186186 for (const auto & dir : board_.Directions ()) {
187187 if (!board_.IsCanonical (dir)) continue ;
188188
189- int count = 1 ; // include the starting stone
189+ std::vector<Coord> line;
190+ line.push_back (start);
190191
191- // forward direction
192+ // forward
192193 {
193194 Coord c = start;
194195 while (static_cast <int >(line.size ()) < connect_ &&
195196 board_.Step (c, dir) &&
196197 board_.At (c) == stone) {
197- ++count ;
198+ line. push_back (c) ;
198199 }
199200 }
200201
201- // backward direction
202+ // backward
202203 {
203- auto neg_dir = dir;
204+ Coord neg_dir = dir;
204205 for (int & v : neg_dir) v = -v;
205206
206207 Coord c = start;
207208 while (static_cast <int >(line.size ()) < connect_ &&
208209 board_.Step (c, neg_dir) &&
209210 board_.At (c) == stone) {
210- ++count ;
211+ line. insert (line. begin (), c) ;
211212 }
212213 }
214+
213215 if (static_cast <int >(line.size ()) >= connect_) {
214216 line.resize (connect_); // arbitrary truncation is fine
215217 return line;
@@ -372,6 +374,57 @@ int GomokuGame::NumDistinctActions() const {
372374 return total_size_;
373375}
374376
377+ uint64_t GomokuState::ComputeZobrist (
378+ const Grid<Stone>& grid) const {
379+ const auto * gomoku = static_cast <const GomokuGame*>(game_.get ());
380+ uint64_t h = 0 ;
381+ for (int i = 0 ; i < grid.NumCells (); ++i) {
382+ Stone s = grid.AtIndex (i);
383+ if (s != Stone::kEmpty ) {
384+ h ^= gomoku->ZobristTable ()[i][StoneToInt (s)];
385+ }
386+ }
387+ if (current_player_ == 1 ) {
388+ h ^= gomoku->PlayerToMoveHash ();
389+ }
390+ return h;
391+ }
392+
393+ uint64_t GomokuState::SymmetricHash () const {
394+ uint64_t best = ComputeZobrist (board_);
395+
396+ // --- Basic rotations ---
397+ for (auto [i, j] : board_.GenRotations ()) {
398+ Grid<Stone> rotated = board_;
399+ for (int r = 0 ; r < 3 ; ++r) { // 90, 180, 270
400+ rotated = rotated.ApplyRotation (i, j);
401+ best = std::min (best, ComputeZobrist (rotated));
402+ }
403+ }
404+
405+ // --- Reflections ---
406+ if (symmetry_policy_.allow_reflections ) {
407+ for (int axis = 0 ; axis < dims_; ++axis) {
408+ Grid<Stone> refl = board_.ApplyReflection (axis);
409+ best = std::min (best, ComputeZobrist (refl));
410+
411+ // --- Reflection + rotations ---
412+ if (symmetry_policy_.allow_reflection_rotations ) {
413+ for (auto [i, j] : board_.GenRotations ()) {
414+ Grid<Stone> rotated = refl;
415+ for (int r = 0 ; r < 3 ; ++r) {
416+ rotated = rotated.ApplyRotation (i, j);
417+ best = std::min (best, ComputeZobrist (rotated));
418+ }
419+ }
420+ }
421+ }
422+ }
423+
424+ return best;
425+ }
426+
427+
375428GomokuState::GomokuState (std::shared_ptr<const Game> game,
376429 const std::string& state_str)
377430 : State(game),
@@ -434,5 +487,7 @@ GomokuState::GomokuState(std::shared_ptr<const Game> game,
434487 zobrist_hash_ = ComputeZobrist (board_);
435488}
436489
490+
491+
437492} // namespace gomoku
438493} // namespace open_spiel
0 commit comments