Skip to content

Commit 9f11522

Browse files
committed
Add initial code
1 parent 3a1e7e6 commit 9f11522

31 files changed

+4711
-0
lines changed

NOTICE

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
LiveChess2FEN
2+
Copyright (c) 2020 David Mallasén Quintana
3+
4+
See also the LICENSE file.
5+
6+
The purpose of this NOTICE file is to contain notices that are required by
7+
the copyright owner of software included in this project and their license.
8+
Some of the accompanying products have an attribution requirement, so see below.
9+
10+
LiveChess2FEN includes software from the following projects:
11+
12+
---
13+
neural-chessboard
14+
15+
HOMEPAGE: https://github.com/maciejczyzewski/neural-chessboard
16+
LICENSE (full text below): MIT License
17+
COPYRIGHT: Copyright (c) 2017-present Maciej A. Czyzewski and other contributors
18+
---
19+
keras-squeezenet
20+
21+
HOMEPAGE: https://github.com/rcmalli/keras-squeezenet
22+
LICENSE (full text below): MIT License
23+
COPYRIGHT: Copyright (c) 2016 Refikcanmalli
24+
---
25+
bentley_ottmann
26+
27+
HOMEPAGE: https://github.com/lycantropos/bentley_ottmann
28+
LICENSE (full text below): MIT License
29+
COPYRIGHT: Copyright (c) 2020 Azat Ibrakov
30+
---
31+
bintrees 2.0.2
32+
33+
HOMEPAGE: https://github.com/mozman/bintrees
34+
LICENSE (full text below): MIT License
35+
COPYRIGHT: Copyright (c) 2012, Manfred Moitzi
36+
---
37+
38+
==========
39+
40+
MIT License
41+
42+
Permission is hereby granted, free of charge, to any person obtaining a copy
43+
of this software and associated documentation files (the "Software"), to deal
44+
in the Software without restriction, including without limitation the rights
45+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
46+
copies of the Software, and to permit persons to whom the Software is
47+
furnished to do so, subject to the following conditions:
48+
49+
The above copyright notice and this permission notice shall be included in all
50+
copies or substantial portions of the Software.
51+
52+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
53+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
54+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
55+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
56+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
57+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
58+
SOFTWARE.
59+
60+
==========

board_detection.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
"""
2+
Executes the detection of a chessboard.
3+
"""
4+
from lc2fen.board2data import regenerate_data_state, process_input_boards
5+
6+
7+
def main():
8+
regenerate_data_state("data")
9+
process_input_boards("data")
10+
11+
12+
if __name__ == "__main__":
13+
main()

cpmodels/__init__.py

