Skip to content

Commit ec6edd4

Browse files
committed
Add benchmark info
1 parent 88edcf5 commit ec6edd4

File tree

1 file changed

+49
-1
lines changed

1 file changed

+49
-1
lines changed

squadro/core/benchmarking.py

Lines changed: 49 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1+
from squadro import MonteCarloDeepQLearningAgent
2+
from squadro import logger
3+
from squadro.agents.montecarlo_agent import MonteCarloRolloutAgent, MonteCarloAdvancementAgent
14
from squadro.core.game import Game
25
from squadro.state.evaluators.base import Evaluator
3-
from squadro.tools.constants import DATA_PATH
6+
from squadro.tools.basic import PrettyDict
7+
from squadro.tools.constants import DATA_PATH, DefaultParams
48
from squadro.tools.dates import get_now
59
from squadro.tools.disk import mkdir
610
from squadro.tools.logs import benchmark_logger as logger
@@ -80,3 +84,47 @@ def benchmark(*args, **kwargs) -> float:
8084
Benchmark evaluation between two agents.
8185
"""
8286
return Benchmark(*args, **kwargs).run()
87+
88+
89+
def benchmark_agents(names, n_games=100, n_pawns=5, max_time_per_move=3, results=None):
90+
if logger.client is None:
91+
logger.setup(section=['benchmark', 'main'])
92+
93+
DefaultParams.max_time_per_move = max_time_per_move
94+
mcts_agent_kwargs = dict(
95+
is_training=True,
96+
mcts_kwargs=PrettyDict(tau=.5, p_mix=0, a_dirichlet=0),
97+
)
98+
99+
agents = {}
100+
for name in names:
101+
agents[name] = name
102+
if 'mcts_rollout' in names:
103+
agents['mcts_rollout'] = MonteCarloRolloutAgent(**mcts_agent_kwargs)
104+
if 'mcts_advancement' in names:
105+
agents['mcts_advancement'] = MonteCarloAdvancementAgent(**mcts_agent_kwargs)
106+
if 'mcts_deep_q_learning' in names:
107+
agents['mcts_deep_q_learning'] = MonteCarloDeepQLearningAgent(**mcts_agent_kwargs)
108+
109+
names = list(agents.keys())
110+
logger.info(names)
111+
112+
results = results or {}
113+
for i in range(len(agents)):
114+
name_i = names[i]
115+
if name_i not in results:
116+
results[name_i] = {}
117+
for j in range(i + 1, len(agents)):
118+
name_j = names[j]
119+
if results[name_i].get(name_j) is not None:
120+
continue
121+
logger.info(f"{name_i} vs {name_j}")
122+
results[name_i][name_j] = result = benchmark(
123+
agent_0=agents[name_i],
124+
agent_1=agents[name_j],
125+
n_pawns=n_pawns,
126+
n_games=n_games,
127+
)
128+
logger.info(f"{name_i} vs {name_j}: {result}\n")
129+
130+
return results

0 commit comments

Comments
 (0)