|
| 1 | +from squadro import MonteCarloDeepQLearningAgent |
| 2 | +from squadro import logger |
| 3 | +from squadro.agents.montecarlo_agent import MonteCarloRolloutAgent, MonteCarloAdvancementAgent |
1 | 4 | from squadro.core.game import Game |
2 | 5 | from squadro.state.evaluators.base import Evaluator |
3 | | -from squadro.tools.constants import DATA_PATH |
| 6 | +from squadro.tools.basic import PrettyDict |
| 7 | +from squadro.tools.constants import DATA_PATH, DefaultParams |
4 | 8 | from squadro.tools.dates import get_now |
5 | 9 | from squadro.tools.disk import mkdir |
6 | 10 | from squadro.tools.logs import benchmark_logger as logger |
@@ -80,3 +84,47 @@ def benchmark(*args, **kwargs) -> float: |
80 | 84 | Benchmark evaluation between two agents. |
81 | 85 | """ |
82 | 86 | return Benchmark(*args, **kwargs).run() |
| 87 | + |
| 88 | + |
| 89 | +def benchmark_agents(names, n_games=100, n_pawns=5, max_time_per_move=3, results=None): |
| 90 | + if logger.client is None: |
| 91 | + logger.setup(section=['benchmark', 'main']) |
| 92 | + |
| 93 | + DefaultParams.max_time_per_move = max_time_per_move |
| 94 | + mcts_agent_kwargs = dict( |
| 95 | + is_training=True, |
| 96 | + mcts_kwargs=PrettyDict(tau=.5, p_mix=0, a_dirichlet=0), |
| 97 | + ) |
| 98 | + |
| 99 | + agents = {} |
| 100 | + for name in names: |
| 101 | + agents[name] = name |
| 102 | + if 'mcts_rollout' in names: |
| 103 | + agents['mcts_rollout'] = MonteCarloRolloutAgent(**mcts_agent_kwargs) |
| 104 | + if 'mcts_advancement' in names: |
| 105 | + agents['mcts_advancement'] = MonteCarloAdvancementAgent(**mcts_agent_kwargs) |
| 106 | + if 'mcts_deep_q_learning' in names: |
| 107 | + agents['mcts_deep_q_learning'] = MonteCarloDeepQLearningAgent(**mcts_agent_kwargs) |
| 108 | + |
| 109 | + names = list(agents.keys()) |
| 110 | + logger.info(names) |
| 111 | + |
| 112 | + results = results or {} |
| 113 | + for i in range(len(agents)): |
| 114 | + name_i = names[i] |
| 115 | + if name_i not in results: |
| 116 | + results[name_i] = {} |
| 117 | + for j in range(i + 1, len(agents)): |
| 118 | + name_j = names[j] |
| 119 | + if results[name_i].get(name_j) is not None: |
| 120 | + continue |
| 121 | + logger.info(f"{name_i} vs {name_j}") |
| 122 | + results[name_i][name_j] = result = benchmark( |
| 123 | + agent_0=agents[name_i], |
| 124 | + agent_1=agents[name_j], |
| 125 | + n_pawns=n_pawns, |
| 126 | + n_games=n_games, |
| 127 | + ) |
| 128 | + logger.info(f"{name_i} vs {name_j}: {result}\n") |
| 129 | + |
| 130 | + return results |
0 commit comments