Whitespace-only changes.
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
"""
2+
Common functions to train the chess piece models.
3+
"""
4+
import matplotlib
5+
6+
from keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau
7+
from keras.layers import Dense, GlobalAveragePooling2D
8+
from keras.models import Model
9+
from keras.preprocessing.image import ImageDataGenerator
10+
11+
matplotlib.use('agg')
12+
import matplotlib.pyplot as plt
13+
14+
15+
def build_model(base_model):
16+
"""
17+
Builds the model from a pretrained base model.
18+
:param base_model: Base model from keras applications.
19+
Example: MobileNetV2(input_shape=(224, 224, 3),
20+
include_top=False,
21+
weights='imagenet')
22+
:return: The compiled model to train.
23+
"""
24+
layers = base_model.output
25+
layers = GlobalAveragePooling2D()(layers)
26+
layers = Dense(1024, activation='relu')(layers)
27+
preds = Dense(13, activation='softmax')(layers)
28+
29+
model = Model(inputs=base_model.input, outputs=preds)
30+
31+
model.compile(optimizer='Adam',
32+
loss='categorical_crossentropy',
33+
metrics=['accuracy'])
34+
35+
return model
36+
37+
38+
def data_generators(preprocessing_function, target_size, batch_size,
39+
train_path='../data/train/',
40+
validation_path='../data/validation/'):
41+
"""
42+
Returns the train and validation generators.
43+
44+
:param preprocessing_function: Corresponding preprocessing function
45+
for the pretrained base model.
46+
Example: from keras.applications.mobilenet_v2
47+
import preprocess_input
48+
:param target_size: The dimensions to which all images found will be
49+
resized. Example: (224, 224)
50+
:param batch_size: Size of the batches of data.
51+
:param train_path: Path to the train folder.
52+
:param validation_path: Path to the validation folder.
53+
:return: The train and validation generators.
54+
"""
55+
datagen = ImageDataGenerator(preprocessing_function=preprocessing_function,
56+
dtype='float16')
57+
58+
train_gen = datagen.flow_from_directory(train_path,
59+
target_size=target_size,
60+
color_mode='rgb',
61+
batch_size=batch_size,
62+
class_mode='categorical',
63+
shuffle=True)
64+
65+
val_gen = datagen.flow_from_directory(validation_path,
66+
target_size=target_size,
67+
color_mode='rgb',
68+
batch_size=batch_size,
69+
class_mode='categorical',
70+
shuffle=True)
71+
return train_gen, val_gen
72+
73+
74+
def train_model(model, epochs, train_generator, val_generator, callbacks,
75+
use_weights, workers):
76+
"""Trains the input model."""
77+
steps_per_epoch = train_generator.n // train_generator.batch_size
78+
validation_steps = val_generator.n // val_generator.batch_size
79+
80+
if use_weights:
81+
weights = {0: 1., 1: 1., 2: 1., 3: 0.125, 4: 1., 5: 1., 6: 0.05,
82+
7: 1., 8: 1., 9: 1., 10: 0.125, 11: 1., 12: 1.}
83+
else:
84+
weights = None
85+
86+
return model.fit_generator(generator=train_generator,
87+
steps_per_epoch=steps_per_epoch,
88+
epochs=epochs,
89+
validation_data=val_generator,
90+
validation_steps=validation_steps,
91+
callbacks=callbacks,
92+
verbose=2,
93+
class_weight=weights,
94+
use_multiprocessing=True,
95+
workers=workers
96+
)
97+
98+
99+
def model_callbacks(early_stopping_patience, model_checkpoint_dir,
100+
reducelr_factor, reducelr_patience):
101+
"""Initializes the model callbacks."""
102+
early_stopping = EarlyStopping(monitor='val_accuracy',
103+
mode='max',
104+
verbose=1,
105+
patience=early_stopping_patience,
106+
restore_best_weights=True,
107+
min_delta=0.002)
108+
model_checkpoint = ModelCheckpoint(filepath=model_checkpoint_dir,
109+
monitor='val_accuracy',
110+
mode='max',
111+
verbose=1,
112+
save_best_only=True)
113+
reduce_lr = ReduceLROnPlateau(monitor='val_accuracy',
114+
mode='max',
115+
factor=reducelr_factor,
116+
patience=reducelr_patience,
117+
verbose=1)
118+
return [early_stopping, model_checkpoint, reduce_lr]
119+
120+
121+
def plot_model_history(history, accuracy_savedir, loss_savedir):
122+
"""Plots the model history (accuracy and loss)."""
123+
# Summarize history for accuracy
124+
plt.plot(history.history['accuracy'])
125+
plt.plot(history.history['val_accuracy'])
126+
plt.title('Model accuracy')
127+
plt.ylabel('accuracy')
128+
plt.xlabel('epoch')
129+
plt.legend(['train', 'test'])
130+
plt.savefig(accuracy_savedir)
131+
plt.close()
132+
133+
# Summarize history for loss
134+
plt.plot(history.history['loss'])
135+
plt.plot(history.history['val_loss'])
136+
plt.title('Model loss')
137+
plt.ylabel('loss')
138+
plt.xlabel('epoch')
139+
plt.legend(['train', 'test'])
140+
plt.savefig(loss_savedir)
141+
plt.close()
142+
143+
144+
def evaluate_model(model, test_generator):
145+
"""
146+
Prints the test loss and accuracy of the model.
147+
148+
:param model: Model to evaluate.
149+
:param test_generator: Generator with which to test the model.
150+
"""
151+
scores = model.evaluate_generator(test_generator, verbose=1)
152+
print('Test loss:', scores[0])
153+
print('Test accuracy:', scores[1])

