Skip to content

Commit aeb326c

Browse files
functional fixes
1 parent 128aa1c commit aeb326c

File tree

3 files changed

+3
-3
lines changed

3 files changed

+3
-3
lines changed

rsl_rl/algorithms/distillation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def __init__(
5353

5454
# Initialize the transition
5555
self.transition = RolloutStorage.Transition()
56-
self.last_hidden_states = None
56+
self.last_hidden_states = (None, None)
5757

5858
# Distillation parameters
5959
self.num_learning_epochs = num_learning_epochs

rsl_rl/networks/memory.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ class Memory(nn.Module):
2020
def __init__(self, input_size: int, hidden_dim: int = 256, num_layers: int = 1, type: str = "lstm") -> None:
2121
super().__init__()
2222
rnn_cls = nn.GRU if type.lower() == "gru" else nn.LSTM
23-
self.rnn = rnn_cls(input_size=input_size, hidden_dim=hidden_dim, num_layers=num_layers)
23+
self.rnn = rnn_cls(input_size=input_size, hidden_size=hidden_dim, num_layers=num_layers)
2424
self.hidden_states = None
2525

2626
def forward(

rsl_rl/runners/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
"""Implementation of runners for environment-agent interaction."""
77

8+
from .on_policy_runner import OnPolicyRunner # noqa: I001
89
from .distillation_runner import DistillationRunner
9-
from .on_policy_runner import OnPolicyRunner
1010

1111
__all__ = ["DistillationRunner", "OnPolicyRunner"]

0 commit comments

Comments
 (0)