Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,19 @@ Stream

https://www.twitch.tv/tomcr00s3

Usage
Setup
-----

```
pip3 install python-chess torch torchvision numpy flask
# then...
python3 -m venv venv
source venv/bin/activate
pip install pip --upgrade
pip install -r requirements.txt
```

Usage
-----
```
./play.py # runs webserver on localhost:5000
```

Expand Down
70 changes: 53 additions & 17 deletions generate_training_set.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,63 @@
#!/usr/bin/env python3
import os
import chess.pgn
import numpy as np
import pathlib
from state import State

def get_dataset(num_samples=None):
def extract_boards(instance_game):
"""
Extracts serialized board from game
"""
board = instance_game.board()
list_out = []
for move in instance_game.mainline_moves():
board.push(move)
board_serialized = State(board).serialize()
list_out.append(board_serialized)

return list_out

def get_dataset_via_pgn(path_dir_pgn, num_samples=None):
"""
Reads and iterates over pgn files
chess.pgn.read_game works as a generator
no way to tell how large the file is aside from counting Results in file.
For 1M games:
1:47:27
NOTE: A lot of games are ended by giving up

NOTE: Hard-coded labels
1/2-1/2 : tie
0-1 : R player wins
1-0 : L player wins

We read the games and process one at a time

TODO: Periodically writes to out file to avoid out-of-memory (OOM)
"""
X,Y = [], []
gn = 0
values = {'1/2-1/2':0, '0-1':-1, '1-0':1}
# pgn files in the data folder
for fn in os.listdir("data"):
pgn = open(os.path.join("data", fn))
dict_map_result = {'1/2-1/2':0, '0-1':-1, '1-0':1}

for fname in path_dir_pgn.glob("*.pgn"):
pgn = open(fname)
while 1:
# read & parse
game = chess.pgn.read_game(pgn)
if game is None:
break
res = game.headers['Result']
if res not in values:
result = game.headers['Result']
if result not in dict_map_result:
continue
value = values[res]
board = game.board()
for i, move in enumerate(game.mainline_moves()):
board.push(move)
ser = State(board).serialize()
X.append(ser)
Y.append(value)
value = dict_map_result[result]

# extract
list_boards_serialized = extract_boards(game)
X.extend(list_boards_serialized)
Y.extend([value] * len(list_data))
print("parsing game %d, got %d examples" % (gn, len(X)))

# break early
if num_samples is not None and len(X) > num_samples:
return X,Y
gn += 1
Expand All @@ -34,6 +66,10 @@ def get_dataset(num_samples=None):
return X,Y

if __name__ == "__main__":
X,Y = get_dataset(25000000)
np.savez("processed/dataset_25M.npz", X, Y)
DIR_DATA = pathtlib.Path(__file__).parent / "data"
DIR_OUT = pathlib.Path(__file__).parent / "processed"
DIR_OUT.mkdir(exist_ok=True, parents=True)

X,Y = get_dataset_via_pgn(DIR_DATA, 25e6)
np.savez(DIR_OUT / "dataset_25M.npz", X, Y)

10 changes: 8 additions & 2 deletions play.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,22 @@
import chess.svg
import traceback
import base64
import torch

from state import State
from train import Net

class Valuator(object):
"""
Call the model for inference
"""
def __init__(self):
import torch
from train import Net
vals = torch.load("nets/value.pth", map_location=lambda storage, loc: storage)
self.model = Net()
self.model.load_state_dict(vals)
self.model.eval() # turn on evaluation mode

@torch.no_grad()
def __call__(self, s):
brd = s.serialize()[None]
output = self.model(torch.tensor(brd).float())
Expand Down
5 changes: 5 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
flask
numpy
python-chess
torch
torchvision
73 changes: 55 additions & 18 deletions state.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,32 @@
#!/usr/bin/env python3
import chess
import numpy as np

def construct_dict_map_pieces(value_shift = 8):
"""
https://python-chess.readthedocs.io/en/latest/core.html#chess.Piece
gives the default values
Gets the symbol P, N, B, R, Q or K for white pieces or the lower-case variants for the black pieces.
We'll assign values
f(piece_white) = 1 through N for the white pieces, and
f(piece_black) = f(piece_white) + 8 for the black pieces

Returns
P = 1, p = 1+8
N = 2, n = 2+8
...
K = 6, k = 6+8

which leaves (7,13), (15, infty) available
"""
list_pieces = ['P', 'N', 'B', 'R', 'Q', 'K']
dict_out = {}
for i, piece in enumerate(list_pieces):
value = i+1
dict_out[piece] = value
dict_out[piece.lower()] = value+value_shift

return dict_out

class State(object):
def __init__(self, board=None):
Expand All @@ -8,49 +35,59 @@ def __init__(self, board=None):
else:
self.board = board

self.dict_map_pieces = construct_dict_map_pieces()

def key(self):
return (self.board.board_fen(), self.board.turn, self.board.castling_rights, self.board.ep_square)

def serialize(self):
import numpy as np
"""
Store the board piece value in an 8x8 matrix with hard-coded values.
Convert the board to a bit matrix s.t.,
state[i,j] are the extracted bits containing four bits per integer
then keep the rightmost bit

State is 256 bits according to readme
"""
assert self.board.is_valid()

bstate = np.zeros(64, np.uint8)
bstate = np.zeros(64, np.uint8) # board state
for i in range(64):
pp = self.board.piece_at(i)
if pp is not None:
#print(i, pp.symbol())
bstate[i] = {"P": 1, "N": 2, "B": 3, "R": 4, "Q": 5, "K": 6, \
"p": 9, "n":10, "b":11, "r":12, "q":13, "k": 14}[pp.symbol()]
bstate[i] = self.dict_map_pieces[pp.symbol()]

piece_rook = "R"
value_rook_castled = 7 # hard-coded
if self.board.has_queenside_castling_rights(chess.WHITE):
assert bstate[0] == 4
bstate[0] = 7
assert bstate[0] == self.dict_map_pieces[piece_rook]
bstate[0] = value_rook_castled
if self.board.has_kingside_castling_rights(chess.WHITE):
assert bstate[7] == 4
bstate[7] = 7
assert bstate[7] == self.dict_map_pieces[piece_rook]
bstate[7] = value_rook_castled
if self.board.has_queenside_castling_rights(chess.BLACK):
assert bstate[56] == 8+4
bstate[56] = 8+7
assert bstate[56] == self.dict_map_pieces[piece_rook.lower()]
bstate[56] = value_rook_castled+8
if self.board.has_kingside_castling_rights(chess.BLACK):
assert bstate[63] == 8+4
bstate[63] = 8+7
assert bstate[63] == self.dict_map_pieces[piece_rook.lower()]
bstate[63] = value_rook_castled+8

value_ep_square = 8 # hard-coded
if self.board.ep_square is not None:
assert bstate[self.board.ep_square] == 0
bstate[self.board.ep_square] = 8
bstate[self.board.ep_square] = value_ep_square
bstate = bstate.reshape(8,8)

# binary state
state = np.zeros((5,8,8), np.uint8)

# 0-3 columns to binary
state[0] = (bstate>>3)&1
state[1] = (bstate>>2)&1
state[2] = (bstate>>1)&1
state[3] = (bstate>>0)&1
for i in range(4):
state[i] = (bstate>>4-1-i)&1 # convert to binary, then keep rightmost bit

# 4th column is who's turn it is
state[4] = (self.board.turn*1.0)
state[4] = self.board.turn*1.0

# 257 bits according to readme
return state
Expand Down
Loading