1+ # Copyright (c) 2021-2025, ETH Zurich and NVIDIA CORPORATION
2+ # All rights reserved.
3+ #
4+ # SPDX-License-Identifier: BSD-3-Clause
5+
6+ from __future__ import annotations
7+
8+ import torch
9+ import torch .nn as nn
10+ from torch .distributions import Normal
11+
12+ from .actor_critic import ActorCritic
13+
14+ from rsl_rl .networks import MLP , CNN , CNNConfig , EmpiricalNormalization
15+
16+
17+ class PerceptiveActorCritic (ActorCritic ):
18+ def __init__ (
19+ self ,
20+ obs ,
21+ obs_groups ,
22+ num_actions ,
23+ actor_obs_normalization : bool = False ,
24+ critic_obs_normalization : bool = False ,
25+ actor_hidden_dims : list [int ] = [256 , 256 , 256 ],
26+ critic_hidden_dims : list [int ] = [256 , 256 , 256 ],
27+ actor_cnn_config : dict [str , CNNConfig ] | CNNConfig | None = None ,
28+ critic_cnn_config : dict [str , CNNConfig ] | CNNConfig | None = None ,
29+ activation : str = "elu" ,
30+ init_noise_std : float = 1.0 ,
31+ noise_std_type : str = "scalar" ,
32+ ** kwargs ,
33+ ):
34+ if kwargs :
35+ print (
36+ "PerceptiveActorCritic.__init__ got unexpected arguments, which will be ignored: "
37+ + str ([key for key in kwargs .keys ()])
38+ )
39+ nn .Module .__init__ (self )
40+
41+ # get the observation dimensions
42+ self .obs_groups = obs_groups
43+ num_actor_obs = 0
44+ num_actor_in_channels = []
45+ self .actor_obs_group_1d = []
46+ self .actor_obs_group_2d = []
47+ for obs_group in obs_groups ["policy" ]:
48+ if len (obs [obs_group ].shape ) == 2 : # FIXME: should be 3???
49+ self .actor_obs_group_2d .append (obs_group )
50+ num_actor_in_channels .append (obs [obs_group ].shape [0 ])
51+ elif len (obs [obs_group ].shape ) == 1 :
52+ self .actor_obs_group_1d .append (obs_group )
53+ num_actor_obs += obs [obs_group ].shape [- 1 ]
54+ else :
55+ raise ValueError (f"Invalid observation shape for { obs_group } : { obs [obs_group ].shape } " )
56+
57+ self .critic_obs_group_1d = []
58+ self .critic_obs_group_2d = []
59+ num_critic_obs = 0
60+ num_critic_in_channels = []
61+ for obs_group in obs_groups ["critic" ]:
62+ if len (obs [obs_group ].shape ) == 2 : # FIXME: should be 3???
63+ self .critic_obs_group_2d .append (obs_group )
64+ num_critic_in_channels .append (obs [obs_group ].shape [0 ])
65+ else :
66+ self .critic_obs_group_1d .append (obs_group )
67+ num_critic_obs += obs [obs_group ].shape [- 1 ]
68+
69+ # actor cnn
70+ if self .actor_obs_group_2d :
71+ assert actor_cnn_config is not None , "Actor CNN config is required for 2D actor observations."
72+
73+ # check if multiple 2D actor observations are provided
74+ if len (self .actor_obs_group_2d ) > 1 and isinstance (actor_cnn_config , CNNConfig ):
75+ print (f"Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups." )
76+ actor_cnn_config = dict (zip (self .actor_obs_group_2d , [actor_cnn_config ] * len (self .actor_obs_group_2d )))
77+ elif len (self .actor_obs_group_2d ) > 1 and isinstance (actor_cnn_config , dict ):
78+ assert len (actor_cnn_config ) == len (self .actor_obs_group_2d ), "Number of CNN configs must match number of 2D actor observations."
79+ elif len (self .actor_obs_group_2d ) == 1 and isinstance (actor_cnn_config , CNNConfig ):
80+ actor_cnn_config = dict (zip (self .actor_obs_group_2d , [actor_cnn_config ]))
81+ else :
82+ raise ValueError (f"Invalid combination of 2D actor observations { self .actor_obs_group_2d } and actor CNN config { actor_cnn_config } ." )
83+
84+ self .actor_cnns = {}
85+ encoding_dims = []
86+ for idx , obs_group in enumerate (self .actor_obs_group_2d ):
87+ self .actor_cnns [obs_group ] = CNN (actor_cnn_config [obs_group ], num_actor_in_channels [idx ], activation )
88+ print (f"Actor CNN for { obs_group } : { self .actor_cnns [obs_group ]} " )
89+
90+ # compute the encoding dimension
91+ encoding_dims .append (self .actor_cnns [obs_group ](obs [obs_group ]).shape [- 1 ])
92+
93+ encoding_dim = sum (encoding_dims )
94+ else :
95+ self .actor_cnns = None
96+ encoding_dim = 0
97+
98+ # actor mlp
99+ self .actor = MLP (num_actor_obs + encoding_dim , num_actions , actor_hidden_dims , activation )
100+
101+ # actor observation normalization (only for 1D actor observations)
102+ self .actor_obs_normalization = actor_obs_normalization
103+ if actor_obs_normalization :
104+ self .actor_obs_normalizer = EmpiricalNormalization (num_actor_obs )
105+ else :
106+ self .actor_obs_normalizer = torch .nn .Identity ()
107+ print (f"Actor MLP: { self .actor } " )
108+
109+ # critic cnn
110+ if self .critic_obs_group_2d :
111+ assert critic_cnn_config is not None , "Critic CNN config is required for 2D critic observations."
112+
113+ # check if multiple 2D critic observations are provided
114+ if len (self .critic_obs_group_2d ) > 1 and isinstance (critic_cnn_config , CNNConfig ):
115+ print (f"Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups." )
116+ critic_cnn_config = dict (zip (self .critic_obs_group_2d , [critic_cnn_config ] * len (self .critic_obs_group_2d )))
117+ elif len (self .critic_obs_group_2d ) > 1 and isinstance (critic_cnn_config , dict ):
118+ assert len (critic_cnn_config ) == len (self .critic_obs_group_2d ), "Number of CNN configs must match number of 2D critic observations."
119+ elif len (self .critic_obs_group_2d ) == 1 and isinstance (critic_cnn_config , CNNConfig ):
120+ critic_cnn_config = dict (zip (self .critic_obs_group_2d , [critic_cnn_config ]))
121+ else :
122+ raise ValueError (f"Invalid combination of 2D critic observations { self .critic_obs_group_2d } and critic CNN config { critic_cnn_config } ." )
123+
124+ self .critic_cnns = {}
125+ encoding_dims = []
126+ for idx , obs_group in enumerate (self .critic_obs_group_2d ):
127+ self .critic_cnns [obs_group ] = CNN (critic_cnn_config [obs_group ], num_critic_in_channels [idx ], activation )
128+ print (f"Critic CNN for { obs_group } : { self .critic_cnns [obs_group ]} " )
129+
130+ # compute the encoding dimension
131+ encoding_dims .append (self .critic_cnns [obs_group ](obs [obs_group ]).shape [- 1 ])
132+
133+ encoding_dim = sum (encoding_dims )
134+ else :
135+ self .critic_cnns = None
136+ encoding_dim = 0
137+
138+ # critic mlp
139+ self .critic = MLP (num_critic_obs + encoding_dim , 1 , critic_hidden_dims , activation )
140+
141+ # critic observation normalization (only for 1D critic observations)
142+ self .critic_obs_normalization = critic_obs_normalization
143+ if critic_obs_normalization :
144+ self .critic_obs_normalizer = EmpiricalNormalization (num_critic_obs )
145+ else :
146+ self .critic_obs_normalizer = torch .nn .Identity ()
147+ print (f"Critic MLP: { self .critic } " )
148+
149+ # Action noise
150+ self .noise_std_type = noise_std_type
151+ if self .noise_std_type == "scalar" :
152+ self .std = nn .Parameter (init_noise_std * torch .ones (num_actions ))
153+ elif self .noise_std_type == "log" :
154+ self .log_std = nn .Parameter (torch .log (init_noise_std * torch .ones (num_actions )))
155+ else :
156+ raise ValueError (f"Unknown standard deviation type: { self .noise_std_type } . Should be 'scalar' or 'log'" )
157+
158+ # Action distribution (populated in update_distribution)
159+ self .distribution : Normal = None
160+ # disable args validation for speedup
161+ Normal .set_default_validate_args (False )
162+
163+ def update_distribution (self , mlp_obs : torch .Tensor , cnn_obs : dict [str , torch .Tensor ]):
164+
165+ if self .actor_cnns is not None :
166+ # encode the 2D actor observations
167+ cnn_enc_list = []
168+ for obs_group in self .actor_obs_group_2d :
169+ cnn_enc_list .append (self .actor_cnns [obs_group ](cnn_obs [obs_group ]))
170+ cnn_enc = torch .cat (cnn_enc_list , dim = - 1 )
171+ # update mlp obs
172+ mlp_obs = torch .cat ([mlp_obs , cnn_enc ], dim = - 1 )
173+
174+ super ().update_distribution (mlp_obs )
175+
176+ def act (self , obs , ** kwargs ):
177+ mlp_obs , cnn_obs = self .get_actor_obs (obs )
178+ mlp_obs = self .actor_obs_normalizer (mlp_obs )
179+ self .update_distribution (mlp_obs , cnn_obs )
180+ return self .distribution .sample ()
181+
182+ def act_inference (self , obs ):
183+ mlp_obs , cnn_obs = self .get_actor_obs (obs )
184+ mlp_obs = self .actor_obs_normalizer (mlp_obs )
185+
186+ if self .actor_cnns is not None :
187+ # encode the 2D actor observations
188+ cnn_enc_list = []
189+ for obs_group in self .actor_obs_group_2d :
190+ cnn_enc_list .append (self .actor_cnns [obs_group ](cnn_obs [obs_group ]))
191+ cnn_enc = torch .cat (cnn_enc_list , dim = - 1 )
192+ # update mlp obs
193+ mlp_obs = torch .cat ([mlp_obs , cnn_enc ], dim = - 1 )
194+
195+ return self .actor (mlp_obs )
196+
197+ def evaluate (self , obs , ** kwargs ):
198+ mlp_obs , cnn_obs = self .get_critic_obs (obs )
199+ mlp_obs = self .critic_obs_normalizer (mlp_obs )
200+
201+ if self .critic_cnns is not None :
202+ # encode the 2D critic observations
203+ cnn_enc_list = []
204+ for obs_group in self .critic_obs_group_2d :
205+ cnn_enc_list .append (self .critic_cnns [obs_group ](cnn_obs [obs_group ]))
206+ cnn_enc = torch .cat (cnn_enc_list , dim = - 1 )
207+ # update mlp obs
208+ mlp_obs = torch .cat ([mlp_obs , cnn_enc ], dim = - 1 )
209+
210+ return self .critic (mlp_obs )
211+
212+ def get_actor_obs (self , obs ):
213+ obs_list_1d = []
214+ obs_dict_2d = {}
215+ for obs_group in self .actor_obs_group_1d :
216+ obs_list_1d .append (obs [obs_group ])
217+ for obs_group in self .actor_obs_group_2d :
218+ obs_dict_2d [obs_group ] = obs [obs_group ]
219+ return torch .cat (obs_list_1d , dim = - 1 ), obs_dict_2d
220+
221+ def get_critic_obs (self , obs ):
222+ obs_list_1d = []
223+ obs_dict_2d = {}
224+ for obs_group in self .critic_obs_group_1d :
225+ obs_list_1d .append (obs [obs_group ])
226+ for obs_group in self .critic_obs_group_2d :
227+ obs_dict_2d [obs_group ] = obs [obs_group ]
228+ return torch .cat (obs_list_1d , dim = - 1 ), obs_dict_2d
229+
230+ def update_normalization (self , obs ):
231+ if self .actor_obs_normalization :
232+ actor_obs , _ = self .get_actor_obs (obs )
233+ self .actor_obs_normalizer .update (actor_obs )
234+ if self .critic_obs_normalization :
235+ critic_obs , _ = self .get_critic_obs (obs )
236+ self .critic_obs_normalizer .update (critic_obs )
0 commit comments