@@ -27,6 +27,9 @@ def __init__(
2727 :param entropy_coefficient: weight on entropy term
2828 :param discrete: whether the action space is discrete
2929 """
30+
31+ self .model = model
32+
3033 super ().__init__ (algo , model , discrete )
3134 self .entropy_coefficient = entropy_coefficient
3235 self .discrete = discrete
@@ -84,33 +87,24 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
8487 # 4-tuple case (Tanh squashing): (action, z, mean, log_std)
8588 elif isinstance (model_output , tuple ) and len (model_output ) == 4 :
8689 action , z , mean , log_std = model_output
87- log_prob = sample_nondeterministic_logprobs (
88- z = z ,
89- mean = mean ,
90- log_std = log_std ,
91- sac = self .algo == "sac" ,
92- )
90+
91+ if not self .algo == "sac" :
92+
93+ log_prob = sample_nondeterministic_logprobs (
94+ z = z ,
95+ mean = mean ,
96+ log_std = log_std ,
97+ sac = False ,
98+ )
99+ else :
100+ log_prob = self .model .policy_log_prob (z , mean , log_std )
93101
94102 if return_logp :
95103 return action .detach ().cpu ().numpy (), log_prob
96104 else :
97105 weighted_log_prob = log_prob * self .entropy_coefficient
98106 return action .detach ().cpu ().numpy (), weighted_log_prob
99107
100- # Legacy 2-tuple case: (mean, std)
101- elif isinstance (model_output , tuple ) and len (model_output ) == 2 :
102- mean , std = model_output
103- dist = Normal (mean , std )
104- z = dist .rsample () # [batch, action_dim]
105- action = torch .tanh (z ) # [batch, action_dim]
106-
107- log_prob = sample_nondeterministic_logprobs (
108- z = z , mean = mean , log_std = torch .log (std ), sac = self .algo == "sac"
109- )
110- entropy = dist .entropy ().sum (dim = - 1 , keepdim = True ) # [batch, 1]
111- weighted_log_prob = log_prob * entropy
112- return action .detach ().cpu ().numpy (), weighted_log_prob
113-
114108 # Check for model attribute-based approaches
115109 elif hasattr (self .model , "continuous_action" ) and getattr (
116110 self .model , "continuous_action"
@@ -126,9 +120,16 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
126120 elif len (model_output ) == 4 :
127121 # Tanh squashing mode: (action, z, mean, log_std)
128122 action , z , mean , log_std = model_output
129- log_prob = sample_nondeterministic_logprobs (
130- z = z , mean = mean , log_std = log_std , sac = self .algo == "sac"
131- )
123+ if not self .algo == "sac" :
124+
125+ log_prob = sample_nondeterministic_logprobs (
126+ z = z ,
127+ mean = mean ,
128+ log_std = log_std ,
129+ sac = False ,
130+ )
131+ else :
132+ log_prob = self .model .policy_log_prob (z , mean , log_std )
132133 else :
133134 raise ValueError (
134135 f"Unexpected model output length: { len (model_output )} "
@@ -145,9 +146,15 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
145146 if self .model .output_style == "squashed_gaussian" :
146147 # Should be 4-tuple: (action, z, mean, log_std)
147148 action , z , mean , log_std = model_output
148- log_prob = sample_nondeterministic_logprobs (
149- z = z , mean = mean , log_std = log_std , sac = self .algo == "sac"
150- )
149+ if not self .algo == "sac" :
150+ log_prob = sample_nondeterministic_logprobs (
151+ z = z ,
152+ mean = mean ,
153+ log_std = log_std ,
154+ sac = False ,
155+ )
156+ else :
157+ log_prob = self .model .policy_log_prob (z , mean , log_std )
151158
152159 if return_logp :
153160 return action .detach ().cpu ().numpy (), log_prob
@@ -162,9 +169,16 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
162169 z = dist .rsample ()
163170 action = torch .tanh (z )
164171
165- log_prob = sample_nondeterministic_logprobs (
166- z = z , mean = mean , log_std = torch .log (std ), sac = self .algo == "sac"
167- )
172+ if not self .algo == "sac" :
173+ log_prob = sample_nondeterministic_logprobs (
174+ z = z ,
175+ mean = mean ,
176+ log_std = log_std ,
177+ sac = False ,
178+ )
179+ else :
180+ log_prob = self .model .policy_log_prob (z , mean , log_std )
181+
168182 entropy = dist .entropy ().sum (dim = - 1 , keepdim = True )
169183 weighted_log_prob = log_prob * entropy
170184 return action .detach ().cpu ().numpy (), weighted_log_prob
@@ -175,14 +189,11 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
175189 )
176190
177191 # Special handling for SACModel
178- elif isinstance (self .model , SACModel ):
192+ elif self . algo == "sac" and isinstance (self .model , SACModel ):
179193 action , z , mean , log_std = self .model (state , deterministic = False )
180- std = torch .exp (log_std )
181- dist = Normal (mean , std )
182-
183- log_pz = dist .log_prob (z ).sum (dim = - 1 , keepdim = True )
184- weighted_log_prob = log_pz * self .entropy_coefficient
185- return action .detach ().cpu ().numpy (), weighted_log_prob
194+ # CRITICAL: Use the model's policy_log_prob which includes tanh correction
195+ log_prob = self .model .policy_log_prob (z , mean , log_std )
196+ return action .detach ().cpu ().numpy (), log_prob
186197
187198 else :
188199 raise RuntimeError (
0 commit comments