77
88import torch
99import torch .nn as nn
10+ from tensordict import TensorDict
1011from torch .distributions import Normal
12+ from typing import Any
1113
1214from rsl_rl .networks import CNN , MLP , EmpiricalNormalization
1315
1719class ActorCriticPerceptive (ActorCritic ):
1820 def __init__ (
1921 self ,
20- obs ,
21- obs_groups ,
22- num_actions ,
22+ obs : TensorDict ,
23+ obs_groups : dict [ str , list [ str ]] ,
24+ num_actions : int ,
2325 actor_obs_normalization : bool = False ,
2426 critic_obs_normalization : bool = False ,
2527 actor_hidden_dims : list [int ] = [256 , 256 , 256 ],
2628 critic_hidden_dims : list [int ] = [256 , 256 , 256 ],
27- actor_cnn_config : dict [str , dict ] | dict | None = None ,
28- critic_cnn_config : dict [str , dict ] | dict | None = None ,
29+ actor_cnn_cfg : dict [str , dict ] | dict | None = None ,
30+ critic_cnn_cfg : dict [str , dict ] | dict | None = None ,
2931 activation : str = "elu" ,
3032 init_noise_std : float = 1.0 ,
3133 noise_std_type : str = "scalar" ,
32- ** kwargs ,
34+ state_dependent_std : bool = False ,
35+ ** kwargs : dict [str , Any ],
3336 ) -> None :
3437 if kwargs :
3538 print (
@@ -38,195 +41,212 @@ def __init__(
3841 )
3942 nn .Module .__init__ (self )
4043
41- # get the observation dimensions
44+ # Get the observation dimensions
4245 self .obs_groups = obs_groups
4346 num_actor_obs = 0
4447 num_actor_in_channels = []
45- self .actor_obs_group_1d = []
46- self .actor_obs_group_2d = []
48+ self .actor_obs_groups_1d = []
49+ self .actor_obs_groups_2d = []
4750 for obs_group in obs_groups ["policy" ]:
4851 if len (obs [obs_group ].shape ) == 4 : # B, C, H, W
49- self .actor_obs_group_2d .append (obs_group )
52+ self .actor_obs_groups_2d .append (obs_group )
5053 num_actor_in_channels .append (obs [obs_group ].shape [1 ])
5154 elif len (obs [obs_group ].shape ) == 2 : # B, C
52- self .actor_obs_group_1d .append (obs_group )
55+ self .actor_obs_groups_1d .append (obs_group )
5356 num_actor_obs += obs [obs_group ].shape [- 1 ]
5457 else :
5558 raise ValueError (f"Invalid observation shape for { obs_group } : { obs [obs_group ].shape } " )
56-
57- self .critic_obs_group_1d = []
58- self .critic_obs_group_2d = []
5959 num_critic_obs = 0
6060 num_critic_in_channels = []
61+ self .critic_obs_groups_1d = []
62+ self .critic_obs_groups_2d = []
6163 for obs_group in obs_groups ["critic" ]:
6264 if len (obs [obs_group ].shape ) == 4 : # B, C, H, W
63- self .critic_obs_group_2d .append (obs_group )
65+ self .critic_obs_groups_2d .append (obs_group )
6466 num_critic_in_channels .append (obs [obs_group ].shape [1 ])
6567 elif len (obs [obs_group ].shape ) == 2 : # B, C
66- self .critic_obs_group_1d .append (obs_group )
68+ self .critic_obs_groups_1d .append (obs_group )
6769 num_critic_obs += obs [obs_group ].shape [- 1 ]
6870 else :
6971 raise ValueError (f"Invalid observation shape for { obs_group } : { obs [obs_group ].shape } " )
7072
71- # actor cnn
72- if self .actor_obs_group_2d :
73- assert actor_cnn_config is not None , "Actor CNN config is required for 2D actor observations."
73+ # Actor CNN
74+ if self .actor_obs_groups_2d :
75+ assert actor_cnn_cfg is not None , "An actor CNN configuration is required for 2D actor observations."
7476
75- # check if multiple 2D actor observations are provided
76- if len (self .actor_obs_group_2d ) > 1 and all (isinstance (item , dict ) for item in actor_cnn_config .values ()):
77- assert len (actor_cnn_config ) == len (self .actor_obs_group_2d ), (
78- "Number of CNN configs must match number of 2D actor observations."
77+ # Check if multiple 2D actor observations are provided
78+ if len (self .actor_obs_groups_2d ) > 1 and all (isinstance (item , dict ) for item in actor_cnn_cfg .values ()):
79+ assert len (actor_cnn_cfg ) == len (self .actor_obs_groups_2d ), (
80+ "The number of CNN configurations must match the number of 2D actor observations."
7981 )
80- elif len (self .actor_obs_group_2d ) > 1 :
82+ elif len (self .actor_obs_groups_2d ) > 1 :
8183 print (
82- "Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups."
84+ "Only one CNN configuration for multiple 2D actor observations given, using the same configuration "
85+ "for all groups."
8386 )
84- actor_cnn_config = dict (zip (self .actor_obs_group_2d , [actor_cnn_config ] * len (self .actor_obs_group_2d )))
87+ actor_cnn_cfg = dict (zip (self .actor_obs_groups_2d , [actor_cnn_cfg ] * len (self .actor_obs_groups_2d )))
8588 else :
86- actor_cnn_config = dict (zip (self .actor_obs_group_2d , [actor_cnn_config ]))
89+ actor_cnn_cfg = dict (zip (self .actor_obs_groups_2d , [actor_cnn_cfg ]))
8790
91+ # Create CNNs for each 2D actor observation
8892 self .actor_cnns = nn .ModuleDict ()
8993 encoding_dims = []
90- for idx , obs_group in enumerate (self .actor_obs_group_2d ):
91- self .actor_cnns [obs_group ] = CNN (num_actor_in_channels [idx ], activation , ** actor_cnn_config [obs_group ])
94+ for idx , obs_group in enumerate (self .actor_obs_groups_2d ):
95+ self .actor_cnns [obs_group ] = CNN (num_actor_in_channels [idx ], activation , ** actor_cnn_cfg [obs_group ])
9296 print (f"Actor CNN for { obs_group } : { self .actor_cnns [obs_group ]} " )
9397
94- # compute the encoding dimension (cpu necessary as model not moved to device yet)
98+ # Compute the encoding dimension (cpu necessary as model not moved to device yet)
9599 encoding_dims .append (self .actor_cnns [obs_group ](obs [obs_group ].to ("cpu" )).shape [- 1 ])
96-
97100 encoding_dim = sum (encoding_dims )
98101 else :
99102 self .actor_cnns = None
100103 encoding_dim = 0
101104
102- # actor mlp
103- self .actor = MLP (num_actor_obs + encoding_dim , num_actions , actor_hidden_dims , activation )
105+ # Actor MLP
106+ self .state_dependent_std = state_dependent_std
107+ if self .state_dependent_std :
108+ self .actor = MLP (num_actor_obs + encoding_dim , [2 , num_actions ], actor_hidden_dims , activation )
109+ else :
110+ self .actor = MLP (num_actor_obs + encoding_dim , num_actions , actor_hidden_dims , activation )
111+ print (f"Actor MLP: { self .actor } " )
104112
105- # actor observation normalization (only for 1D actor observations)
113+ # Actor observation normalization (only for 1D actor observations)
106114 self .actor_obs_normalization = actor_obs_normalization
107115 if actor_obs_normalization :
108116 self .actor_obs_normalizer = EmpiricalNormalization (num_actor_obs )
109117 else :
110118 self .actor_obs_normalizer = torch .nn .Identity ()
111- print (f"Actor MLP: { self .actor } " )
112119
113- # critic cnn
114- if self .critic_obs_group_2d :
115- assert critic_cnn_config is not None , "Critic CNN config is required for 2D critic observations."
120+ # Critic CNN
121+ if self .critic_obs_groups_2d :
122+ assert critic_cnn_cfg is not None , " A critic CNN configuration is required for 2D critic observations."
116123
117124 # check if multiple 2D critic observations are provided
118- if len (self .critic_obs_group_2d ) > 1 and all (isinstance (item , dict ) for item in critic_cnn_config .values ()):
119- assert len (critic_cnn_config ) == len (self .critic_obs_group_2d ), (
120- "Number of CNN configs must match number of 2D critic observations."
125+ if len (self .critic_obs_groups_2d ) > 1 and all (isinstance (item , dict ) for item in critic_cnn_cfg .values ()):
126+ assert len (critic_cnn_cfg ) == len (self .critic_obs_groups_2d ), (
127+ "The number of CNN configurations must match the number of 2D critic observations."
121128 )
122- elif len (self .critic_obs_group_2d ) > 1 :
129+ elif len (self .critic_obs_groups_2d ) > 1 :
123130 print (
124- "Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups."
125- )
126- critic_cnn_config = dict (
127- zip (self .critic_obs_group_2d , [critic_cnn_config ] * len (self .critic_obs_group_2d ))
131+ "Only one CNN configuration for multiple 2D critic observations given, using the same configuration"
132+ " for all groups."
128133 )
134+ critic_cnn_cfg = dict (zip (self .critic_obs_groups_2d , [critic_cnn_cfg ] * len (self .critic_obs_groups_2d )))
129135 else :
130- critic_cnn_config = dict (zip (self .critic_obs_group_2d , [critic_cnn_config ]))
136+ critic_cnn_cfg = dict (zip (self .critic_obs_groups_2d , [critic_cnn_cfg ]))
131137
138+ # Create CNNs for each 2D critic observation
132139 self .critic_cnns = nn .ModuleDict ()
133140 encoding_dims = []
134- for idx , obs_group in enumerate (self .critic_obs_group_2d ):
135- self .critic_cnns [obs_group ] = CNN (
136- num_critic_in_channels [idx ], activation , ** critic_cnn_config [obs_group ]
137- )
141+ for idx , obs_group in enumerate (self .critic_obs_groups_2d ):
142+ self .critic_cnns [obs_group ] = CNN (num_critic_in_channels [idx ], activation , ** critic_cnn_cfg [obs_group ])
138143 print (f"Critic CNN for { obs_group } : { self .critic_cnns [obs_group ]} " )
139144
140- # compute the encoding dimension (cpu necessary as model not moved to device yet)
145+ # Compute the encoding dimension (cpu necessary as model not moved to device yet)
141146 encoding_dims .append (self .critic_cnns [obs_group ](obs [obs_group ].to ("cpu" )).shape [- 1 ])
142-
143147 encoding_dim = sum (encoding_dims )
144148 else :
145149 self .critic_cnns = None
146150 encoding_dim = 0
147151
148- # critic mlp
152+ # Critic MLP
149153 self .critic = MLP (num_critic_obs + encoding_dim , 1 , critic_hidden_dims , activation )
154+ print (f"Critic MLP: { self .critic } " )
150155
151- # critic observation normalization (only for 1D critic observations)
156+ # Critic observation normalization (only for 1D critic observations)
152157 self .critic_obs_normalization = critic_obs_normalization
153158 if critic_obs_normalization :
154159 self .critic_obs_normalizer = EmpiricalNormalization (num_critic_obs )
155160 else :
156161 self .critic_obs_normalizer = torch .nn .Identity ()
157- print (f"Critic MLP: { self .critic } " )
158162
159163 # Action noise
160164 self .noise_std_type = noise_std_type
161- if self .noise_std_type == "scalar" :
162- self .std = nn .Parameter (init_noise_std * torch .ones (num_actions ))
163- elif self .noise_std_type == "log" :
164- self .log_std = nn .Parameter (torch .log (init_noise_std * torch .ones (num_actions )))
165+ if self .state_dependent_std :
166+ torch .nn .init .zeros_ (self .actor [- 2 ].weight [num_actions :])
167+ if self .noise_std_type == "scalar" :
168+ torch .nn .init .constant_ (self .actor [- 2 ].bias [num_actions :], init_noise_std )
169+ elif self .noise_std_type == "log" :
170+ torch .nn .init .constant_ (
171+ self .actor [- 2 ].bias [num_actions :], torch .log (torch .tensor (init_noise_std + 1e-7 ))
172+ )
173+ else :
174+ raise ValueError (f"Unknown standard deviation type: { self .noise_std_type } . Should be 'scalar' or 'log'" )
165175 else :
166- raise ValueError (f"Unknown standard deviation type: { self .noise_std_type } . Should be 'scalar' or 'log'" )
176+ if self .noise_std_type == "scalar" :
177+ self .std = nn .Parameter (init_noise_std * torch .ones (num_actions ))
178+ elif self .noise_std_type == "log" :
179+ self .log_std = nn .Parameter (torch .log (init_noise_std * torch .ones (num_actions )))
180+ else :
181+ raise ValueError (f"Unknown standard deviation type: { self .noise_std_type } . Should be 'scalar' or 'log'" )
182+
183+ # Action distribution
184+ # Note: Populated in update_distribution
185+ self .distribution = None
167186
168- # Action distribution (populated in update_distribution)
169- self .distribution : Normal = None
170- # disable args validation for speedup
187+ # Disable args validation for speedup
171188 Normal .set_default_validate_args (False )
172189
173- def update_distribution (self , mlp_obs : torch .Tensor , cnn_obs : dict [str , torch .Tensor ]) -> None :
190+ def _update_distribution (self , mlp_obs : torch .Tensor , cnn_obs : dict [str , torch .Tensor ]) -> None :
174191 if self .actor_cnns is not None :
175- # encode the 2D actor observations
176- cnn_enc_list = [self .actor_cnns [obs_group ](cnn_obs [obs_group ]) for obs_group in self .actor_obs_group_2d ]
192+ # Encode the 2D actor observations
193+ cnn_enc_list = [self .actor_cnns [obs_group ](cnn_obs [obs_group ]) for obs_group in self .actor_obs_groups_2d ]
177194 cnn_enc = torch .cat (cnn_enc_list , dim = - 1 )
178- # update mlp obs
195+ # Concatenate to the MLP observations
179196 mlp_obs = torch .cat ([mlp_obs , cnn_enc ], dim = - 1 )
180197
181- super ().update_distribution (mlp_obs )
198+ super ()._update_distribution (mlp_obs )
182199
183- def act (self , obs , ** kwargs ) :
200+ def act (self , obs : TensorDict , ** kwargs : dict [ str , Any ]) -> torch . Tensor :
184201 mlp_obs , cnn_obs = self .get_actor_obs (obs )
185202 mlp_obs = self .actor_obs_normalizer (mlp_obs )
186- self .update_distribution (mlp_obs , cnn_obs )
203+ self ._update_distribution (mlp_obs , cnn_obs )
187204 return self .distribution .sample ()
188205
189- def act_inference (self , obs ) :
206+ def act_inference (self , obs : TensorDict ) -> torch . Tensor :
190207 mlp_obs , cnn_obs = self .get_actor_obs (obs )
191208 mlp_obs = self .actor_obs_normalizer (mlp_obs )
192209
193210 if self .actor_cnns is not None :
194- # encode the 2D actor observations
195- cnn_enc_list = [self .actor_cnns [obs_group ](cnn_obs [obs_group ]) for obs_group in self .actor_obs_group_2d ]
211+ # Encode the 2D actor observations
212+ cnn_enc_list = [self .actor_cnns [obs_group ](cnn_obs [obs_group ]) for obs_group in self .actor_obs_groups_2d ]
196213 cnn_enc = torch .cat (cnn_enc_list , dim = - 1 )
197- # update mlp obs
214+ # Concatenate to the MLP observations
198215 mlp_obs = torch .cat ([mlp_obs , cnn_enc ], dim = - 1 )
199216
200- return self .actor (mlp_obs )
217+ if self .state_dependent_std :
218+ return self .actor (obs )[..., 0 , :]
219+ else :
220+ return self .actor (mlp_obs )
201221
202- def evaluate (self , obs , ** kwargs ) :
222+ def evaluate (self , obs : TensorDict , ** kwargs : dict [ str , Any ]) -> torch . Tensor :
203223 mlp_obs , cnn_obs = self .get_critic_obs (obs )
204224 mlp_obs = self .critic_obs_normalizer (mlp_obs )
205225
206226 if self .critic_cnns is not None :
207- # encode the 2D critic observations
208- cnn_enc_list = [self .critic_cnns [obs_group ](cnn_obs [obs_group ]) for obs_group in self .critic_obs_group_2d ]
227+ # Encode the 2D critic observations
228+ cnn_enc_list = [self .critic_cnns [obs_group ](cnn_obs [obs_group ]) for obs_group in self .critic_obs_groups_2d ]
209229 cnn_enc = torch .cat (cnn_enc_list , dim = - 1 )
210- # update mlp obs
230+ # Concatenate to the MLP observations
211231 mlp_obs = torch .cat ([mlp_obs , cnn_enc ], dim = - 1 )
212232
213233 return self .critic (mlp_obs )
214234
215- def get_actor_obs (self , obs ):
235+ def get_actor_obs (self , obs : TensorDict ) -> tuple [torch .Tensor , dict [str , torch .Tensor ]]:
236+ obs_list_1d = [obs [obs_group ] for obs_group in self .actor_obs_groups_1d ]
216237 obs_dict_2d = {}
217- obs_list_1d = [obs [obs_group ] for obs_group in self .actor_obs_group_1d ]
218- for obs_group in self .actor_obs_group_2d :
238+ for obs_group in self .actor_obs_groups_2d :
219239 obs_dict_2d [obs_group ] = obs [obs_group ]
220240 return torch .cat (obs_list_1d , dim = - 1 ), obs_dict_2d
221241
222- def get_critic_obs (self , obs ):
242+ def get_critic_obs (self , obs : TensorDict ) -> tuple [torch .Tensor , dict [str , torch .Tensor ]]:
243+ obs_list_1d = [obs [obs_group ] for obs_group in self .critic_obs_groups_1d ]
223244 obs_dict_2d = {}
224- obs_list_1d = [obs [obs_group ] for obs_group in self .critic_obs_group_1d ]
225- for obs_group in self .critic_obs_group_2d :
245+ for obs_group in self .critic_obs_groups_2d :
226246 obs_dict_2d [obs_group ] = obs [obs_group ]
227247 return torch .cat (obs_list_1d , dim = - 1 ), obs_dict_2d
228248
229- def update_normalization (self , obs ) -> None :
249+ def update_normalization (self , obs : TensorDict ) -> None :
230250 if self .actor_obs_normalization :
231251 actor_obs , _ = self .get_actor_obs (obs )
232252 self .actor_obs_normalizer .update (actor_obs )
0 commit comments