-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathtrain.py
More file actions
31 lines (27 loc) · 722 Bytes
/
train.py
File metadata and controls
31 lines (27 loc) · 722 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from chess import Chess
from agents import SingleAgentChess, DoubleAgentsChess
from learnings.ppo import PPO
buffer_size = 32
if __name__ == "__main__":
chess = Chess(window_size=512, max_steps=128, render_mode="rgb_array")
chess.reset()
ppo = PPO(
chess,
hidden_layers=(2048,) * 4,
epochs=100,
buffer_size=buffer_size * 2,
batch_size=128,
)
print(ppo.device)
print(ppo)
print("-" * 64)
agent = DoubleAgentsChess(
env=chess,
learner=ppo,
episodes=2000,
train_on=buffer_size,
result_folder="results/DoubleAgents",
)
agent.train(render_each=20, save_on_learn=True)
agent.save()
chess.close()