cpmodels/dataset.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
"""
2+
Works with the chess pieces dataset.
3+
"""
4+
import csv
5+
import functools
6+
import os
7+
import shutil
8+
from random import shuffle
9+
10+
import pandas as pd
11+
12+
from lc2fen.fen import PIECE_TYPES
13+
14+
PIECES_TO_CLASSNUM = {
15+
'_': 0,
16+
'b': 1,
17+
'k': 2,
18+
'n': 3,
19+
'p': 4,
20+
'q': 5,
21+
'r': 6,
22+
'B': 7,
23+
'K': 8,
24+
'N': 9,
25+
'P': 10,
26+
'Q': 11,
27+
'R': 12,
28+
}
29+
30+
31+
def create_dataset_csv(dataset_dir, csv_name, frac=1, validate=0.2, test=0.1):
32+
"""
33+
Deprecated, not currently in use.
34+
35+
Creates the csv for the dataset.
36+
37+
:param dataset_dir: Directory of the dataset.
38+
:param csv_name: Name of the output csv.
39+
:param frac: Fraction of images to load. Default 1.
40+
:param validate: Fraction of images to label as VAL. Default 0.2.
41+
:param test: Fraction of images to label as TEST. Default 0.1.
42+
:return: Number of loaded images.
43+
"""
44+
def load_dataset_images(dataset_dir, frac):
45+
"""
46+
Returns a DataFrame with the loaded dataset images.
47+
48+
:param dataset_dir: Directory of the dataset.
49+
:param frac: Fraction of images to load.
50+
"""
51+
file_names = [(piece_type, [dataset_dir + piece_type + '/' + str(x) for
52+
x in os.listdir(dataset_dir + piece_type)])
53+
for piece_type in PIECE_TYPES]
54+
55+
file_names_label = [list(zip(images, [piece_type for x in images])) for
56+
piece_type, images in file_names]
57+
58+
data_frame = pd.DataFrame(data=functools.reduce(lambda x, y: x + y,
59+
file_names_label))
60+
data_frame = data_frame.rename(columns={0: "image_name", 1: "label"})
61+
# Shuffle rows
62+
return data_frame.sample(frac=frac).reset_index(drop=True)
63+
64+
data_frame = load_dataset_images(dataset_dir, frac)
65+
total_rows = len(data_frame.index)
66+
67+
with open(dataset_dir + csv_name, 'w', newline='',
68+
encoding='utf-8') as csvfile:
69+
csvwriter = csv.writer(csvfile, delimiter=',', quotechar='|',
70+
quoting=csv.QUOTE_MINIMAL)
71+
start_test = 1.0 - test
72+
start_validate = start_test - validate
73+
74+
for i, row in data_frame.iterrows():
75+
percentage = i / total_rows
76+
set_str = 'TRAIN'
77+
if percentage >= start_test:
78+
set_str = 'TEST'
79+
elif percentage >= start_validate:
80+
set_str = 'VAL'
81+
82+
filename, label = row
83+
label = PIECES_TO_CLASSNUM[label]
84+
csvwriter.writerow([set_str, filename, label])
85+
return total_rows
86+
87+
88+
def randomize_dataset(dataset_dir):
89+
"""
90+
Randomizes the order of the images in the subdirectories of
91+
dataset_dir. Renames them to <number>.jpg.
92+
93+
:param dataset_dir: Directory of the dataset.
94+
"""
95+
dirs = [d for d in os.listdir(dataset_dir) if
96+
os.path.isdir(os.path.join(dataset_dir, d))]
97+
for dir in dirs:
98+
files = os.listdir(dataset_dir + "/" + dir)
99+
shuffle(files)
100+
101+
for i, file in enumerate(files):
102+
path = os.path.join(dataset_dir, dir, file)
103+
if os.path.isfile(path):
104+
newpath = os.path.join(dataset_dir, dir, str(i) + ".jpg")
105+
os.rename(path, newpath)
106+
107+
108+
def split_dataset(dataset_dir, train_dir, validation_dir, train_perc=0.8):
109+
"""
110+
Splits dataset_dir into train_dir and validation_dir given
111+
train_perc.
112+
113+
:param dataset_dir: Directory of the whole dataset.
114+
:param train_dir: Train directory.
115+
:param validation_dir: Validation directory.
116+
:param train_perc: Percentage of training images. Default 0.8.
117+
"""
118+
shutil.rmtree(train_dir)
119+
shutil.rmtree(validation_dir)
120+
121+
os.mkdir(train_dir)
122+
os.mkdir(train_dir + '/_/')
123+
os.mkdir(train_dir + '/r/')
124+
os.mkdir(train_dir + '/n/')
125+
os.mkdir(train_dir + '/b/')
126+
os.mkdir(train_dir + '/q/')
127+
os.mkdir(train_dir + '/k/')
128+
os.mkdir(train_dir + '/p/')
129+
os.mkdir(train_dir + '/R/')
130+
os.mkdir(train_dir + '/N/')
131+
os.mkdir(train_dir + '/B/')
132+
os.mkdir(train_dir + '/Q/')
133+
os.mkdir(train_dir + '/K/')
134+
os.mkdir(train_dir + '/P/')
135+
136+
os.mkdir(validation_dir)
137+
os.mkdir(validation_dir + '/_/')
138+
os.mkdir(validation_dir + '/r/')
139+
os.mkdir(validation_dir + '/n/')
140+
os.mkdir(validation_dir + '/b/')
141+
os.mkdir(validation_dir + '/q/')
142+
os.mkdir(validation_dir + '/k/')
143+
os.mkdir(validation_dir + '/p/')
144+
os.mkdir(validation_dir + '/R/')
145+
os.mkdir(validation_dir + '/N/')
146+
os.mkdir(validation_dir + '/B/')
147+
os.mkdir(validation_dir + '/Q/')
148+
os.mkdir(validation_dir + '/K/')
149+
os.mkdir(validation_dir + '/P/')
150+
151+
dirs = [d for d in os.listdir(dataset_dir) if
152+
os.path.isdir(os.path.join(dataset_dir, d))]
153+
for dir in dirs:
154+
files = os.listdir(os.path.join(dataset_dir, dir))
155+
num_train_files = len(files) * train_perc
156+
for i, file in enumerate(files):
157+
path = os.path.join(dataset_dir, dir, file)
158+
if os.path.isfile(path):
159+
if i < num_train_files:
160+
newpath = os.path.join(train_dir, dir, file)
161+
else:
162+
newpath = os.path.join(validation_dir, dir, file)
163+
shutil.copy(path, newpath)

0 commit comments

Comments
 (0)