We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 6162b81 commit 97b9bc7Copy full SHA for 97b9bc7
ding/framework/middleware/functional/trainer.py
@@ -71,8 +71,8 @@ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
71
72
if ctx.train_data is None: # no enough data from data fetcher
73
return
74
- data = ctx.train_data.to(policy._device)
75
- train_output = policy.forward(data)
+ # data = ctx.train_data.to(policy._device)
+ train_output = policy.forward(ctx.train_data)
76
nonlocal last_log_iter
77
if ctx.train_iter - last_log_iter >= log_freq:
78
loss = np.mean([o['total_loss'] for o in train_output])
0 commit comments