Skip to content

Commit dfae2cc

Browse files
committed
feature(nyz): add ppof cuda
1 parent 3a9f213 commit dfae2cc

File tree

4 files changed

+9
-1
lines changed

4 files changed

+9
-1
lines changed

ding/framework/middleware/collector.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,7 @@ def __call__(self, ctx: "OnlineRLContext") -> None:
105105
Input of ctx:
106106
- env_step (:obj:`int`): The env steps which will increase during collection.
107107
"""
108+
device = self.policy._device
108109
old = ctx.env_step
109110
target_size = self.n_sample * self.unroll_len
110111

@@ -113,7 +114,9 @@ def __call__(self, ctx: "OnlineRLContext") -> None:
113114

114115
while True:
115116
obs = ttorch.as_tensor(self.env.ready_obs).to(dtype=ttorch.float32)
117+
obs = obs.to(device)
116118
inference_output = self.policy.collect(obs, **ctx.collect_kwargs)
119+
inference_output = inference_output.cpu()
117120
action = inference_output.action.numpy()
118121
timesteps = self.env.step(action)
119122
ctx.env_step += len(timesteps)

ding/framework/middleware/functional/evaluator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,11 +343,14 @@ def _evaluate(ctx: "OnlineRLContext"):
343343
else:
344344
env.reset()
345345
policy.reset()
346+
device = policy._device
346347
eval_monitor = VectorEvalMonitor(env.env_num, n_evaluator_episode)
347348

348349
while not eval_monitor.is_finished():
349350
obs = ttorch.as_tensor(env.ready_obs).to(dtype=ttorch.float32)
351+
obs = obs.to(device)
350352
inference_output = policy.eval(obs)
353+
inference_output = inference_output.cpu()
351354
if render:
352355
eval_monitor.update_video(env.ready_imgs)
353356
eval_monitor.update_output(inference_output)

ding/framework/middleware/functional/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,8 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
7171

7272
if ctx.train_data is None: # no enough data from data fetcher
7373
return
74-
train_output = policy.forward(ctx.train_data)
74+
data = ctx.train_data.to(policy._device)
75+
train_output = policy.forward(data)
7576
nonlocal last_log_iter
7677
if ctx.train_iter - last_log_iter >= log_freq:
7778
loss = np.mean([o['total_loss'] for o in train_output])

ding/policy/ppof.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@ def __init__(self, cfg: "EasyDict", model: torch.nn.Module, enable_mode: List[st
6464
self._model = model
6565
if self._cfg.cuda and torch.cuda.is_available():
6666
self._device = 'cuda'
67+
self._model.cuda()
6768
else:
6869
self._device = 'cpu'
6970
assert self._cfg.action_space in ["continuous", "discrete", "hybrid", 'multi_discrete']

0 commit comments

Comments
 (0)