2
2
#
3
3
# This source code is licensed under the MIT license found in the
4
4
# LICENSE file in the root directory of this source tree.
5
+ from __future__ import annotations
6
+
5
7
import functools
6
8
7
9
import torch .nn
@@ -221,8 +223,8 @@ def make_cql_model(cfg, train_env, eval_env, device="cpu"):
221
223
# distribution_kwargs=TensorDictParams(
222
224
# TensorDict(
223
225
# {
224
- # "low": action_spec.space.low,
225
- # "high": action_spec.space.high,
226
+ # "low": torch.as_tensor( action_spec.space.low, device=device) ,
227
+ # "high": torch.as_tensor( action_spec.space.high, device=device) ,
226
228
# "tanh_loc": NonTensorData(False),
227
229
# }
228
230
# ),
@@ -326,7 +328,7 @@ def make_cql_modules_state(model_cfg, proof_environment):
326
328
# ---------
327
329
328
330
329
- def make_continuous_loss (loss_cfg , model ):
331
+ def make_continuous_loss (loss_cfg , model , device : torch . device | None = None ):
330
332
loss_module = CQLLoss (
331
333
model [0 ],
332
334
model [1 ],
@@ -339,19 +341,19 @@ def make_continuous_loss(loss_cfg, model):
339
341
with_lagrange = loss_cfg .with_lagrange ,
340
342
lagrange_thresh = loss_cfg .lagrange_thresh ,
341
343
)
342
- loss_module .make_value_estimator (gamma = loss_cfg .gamma )
344
+ loss_module .make_value_estimator (gamma = loss_cfg .gamma , device = device )
343
345
target_net_updater = SoftUpdate (loss_module , tau = loss_cfg .tau )
344
346
345
347
return loss_module , target_net_updater
346
348
347
349
348
- def make_discrete_loss (loss_cfg , model ):
350
+ def make_discrete_loss (loss_cfg , model , device : torch . device | None = None ):
349
351
loss_module = DiscreteCQLLoss (
350
352
model ,
351
353
loss_function = loss_cfg .loss_function ,
352
354
delay_value = True ,
353
355
)
354
- loss_module .make_value_estimator (gamma = loss_cfg .gamma )
356
+ loss_module .make_value_estimator (gamma = loss_cfg .gamma , device = device )
355
357
target_net_updater = SoftUpdate (loss_module , tau = loss_cfg .tau )
356
358
357
359
return loss_module , target_net_updater
0 commit comments