Skip to content

Commit e9cfec2

Browse files
committed
finish fixing merge
1 parent 702b150 commit e9cfec2

File tree

6 files changed

+91
-29
lines changed

6 files changed

+91
-29
lines changed

open_spiel/games/gomoku/gomoku.cc

Lines changed: 65 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -174,42 +174,44 @@ void GomokuState::DoApplyAction(Action move) {
174174
CheckWinFromLastMove(move);
175175
}
176176

177-
void GomokuState::CheckWinFromLastMove(Action last_move) {
177+
absl::optional<std::vector<Grid<Stone>::Coord>>
178+
GomokuState::FindWinLineFromLastMove(Action last_move) const {
179+
using Coord = Grid<Stone>::Coord;
178180

179-
const Grid<Stone>::Coord start =
180-
board_.Unflatten(last_move);
181+
const Coord start = board_.Unflatten(last_move);
181182
const Stone stone = board_.At(start);
182-
const auto* gomoku = static_cast<const GomokuGame*>(game_.get());
183183

184184
SPIEL_CHECK_NE(stone, Stone::kEmpty);
185185

186186
for (const auto& dir : board_.Directions()) {
187187
if (!board_.IsCanonical(dir)) continue;
188188

189-
int count = 1; // include the starting stone
189+
std::vector<Coord> line;
190+
line.push_back(start);
190191

191-
// forward direction
192+
// forward
192193
{
193194
Coord c = start;
194195
while (static_cast<int>(line.size()) < connect_ &&
195196
board_.Step(c, dir) &&
196197
board_.At(c) == stone) {
197-
++count;
198+
line.push_back(c);
198199
}
199200
}
200201

201-
// backward direction
202+
// backward
202203
{
203-
auto neg_dir = dir;
204+
Coord neg_dir = dir;
204205
for (int& v : neg_dir) v = -v;
205206

206207
Coord c = start;
207208
while (static_cast<int>(line.size()) < connect_ &&
208209
board_.Step(c, neg_dir) &&
209210
board_.At(c) == stone) {
210-
++count;
211+
line.insert(line.begin(), c);
211212
}
212213
}
214+
213215
if (static_cast<int>(line.size()) >= connect_) {
214216
line.resize(connect_); // arbitrary truncation is fine
215217
return line;
@@ -372,6 +374,57 @@ int GomokuGame::NumDistinctActions() const {
372374
return total_size_;
373375
}
374376

377+
uint64_t GomokuState::ComputeZobrist(
378+
const Grid<Stone>& grid) const {
379+
const auto* gomoku = static_cast<const GomokuGame*>(game_.get());
380+
uint64_t h = 0;
381+
for (int i = 0; i < grid.NumCells(); ++i) {
382+
Stone s = grid.AtIndex(i);
383+
if (s != Stone::kEmpty) {
384+
h ^= gomoku->ZobristTable()[i][StoneToInt(s)];
385+
}
386+
}
387+
if (current_player_ == 1) {
388+
h ^= gomoku->PlayerToMoveHash();
389+
}
390+
return h;
391+
}
392+
393+
uint64_t GomokuState::SymmetricHash() const {
394+
uint64_t best = ComputeZobrist(board_);
395+
396+
// --- Basic rotations ---
397+
for (auto [i, j] : board_.GenRotations()) {
398+
Grid<Stone> rotated = board_;
399+
for (int r = 0; r < 3; ++r) { // 90, 180, 270
400+
rotated = rotated.ApplyRotation(i, j);
401+
best = std::min(best, ComputeZobrist(rotated));
402+
}
403+
}
404+
405+
// --- Reflections ---
406+
if (symmetry_policy_.allow_reflections) {
407+
for (int axis = 0; axis < dims_; ++axis) {
408+
Grid<Stone> refl = board_.ApplyReflection(axis);
409+
best = std::min(best, ComputeZobrist(refl));
410+
411+
// --- Reflection + rotations ---
412+
if (symmetry_policy_.allow_reflection_rotations) {
413+
for (auto [i, j] : board_.GenRotations()) {
414+
Grid<Stone> rotated = refl;
415+
for (int r = 0; r < 3; ++r) {
416+
rotated = rotated.ApplyRotation(i, j);
417+
best = std::min(best, ComputeZobrist(rotated));
418+
}
419+
}
420+
}
421+
}
422+
}
423+
424+
return best;
425+
}
426+
427+
375428
GomokuState::GomokuState(std::shared_ptr<const Game> game,
376429
const std::string& state_str)
377430
: State(game),
@@ -434,5 +487,7 @@ GomokuState::GomokuState(std::shared_ptr<const Game> game,
434487
zobrist_hash_ = ComputeZobrist(board_);
435488
}
436489

490+
491+
437492
} // namespace gomoku
438493
} // namespace open_spiel

open_spiel/games/gomoku/gomoku.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,8 +49,6 @@ inline constexpr int kNumPlayers = 2;
4949
inline constexpr int kBlackPlayer = 0;
5050
inline constexpr int kWhitePlayer = 1;
5151

52-
53-
//
5452
enum class Stone {
5553
kEmpty,
5654
kBlack,

open_spiel/games/gomoku/gomoku_grid.h

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,8 @@ class Grid {
3434
wrap_(wrap),
3535
strides_(dims),
3636
data_(ComputeTotalSize(size, dims)) {
37-
if (size_ == 0 || dims_ == 0) {
38-
throw std::invalid_argument("Grid size and dims must be >= 1");
39-
}
40-
37+
SPIEL_CHECK_GE(size, 1);
38+
SPIEL_CHECK_GE(dims, 1);
4139
// stride[d] = size^(dims - d - 1)
4240
strides_[dims_ - 1] = 1;
4341
for (std::size_t i = dims_ - 1; i > 0; --i) {

open_spiel/games/gomoku/gomoku_test.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ void TestObservationTensor() {
7070
// Make two moves: Black 0, White 1
7171
state->ApplyAction(0);
7272
// white on move
73-
state->ObservationTensor(/*player=*/0, absl::MakeSpan(obs));
73+
state->ObservationTensor(/*player=*/0, absl::MakeSpan(obs));
7474
SPIEL_CHECK_EQ(obs[2 * num_cells], 0.0f);
7575
state->ApplyAction(1);
7676

open_spiel/python/pybind11/games_gomoku.cc

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ namespace py = ::pybind11;
2727
using open_spiel::Game;
2828
using open_spiel::Action;
2929
using open_spiel::State;
30+
using open_spiel::gomoku::SymmetryPolicy;
3031
using open_spiel::gomoku::GomokuGame;
3132
using open_spiel::gomoku::GomokuState;
3233

@@ -55,10 +56,14 @@ void open_spiel::init_pyspiel_games_gomoku(::pybind11::module &m) {
5556
[](const std::string& data) {
5657
return std::dynamic_pointer_cast<GomokuGame>(
5758
std::const_pointer_cast<Game>(LoadGame(data)));
58-
}));;
59+
}));
5960

6061
py::classh<GomokuState, State>(m, "GomokuState")
6162
.def("hash_value", &GomokuState::HashValue)
63+
.def("symmetric_hash", &GomokuState::SymmetricHash)
64+
.def("winning_line", &GomokuState::WinningLine,
65+
py::return_value_policy::reference_internal)
66+
6267
.def(py::pickle(
6368
[](const GomokuState& state) {
6469
return SerializeGameAndState(*state.GetGame(), state);
@@ -67,5 +72,18 @@ void open_spiel::init_pyspiel_games_gomoku(::pybind11::module &m) {
6772
auto game_and_state = DeserializeGameAndState(data);
6873
return dynamic_cast<GomokuState*>(
6974
game_and_state.second.release());
70-
}));
75+
}))
76+
.def("set_symmetry_policy",
77+
&GomokuState::SetSymmetryPolicy)
78+
.def("get_symmetry_policy",
79+
&GomokuState::GetSymmetryPolicy,
80+
py::return_value_policy::reference_internal);
81+
82+
py::class_<SymmetryPolicy>(m, "SymmetryPolicy")
83+
.def(py::init<>())
84+
.def_readwrite("allow_reflections",
85+
&SymmetryPolicy::allow_reflections)
86+
.def_readwrite("allow_reflection_rotations",
87+
&SymmetryPolicy::allow_reflection_rotations);
88+
7189
}

