File tree Expand file tree Collapse file tree 5 files changed +6
-5
lines changed Expand file tree Collapse file tree 5 files changed +6
-5
lines changed Original file line number Diff line number Diff line change @@ -206,7 +206,7 @@ def train(cfg: "DictConfig"): # noqa: F821
206206 and cfg .logger .backend
207207 ):
208208 evaluation_start = time .time ()
209- with torch .no_grad () and set_exploration_type (ExplorationType .MEAN ):
209+ with torch .no_grad (), set_exploration_type (ExplorationType .MODE ):
210210 env_test .frames = []
211211 rollouts = env_test .rollout (
212212 max_steps = cfg .env .max_steps ,
Original file line number Diff line number Diff line change @@ -230,7 +230,7 @@ def train(cfg: "DictConfig"): # noqa: F821
230230 and cfg .logger .backend
231231 ):
232232 evaluation_start = time .time ()
233- with torch .no_grad () and set_exploration_type (ExplorationType .MEAN ):
233+ with torch .no_grad (), set_exploration_type (ExplorationType .MODE ):
234234 env_test .frames = []
235235 rollouts = env_test .rollout (
236236 max_steps = cfg .env .max_steps ,
Original file line number Diff line number Diff line change 1717from torchrl .data .replay_buffers .storages import LazyTensorStorage
1818from torchrl .envs import RewardSum , TransformedEnv
1919from torchrl .envs .libs .vmas import VmasEnv
20+ from torchrl .envs .utils import ExplorationType , set_exploration_type
2021from torchrl .modules import ProbabilisticActor , TanhNormal , ValueOperator
2122from torchrl .modules .models .multiagent import MultiAgentMLP
2223from torchrl .objectives import ClipPPOLoss , ValueEstimators
@@ -235,7 +236,7 @@ def train(cfg: "DictConfig"): # noqa: F821
235236 and cfg .logger .backend
236237 ):
237238 evaluation_start = time .time ()
238- with torch .no_grad ():
239+ with torch .no_grad (), set_exploration_type ( ExplorationType . MODE ) :
239240 env_test .frames = []
240241 rollouts = env_test .rollout (
241242 max_steps = cfg .env .max_steps ,
Original file line number Diff line number Diff line change @@ -241,7 +241,7 @@ def train(cfg: "DictConfig"): # noqa: F821
241241 and cfg .logger .backend
242242 ):
243243 evaluation_start = time .time ()
244- with torch .no_grad () and set_exploration_type (ExplorationType .MEAN ):
244+ with torch .no_grad (), set_exploration_type (ExplorationType .MODE ):
245245 env_test .frames = []
246246 rollouts = env_test .rollout (
247247 max_steps = cfg .env .max_steps ,
Original file line number Diff line number Diff line change @@ -300,7 +300,7 @@ def train(cfg: "DictConfig"): # noqa: F821
300300 and cfg .logger .backend
301301 ):
302302 evaluation_start = time .time ()
303- with torch .no_grad () and set_exploration_type (ExplorationType .MODE ):
303+ with torch .no_grad (), set_exploration_type (ExplorationType .MODE ):
304304 env_test .frames = []
305305 rollouts = env_test .rollout (
306306 max_steps = cfg .env .max_steps ,
You can’t perform that action at this time.
0 commit comments