Skip to content

Commit bc95cbb

Browse files
matteobettinivmoens
authored andcommitted
[BugFix] Fix multiple context syntax in multiagent examples (#1943)
1 parent 1187fc5 commit bc95cbb

File tree

5 files changed

+6
-5
lines changed

5 files changed

+6
-5
lines changed

examples/multiagent/iql.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def train(cfg: "DictConfig"): # noqa: F821
206206
and cfg.logger.backend
207207
):
208208
evaluation_start = time.time()
209-
with torch.no_grad() and set_exploration_type(ExplorationType.MEAN):
209+
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
210210
env_test.frames = []
211211
rollouts = env_test.rollout(
212212
max_steps=cfg.env.max_steps,

examples/multiagent/maddpg_iddpg.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def train(cfg: "DictConfig"): # noqa: F821
230230
and cfg.logger.backend
231231
):
232232
evaluation_start = time.time()
233-
with torch.no_grad() and set_exploration_type(ExplorationType.MEAN):
233+
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
234234
env_test.frames = []
235235
rollouts = env_test.rollout(
236236
max_steps=cfg.env.max_steps,

examples/multiagent/mappo_ippo.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from torchrl.data.replay_buffers.storages import LazyTensorStorage
1818
from torchrl.envs import RewardSum, TransformedEnv
1919
from torchrl.envs.libs.vmas import VmasEnv
20+
from torchrl.envs.utils import ExplorationType, set_exploration_type
2021
from torchrl.modules import ProbabilisticActor, TanhNormal, ValueOperator
2122
from torchrl.modules.models.multiagent import MultiAgentMLP
2223
from torchrl.objectives import ClipPPOLoss, ValueEstimators
@@ -235,7 +236,7 @@ def train(cfg: "DictConfig"): # noqa: F821
235236
and cfg.logger.backend
236237
):
237238
evaluation_start = time.time()
238-
with torch.no_grad():
239+
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
239240
env_test.frames = []
240241
rollouts = env_test.rollout(
241242
max_steps=cfg.env.max_steps,

examples/multiagent/qmix_vdn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def train(cfg: "DictConfig"): # noqa: F821
241241
and cfg.logger.backend
242242
):
243243
evaluation_start = time.time()
244-
with torch.no_grad() and set_exploration_type(ExplorationType.MEAN):
244+
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
245245
env_test.frames = []
246246
rollouts = env_test.rollout(
247247
max_steps=cfg.env.max_steps,

examples/multiagent/sac.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def train(cfg: "DictConfig"): # noqa: F821
300300
and cfg.logger.backend
301301
):
302302
evaluation_start = time.time()
303-
with torch.no_grad() and set_exploration_type(ExplorationType.MODE):
303+
with torch.no_grad(), set_exploration_type(ExplorationType.MODE):
304304
env_test.frames = []
305305
rollouts = env_test.rollout(
306306
max_steps=cfg.env.max_steps,

0 commit comments

Comments
 (0)