@@ -27,13 +27,12 @@ def __init__(
2727 :param entropy_coefficient: weight on entropy term
2828 :param discrete: whether the action space is discrete
2929 """
30-
30+
3131 self .model = model
32-
32+
3333 super ().__init__ (algo , model , discrete )
3434 self .entropy_coefficient = entropy_coefficient
3535 self .discrete = discrete
36-
3736
3837 # --- override sample_action only for continuous SAC ---
3938 if not discrete and isinstance (model , SACModel ):
@@ -88,9 +87,9 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
8887 # 4-tuple case (Tanh squashing): (action, z, mean, log_std)
8988 elif isinstance (model_output , tuple ) and len (model_output ) == 4 :
9089 action , z , mean , log_std = model_output
91-
92- if not isinstance ( self .model , SACModel ) :
93-
90+
91+ if not self .algo == "sac" :
92+
9493 log_prob = sample_nondeterministic_logprobs (
9594 z = z ,
9695 mean = mean ,
@@ -121,8 +120,8 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
121120 elif len (model_output ) == 4 :
122121 # Tanh squashing mode: (action, z, mean, log_std)
123122 action , z , mean , log_std = model_output
124- if not isinstance ( self .model , SACModel ) :
125-
123+ if not self .algo == "sac" :
124+
126125 log_prob = sample_nondeterministic_logprobs (
127126 z = z ,
128127 mean = mean ,
@@ -147,7 +146,7 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
147146 if self .model .output_style == "squashed_gaussian" :
148147 # Should be 4-tuple: (action, z, mean, log_std)
149148 action , z , mean , log_std = model_output
150- if not isinstance ( self .model , SACModel ) :
149+ if not self .algo == "sac" :
151150 log_prob = sample_nondeterministic_logprobs (
152151 z = z ,
153152 mean = mean ,
@@ -170,7 +169,7 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
170169 z = dist .rsample ()
171170 action = torch .tanh (z )
172171
173- if not isinstance ( self .model , SACModel ) :
172+ if not self .algo == "sac" :
174173 log_prob = sample_nondeterministic_logprobs (
175174 z = z ,
176175 mean = mean ,
@@ -179,7 +178,7 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
179178 )
180179 else :
181180 log_prob = self .model .policy_log_prob (z , mean , log_std )
182-
181+
183182 entropy = dist .entropy ().sum (dim = - 1 , keepdim = True )
184183 weighted_log_prob = log_prob * entropy
185184 return action .detach ().cpu ().numpy (), weighted_log_prob
@@ -190,7 +189,7 @@ def explore(self, s, return_logp, metrics=None) -> Tuple[np.ndarray, torch.Tenso
190189 )
191190
192191 # Special handling for SACModel
193- elif isinstance (self .model , SACModel ):
192+ elif self . algo == "sac" and isinstance (self .model , SACModel ):
194193 action , z , mean , log_std = self .model (state , deterministic = False )
195194 # CRITICAL: Use the model's policy_log_prob which includes tanh correction
196195 log_prob = self .model .policy_log_prob (z , mean , log_std )
0 commit comments