Skip to content

Commit 530f71a

Browse files
authored
Fix: make act_inference return policy mean (without std dev) at deployment time (#118)
* fix to make act_inference return just the policy mean at deployment time, when the policy was trained with state-dependent standard-dev.
1 parent 1c63d8e commit 530f71a

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

rsl_rl/modules/actor_critic.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,10 @@ def act(self, obs, **kwargs):
148148
def act_inference(self, obs):
149149
obs = self.get_actor_obs(obs)
150150
obs = self.actor_obs_normalizer(obs)
151-
return self.actor(obs)
151+
if self.state_dependent_std:
152+
return self.actor(obs)[..., 0, :]
153+
else:
154+
return self.actor(obs)
152155

153156
def evaluate(self, obs, **kwargs):
154157
obs = self.get_critic_obs(obs)

rsl_rl/modules/actor_critic_recurrent.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,10 @@ def act_inference(self, obs):
167167
obs = self.get_actor_obs(obs)
168168
obs = self.actor_obs_normalizer(obs)
169169
out_mem = self.memory_a(obs).squeeze(0)
170-
return self.actor(out_mem)
170+
if self.state_dependent_std:
171+
return self.actor(out_mem)[..., 0, :]
172+
else:
173+
return self.actor(out_mem)
171174

172175
def evaluate(self, obs, masks=None, hidden_states=None):
173176
obs = self.get_critic_obs(obs)

0 commit comments

Comments
 (0)