|
| 1 | +from collections import deque |
| 2 | +from Arena import Arena |
| 3 | +from MCTS import MCTS |
| 4 | +import numpy as np |
| 5 | +from pytorch_classification.utils import Bar, AverageMeter |
| 6 | +import time, os, sys |
| 7 | +from pickle import Pickler, Unpickler |
| 8 | +from random import shuffle |
| 9 | + |
| 10 | + |
| 11 | +class Coach(): |
| 12 | + """ |
| 13 | + This class executes the self-play + learning. It uses the functions defined |
| 14 | + in Game and NeuralNet. args are specified in main.py. |
| 15 | + """ |
| 16 | + def __init__(self, game, nnet, args): |
| 17 | + self.game = game |
| 18 | + self.nnet = nnet |
| 19 | + self.pnet = self.nnet.__class__(self.game) # the competitor network |
| 20 | + self.args = args |
| 21 | + self.mcts = MCTS(self.game, self.nnet, self.args) |
| 22 | + self.trainExamplesHistory = [] # history of examples from args.numItersForTrainExamplesHistory latest iterations |
| 23 | + self.skipFirstSelfPlay = False # can be overriden in loadTrainExamples() |
| 24 | + |
| 25 | + def executeEpisode(self): |
| 26 | + """ |
| 27 | + This function executes one episode of self-play, starting with player 1. |
| 28 | + As the game is played, each turn is added as a training example to |
| 29 | + trainExamples. The game is played till the game ends. After the game |
| 30 | + ends, the outcome of the game is used to assign values to each example |
| 31 | + in trainExamples. |
| 32 | +
|
| 33 | + It uses a temp=1 if episodeStep < tempThreshold, and thereafter |
| 34 | + uses temp=0. |
| 35 | +
|
| 36 | + Returns: |
| 37 | + trainExamples: a list of examples of the form (canonicalBoard,pi,v) |
| 38 | + pi is the MCTS informed policy vector, v is +1 if |
| 39 | + the player eventually won the game, else -1. |
| 40 | + """ |
| 41 | + trainExamples = [] |
| 42 | + board, vis_state = self.game.getInitBoard() |
| 43 | + self.curPlayer = 1 |
| 44 | + episodeStep = 0 |
| 45 | + |
| 46 | + while True: |
| 47 | + episodeStep += 1 |
| 48 | + canonicalBoard = self.game.getCanonicalForm(board,self.curPlayer) |
| 49 | + temp = int(episodeStep < self.args.tempThreshold) |
| 50 | + |
| 51 | + pi = self.mcts.getActionProb(canonicalBoard, temp=temp) |
| 52 | + sym = self.game.getSymmetries(canonicalBoard, pi) |
| 53 | + for b,p in sym: |
| 54 | + trainExamples.append([b[0], self.curPlayer, p, None]) |
| 55 | + |
| 56 | + action = np.random.choice(len(pi), p=pi) |
| 57 | + board, self.curPlayer, vis_state = self.game.getNextState(board, self.curPlayer, action) |
| 58 | + |
| 59 | + r = self.game.getGameEnded(board, self.curPlayer) |
| 60 | + |
| 61 | + if r!=0: |
| 62 | + # return [(x[0],x[2],r*((-1)**(x[1]!=self.curPlayer))) for x in trainExamples] |
| 63 | + return [(x[0],x[2],r) for x in trainExamples] |
| 64 | + |
| 65 | + def learn(self): |
| 66 | + """ |
| 67 | + Performs numIters iterations with numEps episodes of self-play in each |
| 68 | + iteration. After every iteration, it retrains neural network with |
| 69 | + examples in trainExamples (which has a maximium length of maxlenofQueue). |
| 70 | + It then pits the new neural network against the old one and accepts it |
| 71 | + only if it wins >= updateThreshold fraction of games. |
| 72 | + """ |
| 73 | + |
| 74 | + for i in range(1, self.args.numIters+1): |
| 75 | + # bookkeeping |
| 76 | + print('------ITER ' + str(i) + '------') |
| 77 | + # examples of the iteration |
| 78 | + if not self.skipFirstSelfPlay or i>1: |
| 79 | + iterationTrainExamples = deque([], maxlen=self.args.maxlenOfQueue) |
| 80 | + |
| 81 | + eps_time = AverageMeter() |
| 82 | + bar = Bar('Self Play', max=self.args.numEps) |
| 83 | + end = time.time() |
| 84 | + |
| 85 | + for eps in range(self.args.numEps): |
| 86 | + self.mcts = MCTS(self.game, self.nnet, self.args) # reset search tree |
| 87 | + iterationTrainExamples += self.executeEpisode() |
| 88 | + |
| 89 | + # bookkeeping + plot progress |
| 90 | + eps_time.update(time.time() - end) |
| 91 | + end = time.time() |
| 92 | + bar.suffix = '({eps}/{maxeps}) Eps Time: {et:.3f}s | Total: {total:} | ETA: {eta:}'.format(eps=eps+1, maxeps=self.args.numEps, et=eps_time.avg, |
| 93 | + total=bar.elapsed_td, eta=bar.eta_td) |
| 94 | + bar.next() |
| 95 | + bar.finish() |
| 96 | + |
| 97 | + # save the iteration examples to the history |
| 98 | + self.trainExamplesHistory.append(iterationTrainExamples) |
| 99 | + |
| 100 | + if len(self.trainExamplesHistory) > self.args.numItersForTrainExamplesHistory: |
| 101 | + print("len(trainExamplesHistory) =", len(self.trainExamplesHistory), " => remove the oldest trainExamples") |
| 102 | + self.trainExamplesHistory.pop(0) |
| 103 | + # backup history to a file |
| 104 | + # NB! the examples were collected using the model from the previous iteration, so (i-1) |
| 105 | + self.saveTrainExamples(i-1) |
| 106 | + |
| 107 | + # shuffle examples before training |
| 108 | + trainExamples = [] |
| 109 | + for e in self.trainExamplesHistory: |
| 110 | + trainExamples.extend(e) |
| 111 | + shuffle(trainExamples) |
| 112 | + |
| 113 | + # training new network, keeping a copy of the old one |
| 114 | + self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar') |
| 115 | + self.pnet.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar') |
| 116 | + pmcts = MCTS(self.game, self.pnet, self.args) |
| 117 | + |
| 118 | + self.nnet.train(trainExamples) |
| 119 | + nmcts = MCTS(self.game, self.nnet, self.args) |
| 120 | + |
| 121 | + print('PITTING AGAINST PREVIOUS VERSION') |
| 122 | + arena = Arena(lambda x: np.argmax(pmcts.getActionProb(x, temp=0)), |
| 123 | + lambda x: np.argmax(nmcts.getActionProb(x, temp=0)), self.game) |
| 124 | + pwins, nwins, draws = arena.playGames(self.args.arenaCompare) |
| 125 | + |
| 126 | + print('NEW/PREV WINS : %d / %d ; DRAWS : %d' % (nwins, pwins, draws)) |
| 127 | + #if pwins+nwins == 0 or float(nwins)/(pwins+nwins) < self.args.updateThreshold: |
| 128 | + if False: |
| 129 | + print('REJECTING NEW MODEL') |
| 130 | + self.nnet.load_checkpoint(folder=self.args.checkpoint, filename='temp.pth.tar') |
| 131 | + else: |
| 132 | + print('ACCEPTING NEW MODEL') |
| 133 | + self.nnet.save_checkpoint(folder=self.args.checkpoint, filename=self.getCheckpointFile(i)) |
| 134 | + self.nnet.save_checkpoint(folder=self.args.checkpoint, filename='best.pth.tar') |
| 135 | + |
| 136 | + def getCheckpointFile(self, iteration): |
| 137 | + return 'checkpoint_' + str(iteration) + '.pth.tar' |
| 138 | + |
| 139 | + def saveTrainExamples(self, iteration): |
| 140 | + folder = self.args.checkpoint |
| 141 | + if not os.path.exists(folder): |
| 142 | + os.makedirs(folder) |
| 143 | + filename = os.path.join(folder, self.getCheckpointFile(iteration)+".examples") |
| 144 | + with open(filename, "wb+") as f: |
| 145 | + Pickler(f).dump(self.trainExamplesHistory) |
| 146 | + f.closed |
| 147 | + |
| 148 | + def loadTrainExamples(self): |
| 149 | + modelFile = os.path.join(self.args.load_folder_file[0], self.args.load_folder_file[1]) |
| 150 | + examplesFile = modelFile+".examples" |
| 151 | + if not os.path.isfile(examplesFile): |
| 152 | + print(examplesFile) |
| 153 | + r = input("File with trainExamples not found. Continue? [y|n]") |
| 154 | + if r != "y": |
| 155 | + sys.exit() |
| 156 | + else: |
| 157 | + print("File with trainExamples found. Read it.") |
| 158 | + with open(examplesFile, "rb") as f: |
| 159 | + self.trainExamplesHistory = Unpickler(f).load() |
| 160 | + f.closed |
| 161 | + # examples based on the model were already collected (loaded) |
| 162 | + self.skipFirstSelfPlay = True |
0 commit comments