open_spiel/python/tests/games_gomoku_test.py

Lines changed: 3 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -54,14 +54,11 @@ def test_gomoku_game_funs(self):
5454
self.assertEqual(coord, move, f"Coord {coord} move {move}")
5555

5656
def test_gomoku_hash(self):
57-
game = pyspiel.load_game("gomoku")
57+
game = pyspiel.load_game("gomoku(size=3,connect=3)")
5858
state = game.new_initial_state()
5959
hash0 = state.hash_value()
6060
self.assertEqual(hash0, 0, f"Initial board hash {hash0}")
6161
state.apply_action(1)
62-
<<<<<<< HEAD
63-
print("hash", state.hash_value())
64-
=======
6562
hash1 = state.hash_value()
6663
sym1 = state.symmetric_hash()
6764
# States related by symmetry shoud have different hashes
@@ -78,15 +75,15 @@ def test_gomoku_hash(self):
7875
state = game.new_initial_state()
7976
policy = state.get_symmetry_policy()
8077
print("policy", policy)
81-
self.assertEqual(policy.allow_reflections, False, f"Wrong symmetry ploicy")
78+
self.assertEqual(policy.allow_reflections, False, f"Wrong symmetry policy")
8279
# verify policy is correct here
8380
state.apply_action(0)
8481
state.apply_action(1)
8582
sym11 = state.symmetric_hash()
8683
# set symmetry policy here
8784
policy.allow_reflections = True
8885
policy = state.get_symmetry_policy()
89-
self.assertEqual(policy.allow_reflections, True, f"Wrong symmetry ploicy")
86+
self.assertEqual(policy.allow_reflections, True, f"Wrong symmetry policy")
9087
sym12 = state.symmetric_hash()
9188

9289
state = game.new_initial_state()
@@ -102,7 +99,6 @@ def test_gomoku_hash(self):
10299
self.assertNotEqual(sym11, sym21, f"Hash1 {sym11} Hash2 {sym21}")
103100
self.assertEqual(sym12, sym22, f"Hash1 {sym12} Hash2 {sym22}")
104101

105-
>>>>>>> 09f40196 (Cleaned up lint, added tests)
106102

107103
def test_gommoku_game_sim(self):
108104
game = pyspiel.load_game("gomoku")
@@ -115,8 +111,6 @@ def test_gommoku_game_sim(self):
115111
state.apply_action(action)
116112
mc += 1
117113

118-
<<<<<<< HEAD
119-
=======
120114
def test_winning_line(self):
121115
game = pyspiel.load_game("gomoku(size=3,connect=3)")
122116
state = game.new_initial_state()
@@ -167,6 +161,5 @@ def test_consistent_hash(self):
167161
hash2 = state.hash_value()
168162
self.assertEqual(hash1, hash2, f"Hash1 {hash1} Hash2 {hash2}")
169163

170-
>>>>>>> 09f40196 (Cleaned up lint, added tests)
171164
if __name__ == "__main__":
172165
absltest.main()

0 commit comments

Comments
 (0)