Skip to content

Commit 361a2a8

Browse files
committed
add alpha repo for easier db management
1 parent 644cc8b commit 361a2a8

28 files changed

+2160
-0
lines changed

.gitignore

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,17 @@ puzzles.csv
88
secrets.json*
99
static/
1010
.ipynb_checkpoints
11+
12+
*.pyc
13+
alphago/.DS_Store
14+
15+
alphago/temp/
16+
alphago/.project
17+
alphago/.pydevproject
18+
19+
alphago/checkpoints/
20+
alphago/# For PyCharm users
21+
alphago/.idea/
22+
alphago/*.swp
23+
alphago/puzzles.csv
24+
alphago/results/

alpha/.gitignore

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
*.pyc
2+
.DS_Store
3+
4+
/temp/
5+
/.project
6+
/.pydevproject
7+
8+
# checkpoint
9+
checkpoints/
10+
11+
# For PyCharm users
12+
.idea/
13+
14+
*.swp
15+
puzzles.csv
16+
results/

alpha/Arena.py

Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import numpy as np
2+
from pytorch_classification.utils import Bar, AverageMeter
3+
import time
4+
5+
class Arena():
6+
"""
7+
An Arena class where any 2 agents can be pit against each other.
8+
"""
9+
def __init__(self, player1, player2, game, display=None):
10+
"""
11+
Input:
12+
player 1,2: two functions that takes board as input, return action
13+
game: Game object
14+
display: a function that takes board as input and prints it (e.g.
15+
display in othello/OthelloGame). Is necessary for verbose
16+
mode.
17+
18+
see othello/OthelloPlayers.py for an example. See pit.py for pitting
19+
human players/other baselines with each other.
20+
"""
21+
self.player1 = player1
22+
self.player2 = player2
23+
self.game = game
24+
self.display = display
25+
26+
def playGame(self, verbose=False):
27+
"""
28+
Executes one episode of a game.
29+
30+
Returns:
31+
either
32+
winner: player who won the game (1 if player1, -1 if player2)
33+
or
34+
draw result returned from the game that is neither 1, -1, nor 0.
35+
"""
36+
players = [self.player1]
37+
curPlayer = 1
38+
board, vis_state = self.game.getInitBoard()
39+
it = 0
40+
game_ended = self.game.getGameEnded(board, curPlayer)
41+
while game_ended==0:
42+
it+=1
43+
if verbose:
44+
assert(self.display)
45+
print("Turn ", str(it), "Player ", str(curPlayer))
46+
self.display(board)
47+
action = players[0](self.game.getCanonicalForm(board, curPlayer))
48+
49+
valids = self.game.getValidMoves(self.game.getCanonicalForm(board, curPlayer),1)
50+
51+
if valids[action]==0:
52+
print(action)
53+
assert valids[action] >0
54+
board, curPlayer, vis_state = self.game.getNextState(board, curPlayer, action, vis_state)
55+
game_ended = self.game.getGameEnded(board, curPlayer)
56+
57+
print(f'Game ended score {game_ended}')
58+
print('Board')
59+
print(vis_state)
60+
61+
if verbose:
62+
assert(self.display)
63+
print("Game over: Turn ", str(it), "Result ", str(self.game.getGameEnded(board, 1)))
64+
self.display(board)
65+
return self.game.getGameEnded(board, 1)
66+
67+
def playGames(self, num, verbose=False):
68+
"""
69+
Plays num games in which player1 starts num/2 games and player2 starts
70+
num/2 games.
71+
72+
Returns:
73+
oneWon: games won by player1
74+
twoWon: games won by player2
75+
draws: games won by nobody
76+
"""
77+
eps_time = AverageMeter()
78+
bar = Bar('Arena.playGames', max=num)
79+
end = time.time()
80+
eps = 0
81+
maxeps = int(num)
82+
83+
num = int(num/2)
84+
oneWon = 0
85+
twoWon = 0
86+
draws = 0
87+
for _ in range(num):
88+
gameResult = self.playGame(verbose=verbose)
89+
if gameResult==1:
90+
oneWon+=1
91+
elif gameResult==-1:
92+
twoWon+=1
93+
else:
94+
draws+=1
95+
# bookkeeping + plot progress
96+
eps += 1
97+
eps_time.update(time.time() - end)
98+
end = time.time()
99+
bar.suffix = '({eps}/{maxeps}) Eps Time: {et:.3f}s | Total: {total:} | ETA: {eta:}'.format(eps=eps, maxeps=maxeps, et=eps_time.avg,
100+
total=bar.elapsed_td, eta=bar.eta_td)
101+
bar.next()
102+
103+
self.player1, self.player2 = self.player2, self.player1
104+
105+
for _ in range(num):
106+
gameResult = self.playGame(verbose=verbose)
107+
if gameResult==-1:
108+
oneWon+=1
109+
elif gameResult==1:
110+
twoWon+=1
111+
else:
112+
draws+=1
113+
# bookkeeping + plot progress
114+
eps += 1
115+
eps_time.update(time.time() - end)
116+
end = time.time()
117+
bar.suffix = '({eps}/{maxeps}) Eps Time: {et:.3f}s | Total: {total:} | ETA: {eta:}'.format(eps=eps, maxeps=maxeps, et=eps_time.avg,
118+
total=bar.elapsed_td, eta=bar.eta_td)
119+
bar.next()
120+
121+
bar.finish()
122+
123+
return oneWon, twoWon, draws

alpha/Coach.py

Lines changed: 162 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,162 @@
1+
from collections import deque
2+
from Arena import Arena
3+
from MCTS import MCTS
4+
import numpy as np
5+
from pytorch_classification.utils import Bar, AverageMeter
6+
import time, os, sys
7+
from pickle import Pickler, Unpickler
8+
from random import shuffle
9+
10+
11+
class Coach():
12+
"""
13+
This class executes the self-play + learning. It uses the functions defined
14+
in Game and NeuralNet. args are specified in main.py.
15+
"""
16+
def __init__(self, game, nnet, args):
17+
self.game = game
18+
self.nnet = nnet
19+
self.pnet = self.nnet.__class__(self.game) # the competitor network
20+
self.args = args
21+
self.mcts = MCTS(self.game, self.nnet, self.args)
22+
self.trainExamplesHistory = [] # history of examples from args.numItersForTrainExamplesHistory latest iterations
23+
self.skipFirstSelfPlay = False # can be overriden in loadTrainExamples()
24+
25+
def executeEpisode(self):
26+
"""
27+
This function executes one episode of self-play, starting with player 1.
28+
As the game is played, each turn is added as a training example to
29+
trainExamples. The game is played till the game ends. After the game
30+
ends, the outcome of the game is used to assign values to each example
31+
in trainExamples.
32+
33+
It uses a temp=1 if episodeStep < tempThreshold, and thereafter
34+
uses temp=0.
35+
36+
Returns:
37+
trainExamples: a list of examples of the form (canonicalBoard,pi,v)
38+
pi is the MCTS informed policy vector, v is +1 if
39+
the player eventually won the game, else -1.
40+
"""
41+
trainExamples = []
42+
board, vis_state = self.game.getInitBoard()
43+
self.curPlayer = 1
44+
episodeStep = 0
45+
46+
while True:
47+
episodeStep += 1
48+
canonicalBoard = self.game.getCanonicalForm(board,self.curPlayer)
49+
temp = int(episodeStep < self.args.tempThreshold)
50+
51+
pi = self.mcts.getActionProb(canonicalBoard, temp=temp)
52+
sym = self.game.getSymmetries(canonicalBoard, pi)
53+
for b,p in sym:
54+
trainExamples.append([b[0], self.curPlayer, p, None])
55+
56+
action = np.random.choice(len(pi), p=pi)
57+
board, self.curPlayer, vis_state = self.game.getNextState(board, self.curPlayer, action)
58+
59+
r = self.game.getGameEnded(board, self.curPlayer)
60+
61+
if r!=0:
62+
# return [(x[0],x[2],r*((-1)**(x[1]!=self.curPlayer))) for x in trainExamples]
63+
return [(x[0],x[2],r) for x in trainExamples]
64+
65+
def learn(self):
66+
"""
67+
Performs numIters iterations with numEps episodes of self-play in each
68+
iteration. After every iteration, it retrains neural network with
69+
examples in trainExamples (which has a maximium length of maxlenofQueue).
70+
It then pits the new neural network against the old one and accepts it
71+
only if it wins >= updateThreshold fraction of games.
72+
"""
73+
74+
for i in range(1, self.args.numIters+1):
75+
# bookkeeping
76+
print('------ITER ' + str(i) + '------')
77+
# examples of the iteration
78+
if not self.skipFirstSelfPlay or i>1:
79+
iterationTrainExamples = deque([], maxlen=self.args.maxlenOfQueue)
80+
81+
eps_time = AverageMeter()
82+
bar = Bar('Self Play', max=self.args.numEps)
83+
end = time.time()
84+
85+
for eps in range(self.args.numEps):
86+
self.mcts = MCTS(self.game, self.nnet, self.args) # reset search tree
87+
iterationTrainExamples += self.executeEpisode()
88+
89+
# bookkeeping + plot progress
90+
eps_time.update(time.time() - end)
91+
end = time.time()
92+
bar.suffix = '({eps}/{maxeps}) Eps Time: {et:.3f}s | Total: {total:} | ETA: {eta:}'.format(eps=eps+1, maxeps=self.args.numEps, et=eps_time.avg,
93+
total=bar.elapsed_td, eta=bar.eta_td)
94+
bar.next()
95+
bar.finish()
96+
97+
# save the iteration examples to the history
98+
self.trainExamplesHistory.append(iterationTrainExamples)
99+
100+
if len(self.trainExamplesHistory) > self.args.numItersForTrainExamplesHistory:
101+
print("len(trainExamplesHistory) =", len(self.trainExamplesHistory), " => remove the oldest trainExamples")
102+
self.trainExamplesHistory.pop(0)
103+
# backup history to a file
104+
# NB! the examples were collected using the model from the previous iteration, so (i-1)
105+
self.saveTrainExamples(i-1)
106+
107+
# shuffle examples before training
108+
trainExamples = []
109+
for e in self.trainExamplesHistory:
110+
trainExamples.extend(e)
111+
shuffle(trainExamples)
112+
113+
# training new network, keeping a copy of the old one
114+
self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
115+
self.pnet.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
116+
pmcts = MCTS(self.game, self.pnet, self.args)
117+
118+
self.nnet.train(trainExamples)
119+
nmcts = MCTS(self.game, self.nnet, self.args)
120+
121+
print('PITTING AGAINST PREVIOUS VERSION')
122+
arena = Arena(lambda x: np.argmax(pmcts.getActionProb(x, temp=0)),
123+
lambda x: np.argmax(nmcts.getActionProb(x, temp=0)), self.game)
124+
pwins, nwins, draws = arena.playGames(self.args.arenaCompare)
125+
126+
print('NEW/PREV WINS : %d / %d ; DRAWS : %d' % (nwins, pwins, draws))
127+
#if pwins+nwins == 0 or float(nwins)/(pwins+nwins) < self.args.updateThreshold:
128+
if False:
129+
print('REJECTING NEW MODEL')
130+
self.nnet.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar')
131+
else:
132+
print('ACCEPTING NEW MODEL')
133+
self.nnet.save_checkpoint(folder=self.args.checkpoint, filename=self.getCheckpointFile(i))
134+
self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='best.pth.tar')
135+
136+
def getCheckpointFile(self, iteration):
137+
return 'checkpoint_' + str(iteration) + '.pth.tar'
138+
139+
def saveTrainExamples(self, iteration):
140+
folder = self.args.checkpoint
141+
if not os.path.exists(folder):
142+
os.makedirs(folder)
143+
filename = os.path.join(folder, self.getCheckpointFile(iteration)+".examples")
144+
with open(filename, "wb+") as f:
145+
Pickler(f).dump(self.trainExamplesHistory)
146+
f.closed
147+
148+
def loadTrainExamples(self):
149+
modelFile = os.path.join(self.args.load_folder_file[0], self.args.load_folder_file[1])
150+
examplesFile = modelFile+".examples"
151+
if not os.path.isfile(examplesFile):
152+
print(examplesFile)
153+
r = input("File with trainExamples not found. Continue? [y|n]")
154+
if r != "y":
155+
sys.exit()
156+
else:
157+
print("File with trainExamples found. Read it.")
158+
with open(examplesFile, "rb") as f:
159+
self.trainExamplesHistory = Unpickler(f).load()
160+
f.closed
161+
# examples based on the model were already collected (loaded)
162+
self.skipFirstSelfPlay = True

0 commit comments

Comments
 (0)