Skip to content

Commit 00220d0

Browse files
author
miskibin
committed
feat: Introduce AI agent framework and enhance board features for ML
- Added a new AI agent interface with `Agent`, `BaseAgent`, and `AgentEngine` classes for building custom AI players. - Implemented `BoardFeatures` dataclass to extract relevant features from board positions for machine learning applications. - Enhanced `BaseBoard` with methods for tensor representation and legal move masking to support neural network integration. - Updated documentation to include a comprehensive guide on writing custom AI agents and utilizing the new features. - Added tests to ensure the correctness of AI features and agent functionality.
1 parent 49128d4 commit 00220d0

File tree

9 files changed

+1456
-0
lines changed

9 files changed

+1456
-0
lines changed

docs/source/ai.rst

Lines changed: 324 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,324 @@
1+
Writing Your Own AI
2+
===================
3+
4+
This guide covers py-draughts features designed for AI developers building
5+
custom agents, neural networks, or reinforcement learning systems.
6+
7+
Quick Example
8+
-------------
9+
10+
Here's a minimal neural network agent using PyTorch:
11+
12+
.. code-block:: python
13+
14+
import torch
15+
from draughts import Board, Agent
16+
17+
class NeuralAgent:
18+
def __init__(self, model):
19+
self.model = model
20+
21+
def select_move(self, board: Board):
22+
# Convert board to tensor (4 channels, 50 squares)
23+
x = torch.from_numpy(board.to_tensor()).unsqueeze(0)
24+
25+
# Get policy logits from your network
26+
with torch.no_grad():
27+
logits = self.model(x)[0]
28+
29+
# Mask illegal moves
30+
mask = board.legal_moves_mask()
31+
logits[~mask] = float('-inf')
32+
33+
# Sample or take argmax
34+
idx = logits.argmax().item()
35+
return board.index_to_move(idx)
36+
37+
# Usage
38+
board = Board()
39+
agent = NeuralAgent(your_trained_model)
40+
move = agent.select_move(board)
41+
42+
Agent Interface
43+
---------------
44+
45+
The :class:`~draughts.Agent` protocol defines the minimal interface for AI agents:
46+
47+
.. code-block:: python
48+
49+
from draughts import Agent, Board, Move
50+
51+
class MyAgent: # Implicitly implements Agent protocol
52+
def select_move(self, board: Board) -> Move:
53+
# Your logic here
54+
return board.legal_moves[0]
55+
56+
# Type checking confirms protocol compliance
57+
agent: Agent = MyAgent()
58+
59+
For agents needing configuration, extend :class:`~draughts.BaseAgent`:
60+
61+
.. code-block:: python
62+
63+
from draughts import BaseAgent, Board, Move
64+
65+
class ConfigurableAgent(BaseAgent):
66+
def __init__(self, temperature: float = 1.0):
67+
super().__init__(name="SoftmaxBot")
68+
self.temperature = temperature
69+
70+
def select_move(self, board: Board) -> Move:
71+
# Use self.temperature for sampling
72+
...
73+
74+
.. autoclass:: draughts.Agent
75+
:members:
76+
77+
.. autoclass:: draughts.BaseAgent
78+
:members:
79+
80+
Using Agents with Benchmark
81+
~~~~~~~~~~~~~~~~~~~~~~~~~~~
82+
83+
To use agents with :class:`~draughts.Benchmark`, wrap them as engines:
84+
85+
.. code-block:: python
86+
87+
from draughts import AgentEngine, Benchmark, BaseAgent
88+
89+
class GreedyAgent(BaseAgent):
90+
def select_move(self, board):
91+
return max(board.legal_moves, key=lambda m: len(m.captured_list))
92+
93+
# Method 1: Use as_engine() on BaseAgent
94+
engine1 = GreedyAgent().as_engine()
95+
96+
# Method 2: Wrap any Agent with AgentEngine
97+
class RandomAgent:
98+
def select_move(self, board):
99+
import random
100+
return random.choice(board.legal_moves)
101+
102+
engine2 = AgentEngine(RandomAgent(), name="Random")
103+
104+
# Now benchmark them
105+
stats = Benchmark(engine1, engine2, games=10).run()
106+
107+
.. autoclass:: draughts.AgentEngine
108+
:members:
109+
110+
Board Tensor Representation
111+
---------------------------
112+
113+
Use :meth:`~draughts.BaseBoard.to_tensor` to get a neural-network-ready representation:
114+
115+
.. code-block:: python
116+
117+
from draughts import Board
118+
119+
board = Board()
120+
tensor = board.to_tensor()
121+
122+
print(tensor.shape) # (4, 50) for 10x10 board
123+
124+
The 4 channels are:
125+
126+
====== ====================================
127+
Channel Description
128+
====== ====================================
129+
0 Own men (1.0 where present)
130+
1 Own kings (1.0 where present)
131+
2 Opponent men (1.0 where present)
132+
3 Opponent kings (1.0 where present)
133+
====== ====================================
134+
135+
By default, "own" is relative to the current turn. Override with ``perspective``:
136+
137+
.. code-block:: python
138+
139+
from draughts import Color
140+
141+
# Always from white's perspective (useful for training)
142+
tensor = board.to_tensor(perspective=Color.WHITE)
143+
144+
Feature Extraction
145+
------------------
146+
147+
For classical ML or analysis, use :meth:`~draughts.BaseBoard.features`:
148+
149+
.. code-block:: python
150+
151+
from draughts import Board
152+
153+
board = Board()
154+
board.push_uci("31-27")
155+
board.push_uci("18-22")
156+
157+
f = board.features()
158+
print(f.white_men) # 20
159+
print(f.black_men) # 20
160+
print(f.mobility) # Number of legal moves
161+
print(f.material_balance) # (white_men + 2*kings) - (black_men + 2*kings)
162+
print(f.phase) # 'opening', 'midgame', or 'endgame'
163+
164+
.. autoclass:: draughts.BoardFeatures
165+
:members:
166+
167+
Move Indexing for Policy Networks
168+
---------------------------------
169+
170+
Policy networks typically output a fixed-size vector over all possible moves.
171+
py-draughts provides tools to convert between moves and indices:
172+
173+
.. code-block:: python
174+
175+
board = Board()
176+
177+
# Get legal move mask (shape: SQUARES^2 = 2500 for 10x10)
178+
mask = board.legal_moves_mask()
179+
180+
# Your network outputs logits of shape (2500,)
181+
logits = model(board.to_tensor())
182+
183+
# Mask illegal moves
184+
logits[~mask] = float('-inf')
185+
186+
# Convert winning index back to move
187+
best_idx = logits.argmax()
188+
move = board.index_to_move(best_idx)
189+
190+
# Or convert a move to index (for training targets)
191+
target_idx = board.move_to_index(move)
192+
193+
**Index encoding**: ``from_square * SQUARES_COUNT + to_square``
194+
195+
For a 10x10 board (50 squares), indices range from 0 to 2499.
196+
197+
Cheap Position Cloning
198+
----------------------
199+
200+
Tree search and simulation require copying positions. Use :meth:`~draughts.BaseBoard.copy`
201+
for efficient cloning:
202+
203+
.. code-block:: python
204+
205+
board = Board()
206+
207+
# Fast copy - only bitboards, no move history
208+
clone = board.copy()
209+
210+
# Explore a line
211+
for move in some_variation:
212+
clone.push(move)
213+
214+
# Original unchanged
215+
assert board.position.tolist() != clone.position.tolist()
216+
217+
The ``copy()`` method is optimized:
218+
219+
- Copies only essential state (bitboards, turn, halfmove clock)
220+
- New board has empty move stack
221+
- ~10x faster than deepcopy
222+
223+
For full state preservation (including move history), use:
224+
225+
.. code-block:: python
226+
227+
import copy
228+
full_clone = copy.deepcopy(board)
229+
230+
MCTS Example
231+
------------
232+
233+
Here's a Monte Carlo Tree Search skeleton:
234+
235+
.. code-block:: python
236+
237+
from draughts import Board, BaseAgent, Move
238+
import random
239+
240+
class MCTSAgent(BaseAgent):
241+
def __init__(self, simulations: int = 1000):
242+
super().__init__(name=f"MCTS-{simulations}")
243+
self.simulations = simulations
244+
245+
def select_move(self, board: Board) -> Move:
246+
root = Node(board, None)
247+
248+
for _ in range(self.simulations):
249+
node = root
250+
sim_board = board.copy() # Cheap copy!
251+
252+
# Selection: walk to leaf
253+
while node.children and not sim_board.game_over:
254+
node = node.select_child()
255+
sim_board.push(node.move)
256+
257+
# Expansion
258+
if not sim_board.game_over and not node.children:
259+
for move in sim_board.legal_moves:
260+
node.children.append(Node(sim_board, move))
261+
262+
# Simulation
263+
while not sim_board.game_over:
264+
sim_board.push(random.choice(sim_board.legal_moves))
265+
266+
# Backpropagation
267+
result = sim_board.result
268+
while node:
269+
node.update(result)
270+
node = node.parent
271+
272+
return max(root.children, key=lambda n: n.visits).move
273+
274+
Training Tips
275+
-------------
276+
277+
**State representation**:
278+
279+
.. code-block:: python
280+
281+
# For CNN: reshape to 2D grid
282+
tensor = board.to_tensor() # (4, 50)
283+
# Note: Only 50 playable squares exist on 10x10 board
284+
285+
# For flattening to MLP:
286+
flat = tensor.flatten() # (200,)
287+
288+
**Data augmentation**: Draughts boards have rotational symmetry. A position
289+
and its 180° rotation are strategically equivalent (with colors swapped):
290+
291+
.. code-block:: python
292+
293+
# The position array is already 1D over playable squares
294+
# Reverse it and negate to get the symmetric position
295+
symmetric_pos = -board.position[::-1]
296+
297+
**Reward shaping**: Use ``features()`` for intermediate rewards:
298+
299+
.. code-block:: python
300+
301+
f = board.features()
302+
reward = f.material_balance * 0.01 # Small material reward
303+
304+
**Board variants**: All methods work on any board variant:
305+
306+
.. code-block:: python
307+
308+
from draughts import AmericanBoard, FrisianBoard
309+
310+
board = AmericanBoard() # 8x8, 32 squares
311+
tensor = board.to_tensor() # (4, 32)
312+
mask = board.legal_moves_mask() # (1024,)
313+
314+
315+
API Reference
316+
-------------
317+
318+
.. automethod:: draughts.BaseBoard.copy
319+
.. automethod:: draughts.BaseBoard.to_tensor
320+
.. automethod:: draughts.BaseBoard.features
321+
.. automethod:: draughts.BaseBoard.legal_moves_mask
322+
.. automethod:: draughts.BaseBoard.move_to_index
323+
.. automethod:: draughts.BaseBoard.index_to_move
324+

docs/source/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ API Reference
5656

5757
core
5858
engine
59+
ai
5960
svg
6061
server
6162

draughts/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,14 @@
5252
from draughts.engines.hub import HubEngine
5353
from draughts.engines.alpha_beta import AlphaBetaEngine
5454
from draughts.engines.engine import Engine
55+
from draughts.engines.agent import Agent, AgentEngine, BaseAgent
5556

5657
# Benchmarking
5758
from draughts.benchmark import Benchmark, BenchmarkStats, GameResult, STANDARD_OPENINGS
5859

60+
# AI/ML Support
61+
from draughts.boards.base import BoardFeatures
62+
5963
__all__ = [
6064
# Boards
6165
'BaseBoard',
@@ -68,6 +72,12 @@
6872
'Engine',
6973
'AlphaBetaEngine',
7074
'HubEngine',
75+
# Agents (AI interface)
76+
'Agent',
77+
'AgentEngine',
78+
'BaseAgent',
79+
# AI/ML Support
80+
'BoardFeatures',
7181
# Benchmarking
7282
'Benchmark',
7383
'BenchmarkStats',

0 commit comments

Comments
 (0)