Skip to content

Commit 97b9bc7

Browse files
committed
fix trainer.py:multistep_trainer args bug
1 parent 6162b81 commit 97b9bc7

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

ding/framework/middleware/functional/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,8 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
7171

7272
if ctx.train_data is None: # no enough data from data fetcher
7373
return
74-
data = ctx.train_data.to(policy._device)
75-
train_output = policy.forward(data)
74+
# data = ctx.train_data.to(policy._device)
75+
train_output = policy.forward(ctx.train_data)
7676
nonlocal last_log_iter
7777
if ctx.train_iter - last_log_iter >= log_freq:
7878
loss = np.mean([o['total_loss'] for o in train_output])

0 commit comments

Comments
 (0)