1111
1212from .actor_critic import ActorCritic
1313
14- from rsl_rl .networks import MLP , CNN , CNNConfig , EmpiricalNormalization
14+ from rsl_rl .networks import MLP , CNN , EmpiricalNormalization
1515
1616
1717class PerceptiveActorCritic (ActorCritic ):
@@ -24,8 +24,8 @@ def __init__(
2424 critic_obs_normalization : bool = False ,
2525 actor_hidden_dims : list [int ] = [256 , 256 , 256 ],
2626 critic_hidden_dims : list [int ] = [256 , 256 , 256 ],
27- actor_cnn_config : dict [str , CNNConfig ] | CNNConfig | None = None ,
28- critic_cnn_config : dict [str , CNNConfig ] | CNNConfig | None = None ,
27+ actor_cnn_config : dict [str , dict ] | dict | None = None ,
28+ critic_cnn_config : dict [str , dict ] | dict | None = None ,
2929 activation : str = "elu" ,
3030 init_noise_std : float = 1.0 ,
3131 noise_std_type : str = "scalar" ,
@@ -45,10 +45,10 @@ def __init__(
4545 self .actor_obs_group_1d = []
4646 self .actor_obs_group_2d = []
4747 for obs_group in obs_groups ["policy" ]:
48- if len (obs [obs_group ].shape ) == 2 : # FIXME: should be 3???
48+ if len (obs [obs_group ].shape ) == 4 : # B, C, H, W
4949 self .actor_obs_group_2d .append (obs_group )
50- num_actor_in_channels .append (obs [obs_group ].shape [0 ])
51- elif len (obs [obs_group ].shape ) == 1 :
50+ num_actor_in_channels .append (obs [obs_group ].shape [1 ])
51+ elif len (obs [obs_group ].shape ) == 2 : # B, C
5252 self .actor_obs_group_1d .append (obs_group )
5353 num_actor_obs += obs [obs_group ].shape [- 1 ]
5454 else :
@@ -59,36 +59,36 @@ def __init__(
5959 num_critic_obs = 0
6060 num_critic_in_channels = []
6161 for obs_group in obs_groups ["critic" ]:
62- if len (obs [obs_group ].shape ) == 2 : # FIXME: should be 3???
62+ if len (obs [obs_group ].shape ) == 4 : # B, C, H, W
6363 self .critic_obs_group_2d .append (obs_group )
64- num_critic_in_channels .append (obs [obs_group ].shape [0 ])
65- else :
64+ num_critic_in_channels .append (obs [obs_group ].shape [1 ])
65+ elif len ( obs [ obs_group ]. shape ) == 2 : # B, C
6666 self .critic_obs_group_1d .append (obs_group )
6767 num_critic_obs += obs [obs_group ].shape [- 1 ]
68+ else :
69+ raise ValueError (f"Invalid observation shape for { obs_group } : { obs [obs_group ].shape } " )
6870
6971 # actor cnn
7072 if self .actor_obs_group_2d :
7173 assert actor_cnn_config is not None , "Actor CNN config is required for 2D actor observations."
7274
7375 # check if multiple 2D actor observations are provided
74- if len (self .actor_obs_group_2d ) > 1 and isinstance (actor_cnn_config , CNNConfig ):
76+ if len (self .actor_obs_group_2d ) > 1 and all (isinstance (item , dict ) for item in actor_cnn_config .values ()):
77+ assert len (actor_cnn_config ) == len (self .actor_obs_group_2d ), "Number of CNN configs must match number of 2D actor observations."
78+ elif len (self .actor_obs_group_2d ) > 1 :
7579 print (f"Only one CNN config for multiple 2D actor observations given, using the same CNN for all groups." )
7680 actor_cnn_config = dict (zip (self .actor_obs_group_2d , [actor_cnn_config ] * len (self .actor_obs_group_2d )))
77- elif len (self .actor_obs_group_2d ) > 1 and isinstance (actor_cnn_config , dict ):
78- assert len (actor_cnn_config ) == len (self .actor_obs_group_2d ), "Number of CNN configs must match number of 2D actor observations."
79- elif len (self .actor_obs_group_2d ) == 1 and isinstance (actor_cnn_config , CNNConfig ):
80- actor_cnn_config = dict (zip (self .actor_obs_group_2d , [actor_cnn_config ]))
8181 else :
82- raise ValueError ( f"Invalid combination of 2D actor observations { self .actor_obs_group_2d } and actor CNN config { actor_cnn_config } ." )
82+ actor_cnn_config = dict ( zip ( self .actor_obs_group_2d , [ actor_cnn_config ]) )
8383
84- self .actor_cnns = {}
84+ self .actor_cnns = nn . ModuleDict ()
8585 encoding_dims = []
8686 for idx , obs_group in enumerate (self .actor_obs_group_2d ):
87- self .actor_cnns [obs_group ] = CNN (actor_cnn_config [ obs_group ], num_actor_in_channels [idx ], activation )
87+ self .actor_cnns [obs_group ] = CNN (num_actor_in_channels [idx ], activation , ** actor_cnn_config [ obs_group ] )
8888 print (f"Actor CNN for { obs_group } : { self .actor_cnns [obs_group ]} " )
8989
90- # compute the encoding dimension
91- encoding_dims .append (self .actor_cnns [obs_group ](obs [obs_group ]).shape [- 1 ])
90+ # compute the encoding dimension (cpu necessary as model not moved to device yet)
91+ encoding_dims .append (self .actor_cnns [obs_group ](obs [obs_group ]. to ( "cpu" ) ).shape [- 1 ])
9292
9393 encoding_dim = sum (encoding_dims )
9494 else :
@@ -111,24 +111,22 @@ def __init__(
111111 assert critic_cnn_config is not None , "Critic CNN config is required for 2D critic observations."
112112
113113 # check if multiple 2D critic observations are provided
114- if len (self .critic_obs_group_2d ) > 1 and isinstance (critic_cnn_config , CNNConfig ):
114+ if len (self .critic_obs_group_2d ) > 1 and all (isinstance (item , dict ) for item in critic_cnn_config .values ()):
115+ assert len (critic_cnn_config ) == len (self .critic_obs_group_2d ), "Number of CNN configs must match number of 2D critic observations."
116+ elif len (self .critic_obs_group_2d ) > 1 :
115117 print (f"Only one CNN config for multiple 2D critic observations given, using the same CNN for all groups." )
116118 critic_cnn_config = dict (zip (self .critic_obs_group_2d , [critic_cnn_config ] * len (self .critic_obs_group_2d )))
117- elif len (self .critic_obs_group_2d ) > 1 and isinstance (critic_cnn_config , dict ):
118- assert len (critic_cnn_config ) == len (self .critic_obs_group_2d ), "Number of CNN configs must match number of 2D critic observations."
119- elif len (self .critic_obs_group_2d ) == 1 and isinstance (critic_cnn_config , CNNConfig ):
120- critic_cnn_config = dict (zip (self .critic_obs_group_2d , [critic_cnn_config ]))
121119 else :
122- raise ValueError ( f"Invalid combination of 2D critic observations { self .critic_obs_group_2d } and critic CNN config { critic_cnn_config } ." )
120+ critic_cnn_config = dict ( zip ( self .critic_obs_group_2d , [ critic_cnn_config ]) )
123121
124- self .critic_cnns = {}
122+ self .critic_cnns = nn . ModuleDict ()
125123 encoding_dims = []
126124 for idx , obs_group in enumerate (self .critic_obs_group_2d ):
127- self .critic_cnns [obs_group ] = CNN (critic_cnn_config [ obs_group ], num_critic_in_channels [idx ], activation )
125+ self .critic_cnns [obs_group ] = CNN (num_critic_in_channels [idx ], activation , ** critic_cnn_config [ obs_group ] )
128126 print (f"Critic CNN for { obs_group } : { self .critic_cnns [obs_group ]} " )
129127
130- # compute the encoding dimension
131- encoding_dims .append (self .critic_cnns [obs_group ](obs [obs_group ]).shape [- 1 ])
128+ # compute the encoding dimension (cpu necessary as model not moved to device yet)
129+ encoding_dims .append (self .critic_cnns [obs_group ](obs [obs_group ]. to ( "cpu" ) ).shape [- 1 ])
132130
133131 encoding_dim = sum (encoding_dims )
134132 else :
0 commit comments