Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from sb3_contrib.ars import ARS
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.ppo_recurrent import RecurrentPPO
from sb3_contrib.ppo_attention import AttentionPPO
from sb3_contrib.qrdqn import QRDQN
from sb3_contrib.tqc import TQC
from sb3_contrib.trpo import TRPO
Expand All @@ -16,6 +17,7 @@
"ARS",
"MaskablePPO",
"RecurrentPPO",
"AttentionPPO",
"QRDQN",
"TQC",
"TRPO",
Expand Down
Empty file.
390 changes: 390 additions & 0 deletions sb3_contrib/common/attention/buffers.py

Large diffs are not rendered by default.

673 changes: 673 additions & 0 deletions sb3_contrib/common/attention/policies - Copie.py
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a forgotten file?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yep, my bad

Large diffs are not rendered by default.

628 changes: 628 additions & 0 deletions sb3_contrib/common/attention/policies.py

Large diffs are not rendered by default.

33 changes: 33 additions & 0 deletions sb3_contrib/common/attention/type_aliases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from typing import NamedTuple, Tuple

import torch as th
from stable_baselines3.common.type_aliases import TensorDict


class AttnMemory(NamedTuple):
pi: th.Tensor #Tuple[th.Tensor, ...]
# vf: Tuple[th.Tensor, ...]


class AttentionRolloutBufferSamples(NamedTuple):
observations: th.Tensor
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
advantages: th.Tensor
returns: th.Tensor
attn_memory: AttnMemory
episode_starts: th.Tensor
mask: th.Tensor


class AttentionDictRolloutBufferSamples(NamedTuple):
observations: TensorDict
actions: th.Tensor
old_values: th.Tensor
old_log_prob: th.Tensor
advantages: th.Tensor
returns: th.Tensor
attn_memory: AttnMemory
episode_starts: th.Tensor
mask: th.Tensor
4 changes: 4 additions & 0 deletions sb3_contrib/ppo_attention/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from sb3_contrib.ppo_attention.policies import CnnAttnPolicy, MlpAttnPolicy, MultiInputAttnPolicy
from sb3_contrib.ppo_attention.ppo_attention import AttentionPPO

__all__ = ["CnnAttnPolicy", "MlpAttnPolicy", "MultiInputAttnPolicy", "AttentionPPO"]
Loading