Skip to content

Commit cee7543

Browse files
author
miskibin
committed
feat: Refactor engine interface to use BaseBoard and enhance legal moves retrieval in server
1 parent 7c140cf commit cee7543

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

draughts/engine.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from loguru import logger
66
import numpy as np
77

8+
from draughts.boards.base import BaseBoard
89
from draughts.boards.standard import Board, Move, Figure
910
from draughts.models import Color
1011

@@ -63,7 +64,7 @@ class Engine(ABC):
6364

6465
@abstractmethod
6566
def get_best_move(
66-
self, board: Board, with_evaluation: bool
67+
self, board: BaseBoard, with_evaluation: bool
6768
) -> Move | tuple[Move, float]:
6869
"""
6970
Returns best move for given board.
@@ -138,7 +139,7 @@ def _get_piece_index(self, piece):
138139
# Map piece value to 0-4 index
139140
return piece + 2
140141

141-
def compute_hash(self, board: Board) -> int:
142+
def compute_hash(self, board: BaseBoard) -> int:
142143
h = 0
143144
for i, piece in enumerate(board._pos):
144145
if piece != 0:
@@ -147,7 +148,7 @@ def compute_hash(self, board: Board) -> int:
147148
h ^= self.zobrist_turn
148149
return h
149150

150-
def evaluate(self, board: Board) -> float:
151+
def evaluate(self, board: BaseBoard) -> float:
151152
"""
152153
Evaluation function with material and PST.
153154
Returns score from the perspective of the side to move.
@@ -181,7 +182,7 @@ def evaluate(self, board: Board) -> float:
181182
return -score
182183
return score
183184

184-
def get_best_move(self, board: Board, with_evaluation: bool = False) -> Move | tuple[Move, float]:
185+
def get_best_move(self, board: BaseBoard, with_evaluation: bool = False) -> Move | tuple[Move, float]:
185186
self.start_time = time.time()
186187
self.nodes = 0
187188
self.stop_search = False
@@ -240,7 +241,7 @@ def get_best_move(self, board: Board, with_evaluation: bool = False) -> Move | t
240241
return best_move, float(best_score)
241242
return best_move
242243

243-
def negamax(self, board: Board, depth: int, alpha: float, beta: float, h: int) -> float:
244+
def negamax(self, board: BaseBoard, depth: int, alpha: float, beta: float, h: int) -> float:
244245
self.nodes += 1
245246

246247
# Check time
@@ -337,7 +338,7 @@ def negamax(self, board: Board, depth: int, alpha: float, beta: float, h: int) -
337338

338339
return best_value
339340

340-
def quiescence_search(self, board: Board, alpha: float, beta: float, h: int, qs_depth: int = 0) -> float:
341+
def quiescence_search(self, board: BaseBoard, alpha: float, beta: float, h: int, qs_depth: int = 0) -> float:
341342
"""Search captures until position is quiet."""
342343
self.nodes += 1
343344

@@ -377,7 +378,7 @@ def quiescence_search(self, board: Board, alpha: float, beta: float, h: int, qs_
377378

378379
return alpha
379380

380-
def _update_hash(self, current_hash: int, board: Board, move: Move) -> int:
381+
def _update_hash(self, current_hash: int, board: BaseBoard, move: Move) -> int:
381382
# XOR out source
382383
start_sq = move.square_list[0]
383384
piece = board._pos[start_sq]
@@ -405,7 +406,7 @@ def _update_hash(self, current_hash: int, board: Board, move: Move) -> int:
405406

406407
return current_hash
407408

408-
def _order_moves(self, moves: List[Move], board: Board | None = None, h: int = 0, depth: int = 0) -> List[Move]:
409+
def _order_moves(self, moves: List[Move], board: BaseBoard | None = None, h: int = 0, depth: int = 0) -> List[Move]:
409410
# 1. PV Move from TT
410411
tt_entry = self.tt.get(h)
411412
pv_move = tt_entry[3] if tt_entry else None
@@ -438,7 +439,7 @@ def score_move(move):
438439
moves.sort(key=score_move, reverse=True)
439440
return moves
440441

441-
def _order_captures(self, moves: List[Move], board: Board) -> List[Move]:
442+
def _order_captures(self, moves: List[Move], board: BaseBoard) -> List[Move]:
442443
# Sort by number of captures
443444
moves.sort(key=lambda m: len(m.captured_list), reverse=True)
444445
return moves

draughts/server/server.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ def get_position(self, request: Request) -> PositionResponse:
187187
def get_legal_moves(self) -> dict:
188188
"""Get all legal moves for the current position."""
189189
with self._lock:
190-
moves_dict = defaultdict(list)
190+
moves_dict: dict[int, list[int]] = defaultdict(list)
191191
for move in list(self.board.legal_moves):
192192
moves_dict[int(move.square_list[0])].extend(
193193
map(int, move.square_list[1:])
@@ -235,7 +235,8 @@ def get_best_move(self, request: Request) -> PositionResponse:
235235
if not legal_moves:
236236
return self.position_json
237237

238-
move = engine.get_best_move(self.board)
238+
result = engine.get_best_move(self.board, with_evaluation=False)
239+
move = result if not isinstance(result, tuple) else result[0]
239240

240241
# Validate move is legal (handles stale TT or overlapping requests)
241242
if move not in legal_moves:

0 commit comments

Comments
 (0)