File tree Expand file tree Collapse file tree 2 files changed +5
-1
lines changed
Expand file tree Collapse file tree 2 files changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -470,6 +470,7 @@ def __init__(
470470 self ,
471471 model_path : str = None ,
472472 model_config = None ,
473+ device : str = None ,
473474 mcts_kwargs : Optional [dict ] = None ,
474475 ** kwargs ,
475476 ):
@@ -478,7 +479,8 @@ def __init__(
478479 if 'evaluator' not in kwargs :
479480 kwargs ['evaluator' ] = DeepQLearningEvaluator (
480481 model_path = model_path ,
481- model_config = model_config
482+ model_config = model_config ,
483+ device = device ,
482484 )
483485 super ().__init__ (mcts_kwargs = mcts_kwargs , ** kwargs )
484486
Original file line number Diff line number Diff line change @@ -815,6 +815,7 @@ def __init__(
815815 adaptive_sampling = True ,
816816 freeze_backprop = None ,
817817 display_plot = True ,
818+ device = None ,
818819 ):
819820 """
820821 Represents a class for initializing and training a deep reinforcement learning agent
@@ -854,6 +855,7 @@ def __init__(
854855 self .agent_kwargs = dict (
855856 mcts_kwargs = self .mcts_kwargs ,
856857 max_time_per_move = inf ,
858+ device = device ,
857859 )
858860
859861 self .agent = MonteCarloDeepQLearningAgent (
You can’t perform that action at this time.
0 commit comments