Skip to content

Commit 73a86f5

Browse files
committed
Expose device arg
1 parent 822d64a commit 73a86f5

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

squadro/agents/montecarlo_agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,7 @@ def __init__(
470470
self,
471471
model_path: str = None,
472472
model_config=None,
473+
device: str = None,
473474
mcts_kwargs: Optional[dict] = None,
474475
**kwargs,
475476
):
@@ -478,7 +479,8 @@ def __init__(
478479
if 'evaluator' not in kwargs:
479480
kwargs['evaluator'] = DeepQLearningEvaluator(
480481
model_path=model_path,
481-
model_config=model_config
482+
model_config=model_config,
483+
device=device,
482484
)
483485
super().__init__(mcts_kwargs=mcts_kwargs, **kwargs)
484486

squadro/training/deep_q_learning.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -815,6 +815,7 @@ def __init__(
815815
adaptive_sampling=True,
816816
freeze_backprop=None,
817817
display_plot=True,
818+
device=None,
818819
):
819820
"""
820821
Represents a class for initializing and training a deep reinforcement learning agent
@@ -854,6 +855,7 @@ def __init__(
854855
self.agent_kwargs = dict(
855856
mcts_kwargs=self.mcts_kwargs,
856857
max_time_per_move=inf,
858+
device=device,
857859
)
858860

859861
self.agent = MonteCarloDeepQLearningAgent(

0 commit comments

Comments
 (0)