Skip to content

[BUG] giving target_entroy = N (only "auto" work) with crossQLoss do not work #3309

@MathieuFonsProjects

Description

@MathieuFonsProjects

Describe the bug

Title + same problem with target_entropy inference corrected by @vmoens on SAC

To Reproduce

In crossQLoss:

def maybe_init_target_entropy(self, fault_tolerant=True):
        """Initialize the target entropy.

        Args:
            fault_tolerant (bool, optional): if ``True``, returns None if the target entropy
                cannot be determined. Raises an exception otherwise. Defaults to ``True``.

        """
        if "_target_entropy" in self._buffers:
            return
        target_entropy = self._target_entropy
        if target_entropy == "auto":
            device = next(self.parameters()).device
            action_spec = self.get_action_spec()
            if action_spec is None:
                if fault_tolerant:
                    return
                raise RuntimeError(
                    "Cannot infer the dimensionality of the action. Consider providing "
                    "the target entropy explicitly or provide the spec of the "
                    "action tensor in the actor network."
                )
            if not isinstance(action_spec, Composite):
                action_spec = Composite({self.tensor_keys.action: action_spec})
            elif fault_tolerant and self.tensor_keys.action not in action_spec:
                return
            if (
                isinstance(self.tensor_keys.action, tuple)
                and len(self.tensor_keys.action) > 1
            ):
                action_container_shape = action_spec[self.tensor_keys.action[:-1]].shape
            else:
                action_container_shape = action_spec.shape
            target_entropy = -float(
                action_spec[self.tensor_keys.action]
                .shape[len(action_container_shape) :]
                .numel()
            )
        delattr(self, "_target_entropy")
        self.register_buffer(
            "_target_entropy", torch.tensor(target_entropy, device=device)
        )
        return self._target_entropy

If 'if target_entropy == "auto"' isn't accepted device isn't set so:

self.register_buffer(
            "_target_entropy", torch.tensor(target_entropy, device=device)
        )

throw

Checklist

  • I have checked that there is no similar issue in the repo (required)
  • I have read the documentation (required)
  • I have provided a minimal working example to reproduce the bug (required)

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions