88
99from maro .rl .model import DiscretePolicyNet , FullyConnected , VNet
1010from maro .rl .policy import DiscretePolicyGradient
11- from maro .rl .training .algorithms import (
12- DiscreteActorCriticTrainer , DiscreteActorCriticParams , DiscretePPOParams , DiscretePPOTrainer ,
13- )
11+ from maro .rl .training .algorithms import DiscreteActorCriticTrainer , DiscreteActorCriticParams
1412
1513actor_net_conf = {
1614 "hidden_dims" : [256 , 128 , 64 ],
1715 "activation" : torch .nn .Tanh ,
1816 "softmax" : True ,
1917 "batch_norm" : False ,
20- "head" : True
18+ "head" : True ,
2119}
2220critic_net_conf = {
2321 "hidden_dims" : [256 , 128 , 64 ],
2422 "output_dim" : 1 ,
2523 "activation" : torch .nn .LeakyReLU ,
2624 "softmax" : False ,
2725 "batch_norm" : True ,
28- "head" : True
26+ "head" : True ,
2927}
3028actor_learning_rate = 0.001
3129critic_learning_rate = 0.001
@@ -64,7 +62,7 @@ def apply_gradients(self, grad: Dict[str, torch.Tensor]) -> None:
6462 def get_state (self ) -> dict :
6563 return {
6664 "network" : self .state_dict (),
67- "optim" : self ._optim .state_dict ()
65+ "optim" : self ._optim .state_dict (),
6866 }
6967
7068 def set_state (self , net_state : dict ) -> None :
@@ -99,7 +97,7 @@ def apply_gradients(self, grad: Dict[str, torch.Tensor]) -> None:
9997 def get_state (self ) -> dict :
10098 return {
10199 "network" : self .state_dict (),
102- "optim" : self ._optim .state_dict ()
100+ "optim" : self ._optim .state_dict (),
103101 }
104102
105103 def set_state (self , net_state : dict ) -> None :
@@ -121,7 +119,6 @@ def get_ac(state_dim: int, name: str) -> DiscreteActorCriticTrainer:
121119 return DiscreteActorCriticTrainer (
122120 name = name ,
123121 params = DiscreteActorCriticParams (
124- device = "cpu" ,
125122 get_v_critic_net_func = lambda : MyCriticNet (state_dim ),
126123 reward_discount = .0 ,
127124 grad_iters = 10 ,
0 commit comments