-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathSelfPlay.py
More file actions
78 lines (62 loc) · 2.38 KB
/
SelfPlay.py
File metadata and controls
78 lines (62 loc) · 2.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import pickle
from datetime import datetime
import numpy as np
from keras.models import load_model
from tqdm import tqdm
import features
from go import Position
from MCTS import get_action_prob, \
get_legal_actions, \
flat_to_coord
from NNet import NETWORK_OUTPUT_SIZE
from keras.backend import clear_session
from Config import SELFPLAY_GAMES, \
SELFPLAY_TEMPERATURE_THRESHOLD, \
SELFPLAY_TEMPERATURE_EARLY, \
SELFPLAY_TEMPERATURE_TERMINAL
class SelfPlay:
def __init__(self):
self.model = None
self.history = []
def _play_game(self):
game_history = []
state: Position = Position()
number_of_moves = 0
while not state.is_game_over():
if number_of_moves < SELFPLAY_TEMPERATURE_THRESHOLD:
scores = get_action_prob(self.model, state, SELFPLAY_TEMPERATURE_EARLY)
else:
scores = get_action_prob(self.model, state, SELFPLAY_TEMPERATURE_TERMINAL)
pi = [0] * NETWORK_OUTPUT_SIZE
for a, p in zip(get_legal_actions(state.all_legal_moves()), scores):
pi[a] = p
input_features = features.extract_features(state, features.AGZ_FEATURES)
game_history.append([input_features, pi, None])
flat = get_legal_actions(state.all_legal_moves())[np.argmax(scores)]
# print(f"Move chose {flat}, i.e., {flat_to_coord(flat)}")
state = state.play_move(flat_to_coord(flat))
# print(state)
number_of_moves += 1
# print(f"Number of moves: {number_of_moves}")
v = state.result()
print('')
for i in range(len(game_history)):
if v == 1 and i % 2 == 0: game_history[i][2] = 1
elif v == -1 and i % 2 == 1: game_history[i][2] = 1
else: game_history[i][2] = -1
return game_history
def _save_history(self):
path = f"./data/{datetime.now().timestamp()}.npy"
with open(path, mode='wb') as f:
pickle.dump(self.history, f)
def generate_data(self):
self.history = []
self.model = load_model('./model/best.h5')
for _ in tqdm(range(SELFPLAY_GAMES), desc="Self play"):
samples = self._play_game()
self.history.extend(samples)
print('')
self._save_history()
clear_session()
del self.model
return self.history