Skip to content

Commit 32dcf3a

Browse files
authored
Merge pull request #59 from WiemKhlifi/fix/fix_cql_loss
fix: fix CQL loss
2 parents 10dead4 + 280d9c4 commit 32dcf3a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

og_marl/tf2_systems/offline/iql_cql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def _tf_train_step(self, train_step: int, experience: Dict[str, Any]) -> Dict[st
226226
#############
227227

228228
# Mask out zero-padded timesteps
229-
loss = td_loss + self.cql_weight + cql_loss
229+
loss = td_loss + self.cql_weight * cql_loss
230230

231231
# Get trainable variables
232232
variables = (*self.q_network.trainable_variables,)

0 commit comments

Comments
 (0)