@@ -28,6 +28,7 @@ def __init__(
2828 activation = "elu" ,
2929 init_noise_std = 1.0 ,
3030 noise_std_type : str = "scalar" ,
31+ state_dependent_std = False ,
3132 rnn_type = "lstm" ,
3233 rnn_hidden_dim = 256 ,
3334 rnn_num_layers = 1 ,
@@ -58,9 +59,14 @@ def __init__(
5859 assert len (obs [obs_group ].shape ) == 2 , "The ActorCriticRecurrent module only supports 1D observations."
5960 num_critic_obs += obs [obs_group ].shape [- 1 ]
6061
62+ self .state_dependent_std = state_dependent_std
6163 # actor
6264 self .memory_a = Memory (num_actor_obs , type = rnn_type , num_layers = rnn_num_layers , hidden_size = rnn_hidden_dim )
63- self .actor = MLP (rnn_hidden_dim , num_actions , actor_hidden_dims , activation )
65+ if self .state_dependent_std :
66+ self .actor = MLP (rnn_hidden_dim , [2 , num_actions ], actor_hidden_dims , activation )
67+ else :
68+ self .actor = MLP (rnn_hidden_dim , num_actions , actor_hidden_dims , activation )
69+
6470 # actor observation normalization
6571 self .actor_obs_normalization = actor_obs_normalization
6672 if actor_obs_normalization :
@@ -84,12 +90,21 @@ def __init__(
8490
8591 # Action noise
8692 self .noise_std_type = noise_std_type
87- if self .noise_std_type == "scalar" :
88- self .std = nn .Parameter (init_noise_std * torch .ones (num_actions ))
89- elif self .noise_std_type == "log" :
90- self .log_std = nn .Parameter (torch .log (init_noise_std * torch .ones (num_actions )))
93+ if self .state_dependent_std :
94+ torch .nn .init .zeros_ (self .actor [- 2 ].weight [num_actions :])
95+ if self .noise_std_type == "scalar" :
96+ torch .nn .init .constant_ (self .actor [- 2 ].bias [num_actions :], init_noise_std )
97+ elif self .noise_std_type == "log" :
98+ torch .nn .init .constant_ (self .actor [- 2 ].bias [num_actions :], torch .log (torch .tensor (init_noise_std + 1e-7 )))
99+ else :
100+ raise ValueError (f"Unknown standard deviation type: { self .noise_std_type } . Should be 'scalar' or 'log'" )
91101 else :
92- raise ValueError (f"Unknown standard deviation type: { self .noise_std_type } . Should be 'scalar' or 'log'" )
102+ if self .noise_std_type == "scalar" :
103+ self .std = nn .Parameter (init_noise_std * torch .ones (num_actions ))
104+ elif self .noise_std_type == "log" :
105+ self .log_std = nn .Parameter (torch .log (init_noise_std * torch .ones (num_actions )))
106+ else :
107+ raise ValueError (f"Unknown standard deviation type: { self .noise_std_type } . Should be 'scalar' or 'log'" )
93108
94109 # Action distribution (populated in update_distribution)
95110 self .distribution = None
@@ -116,15 +131,26 @@ def forward(self):
116131 raise NotImplementedError
117132
118133 def update_distribution (self , obs ):
119- # compute mean
120- mean = self .actor (obs )
121- # compute standard deviation
122- if self .noise_std_type == "scalar" :
123- std = self .std .expand_as (mean )
124- elif self .noise_std_type == "log" :
125- std = torch .exp (self .log_std ).expand_as (mean )
134+ if self .state_dependent_std :
135+ # compute mean and standard deviation
136+ mean_and_std = self .actor (obs )
137+ if self .noise_std_type == "scalar" :
138+ mean , std = torch .unbind (mean_and_std , dim = - 2 )
139+ elif self .noise_std_type == "log" :
140+ mean , log_std = torch .unbind (mean_and_std , dim = - 2 )
141+ std = torch .exp (log_std )
142+ else :
143+ raise ValueError (f"Unknown standard deviation type: { self .noise_std_type } . Should be 'scalar' or 'log'" )
126144 else :
127- raise ValueError (f"Unknown standard deviation type: { self .noise_std_type } . Should be 'scalar' or 'log'" )
145+ # compute mean
146+ mean = self .actor (obs )
147+ # compute standard deviation
148+ if self .noise_std_type == "scalar" :
149+ std = self .std .expand_as (mean )
150+ elif self .noise_std_type == "log" :
151+ std = torch .exp (self .log_std ).expand_as (mean )
152+ else :
153+ raise ValueError (f"Unknown standard deviation type: { self .noise_std_type } . Should be 'scalar' or 'log'" )
128154 # create distribution
129155 self .distribution = Normal (mean , std )
130156
0 commit comments