99import torch .nn as nn
1010from torch .distributions import Normal
1111
12- from . actor_critic import ActorCritic
12+ from rsl_rl . networks import CNN , MLP , EmpiricalNormalization
1313
14- from rsl_rl . networks import MLP , CNN , EmpiricalNormalization
14+ from . actor_critic import ActorCritic
1515
1616
1717class PerceptiveActorCritic (ActorCritic ):
18- def __init__ (
18+ def __init__ ( # noqa: C901
1919 self ,
2020 obs ,
2121 obs_groups ,
@@ -53,7 +53,7 @@ def __init__(
5353 num_actor_obs += obs [obs_group ].shape [- 1 ]
5454 else :
5555 raise ValueError (f"Invalid observation shape for { obs_group } : { obs [obs_group ].shape } " )
56-
56+
5757 self .critic_obs_group_1d = []
5858 self .critic_obs_group_2d = []
5959 num_critic_obs = 0
@@ -71,12 +71,16 @@ def __init__(
7171 # actor cnn
7272 if self .actor_obs_group_2d :
7373 assert actor_cnn_config is not None , "Actor CNN config is required for 2D actor observations."
74-
74+
7575 # check if multiple 2D actor observations are provided
7676 if len (self .actor_obs_group_2d ) > 1 and all (isinstance (item , dict ) for item in actor_cnn_config .values ()):
77- assert len (actor_cnn_config ) == len (self .actor_obs_group_2d ), "Number of CNN configs must match number of 2D actor observations."
77+ assert len (actor_cnn_config ) == len (
78+ self .actor_obs_group_2d
79+ ), "Number of CNN configs must match number of 2D actor observations."
7880 elif len (self .actor_obs_group_2d ) > 1 :
79- print (f"Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups." )
81+ print (
82+ "Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups."
83+ )
8084 actor_cnn_config = dict (zip (self .actor_obs_group_2d , [actor_cnn_config ] * len (self .actor_obs_group_2d )))
8185 else :
8286 actor_cnn_config = dict (zip (self .actor_obs_group_2d , [actor_cnn_config ]))
@@ -89,15 +93,15 @@ def __init__(
8993
9094 # compute the encoding dimension (cpu necessary as model not moved to device yet)
9195 encoding_dims .append (self .actor_cnns [obs_group ](obs [obs_group ].to ("cpu" )).shape [- 1 ])
92-
96+
9397 encoding_dim = sum (encoding_dims )
9498 else :
9599 self .actor_cnns = None
96100 encoding_dim = 0
97101
98102 # actor mlp
99103 self .actor = MLP (num_actor_obs + encoding_dim , num_actions , actor_hidden_dims , activation )
100-
104+
101105 # actor observation normalization (only for 1D actor observations)
102106 self .actor_obs_normalization = actor_obs_normalization
103107 if actor_obs_normalization :
@@ -109,33 +113,41 @@ def __init__(
109113 # critic cnn
110114 if self .critic_obs_group_2d :
111115 assert critic_cnn_config is not None , "Critic CNN config is required for 2D critic observations."
112-
116+
113117 # check if multiple 2D critic observations are provided
114118 if len (self .critic_obs_group_2d ) > 1 and all (isinstance (item , dict ) for item in critic_cnn_config .values ()):
115- assert len (critic_cnn_config ) == len (self .critic_obs_group_2d ), "Number of CNN configs must match number of 2D critic observations."
119+ assert len (critic_cnn_config ) == len (
120+ self .critic_obs_group_2d
121+ ), "Number of CNN configs must match number of 2D critic observations."
116122 elif len (self .critic_obs_group_2d ) > 1 :
117- print (f"Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups." )
118- critic_cnn_config = dict (zip (self .critic_obs_group_2d , [critic_cnn_config ] * len (self .critic_obs_group_2d )))
123+ print (
124+ "Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups."
125+ )
126+ critic_cnn_config = dict (
127+ zip (self .critic_obs_group_2d , [critic_cnn_config ] * len (self .critic_obs_group_2d ))
128+ )
119129 else :
120130 critic_cnn_config = dict (zip (self .critic_obs_group_2d , [critic_cnn_config ]))
121131
122132 self .critic_cnns = nn .ModuleDict ()
123133 encoding_dims = []
124134 for idx , obs_group in enumerate (self .critic_obs_group_2d ):
125- self .critic_cnns [obs_group ] = CNN (num_critic_in_channels [idx ], activation , ** critic_cnn_config [obs_group ])
135+ self .critic_cnns [obs_group ] = CNN (
136+ num_critic_in_channels [idx ], activation , ** critic_cnn_config [obs_group ]
137+ )
126138 print (f"Critic CNN for { obs_group } : { self .critic_cnns [obs_group ]} " )
127139
128140 # compute the encoding dimension (cpu necessary as model not moved to device yet)
129141 encoding_dims .append (self .critic_cnns [obs_group ](obs [obs_group ].to ("cpu" )).shape [- 1 ])
130-
142+
131143 encoding_dim = sum (encoding_dims )
132144 else :
133145 self .critic_cnns = None
134146 encoding_dim = 0
135147
136148 # critic mlp
137149 self .critic = MLP (num_critic_obs + encoding_dim , 1 , critic_hidden_dims , activation )
138-
150+
139151 # critic observation normalization (only for 1D critic observations)
140152 self .critic_obs_normalization = critic_obs_normalization
141153 if critic_obs_normalization :
@@ -159,7 +171,7 @@ def __init__(
159171 Normal .set_default_validate_args (False )
160172
161173 def update_distribution (self , mlp_obs : torch .Tensor , cnn_obs : dict [str , torch .Tensor ]):
162-
174+
163175 if self .actor_cnns is not None :
164176 # encode the 2D actor observations
165177 cnn_enc_list = []
@@ -168,7 +180,7 @@ def update_distribution(self, mlp_obs: torch.Tensor, cnn_obs: dict[str, torch.Te
168180 cnn_enc = torch .cat (cnn_enc_list , dim = - 1 )
169181 # update mlp obs
170182 mlp_obs = torch .cat ([mlp_obs , cnn_enc ], dim = - 1 )
171-
183+
172184 super ().update_distribution (mlp_obs )
173185
174186 def act (self , obs , ** kwargs ):
@@ -180,7 +192,7 @@ def act(self, obs, **kwargs):
180192 def act_inference (self , obs ):
181193 mlp_obs , cnn_obs = self .get_actor_obs (obs )
182194 mlp_obs = self .actor_obs_normalizer (mlp_obs )
183-
195+
184196 if self .actor_cnns is not None :
185197 # encode the 2D actor observations
186198 cnn_enc_list = []
@@ -189,7 +201,7 @@ def act_inference(self, obs):
189201 cnn_enc = torch .cat (cnn_enc_list , dim = - 1 )
190202 # update mlp obs
191203 mlp_obs = torch .cat ([mlp_obs , cnn_enc ], dim = - 1 )
192-
204+
193205 return self .actor (mlp_obs )
194206
195207 def evaluate (self , obs , ** kwargs ):
@@ -204,7 +216,7 @@ def evaluate(self, obs, **kwargs):
204216 cnn_enc = torch .cat (cnn_enc_list , dim = - 1 )
205217 # update mlp obs
206218 mlp_obs = torch .cat ([mlp_obs , cnn_enc ], dim = - 1 )
207-
219+
208220 return self .critic (mlp_obs )
209221
210222 def get_actor_obs (self , obs ):
@@ -231,4 +243,4 @@ def update_normalization(self, obs):
231243 self .actor_obs_normalizer .update (actor_obs )
232244 if self .critic_obs_normalization :
233245 critic_obs , _ = self .get_critic_obs (obs )
234- self .critic_obs_normalizer .update (critic_obs )
246+ self .critic_obs_normalizer .update (critic_obs )
0 commit comments