@@ -130,6 +130,8 @@ class SACLoss(LossModule):
130
130
valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the
131
131
shape of the data and that masking the data results in a valid data structure. Among other things, this may
132
132
not be true in MARL settings or when using RNNs. Defaults to ``False``.
133
+ deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
134
+ Defaults to ``False``.
133
135
134
136
Examples:
135
137
>>> import torch
@@ -334,6 +336,7 @@ def __init__(
334
336
separate_losses : bool = False ,
335
337
reduction : str = None ,
336
338
skip_done_states : bool = False ,
339
+ deactivate_vmap : bool = False ,
337
340
) -> None :
338
341
self ._in_keys = None
339
342
self ._out_keys = None
@@ -344,6 +347,7 @@ def __init__(
344
347
345
348
# Actor
346
349
self .delay_actor = delay_actor
350
+ self .deactivate_vmap = deactivate_vmap
347
351
self .convert_to_functional (
348
352
actor_network ,
349
353
"actor_network" ,
@@ -445,11 +449,16 @@ def __init__(
445
449
446
450
def _make_vmap (self ):
447
451
self ._vmap_qnetworkN0 = _vmap_func (
448
- self .qvalue_network , (None , 0 ), randomness = self .vmap_randomness
452
+ self .qvalue_network ,
453
+ (None , 0 ),
454
+ randomness = self .vmap_randomness ,
455
+ pseudo_vmap = self .deactivate_vmap ,
449
456
)
450
457
if self ._version == 1 :
451
458
self ._vmap_qnetwork00 = _vmap_func (
452
- self .qvalue_network , randomness = self .vmap_randomness
459
+ self .qvalue_network ,
460
+ randomness = self .vmap_randomness ,
461
+ pseudo_vmap = self .deactivate_vmap ,
453
462
)
454
463
455
464
@property
@@ -527,11 +536,13 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
527
536
self ._value_estimator = TD1Estimator (
528
537
** hp ,
529
538
value_network = value_net ,
539
+ deactivate_vmap = self .deactivate_vmap ,
530
540
)
531
541
elif value_type is ValueEstimators .TD0 :
532
542
self ._value_estimator = TD0Estimator (
533
543
** hp ,
534
544
value_network = value_net ,
545
+ deactivate_vmap = self .deactivate_vmap ,
535
546
)
536
547
elif value_type is ValueEstimators .GAE :
537
548
raise NotImplementedError (
@@ -541,6 +552,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
541
552
self ._value_estimator = TDLambdaEstimator (
542
553
** hp ,
543
554
value_network = value_net ,
555
+ deactivate_vmap = self .deactivate_vmap ,
544
556
)
545
557
else :
546
558
raise NotImplementedError (f"Unknown value type { value_type } " )
@@ -673,7 +685,6 @@ def _actor_loss(
673
685
raise RuntimeError (
674
686
f"Losses shape mismatch: { log_prob .shape } and { min_q_logprob .shape } "
675
687
)
676
-
677
688
return self ._alpha * log_prob - min_q_logprob , {"log_prob" : log_prob .detach ()}
678
689
679
690
@property
@@ -922,6 +933,8 @@ class DiscreteSACLoss(LossModule):
922
933
valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the
923
934
shape of the data and that masking the data results in a valid data structure. Among other things, this may
924
935
not be true in MARL settings or when using RNNs. Defaults to ``False``.
936
+ deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
937
+ Defaults to ``False``.
925
938
926
939
Examples:
927
940
>>> import torch
@@ -1098,6 +1111,7 @@ def __init__(
1098
1111
separate_losses : bool = False ,
1099
1112
reduction : str = None ,
1100
1113
skip_done_states : bool = False ,
1114
+ deactivate_vmap : bool = False ,
1101
1115
):
1102
1116
if reduction is None :
1103
1117
reduction = "mean"
@@ -1110,6 +1124,7 @@ def __init__(
1110
1124
"actor_network" ,
1111
1125
create_target_params = self .delay_actor ,
1112
1126
)
1127
+ self .deactivate_vmap = deactivate_vmap
1113
1128
if separate_losses :
1114
1129
# we want to make sure there are no duplicates in the params: the
1115
1130
# params of critic must be refs to actor if they're shared
@@ -1184,7 +1199,10 @@ def __init__(
1184
1199
1185
1200
def _make_vmap (self ):
1186
1201
self ._vmap_qnetworkN0 = _vmap_func (
1187
- self .qvalue_network , (None , 0 ), randomness = self .vmap_randomness
1202
+ self .qvalue_network ,
1203
+ (None , 0 ),
1204
+ randomness = self .vmap_randomness ,
1205
+ pseudo_vmap = self .deactivate_vmap ,
1188
1206
)
1189
1207
1190
1208
def _forward_value_estimator_keys (self , ** kwargs ) -> None :
@@ -1436,11 +1454,13 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
1436
1454
self ._value_estimator = TD1Estimator (
1437
1455
** hp ,
1438
1456
value_network = None ,
1457
+ deactivate_vmap = self .deactivate_vmap ,
1439
1458
)
1440
1459
elif value_type is ValueEstimators .TD0 :
1441
1460
self ._value_estimator = TD0Estimator (
1442
1461
** hp ,
1443
1462
value_network = None ,
1463
+ deactivate_vmap = self .deactivate_vmap ,
1444
1464
)
1445
1465
elif value_type is ValueEstimators .GAE :
1446
1466
raise NotImplementedError (
@@ -1450,6 +1470,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
1450
1470
self ._value_estimator = TDLambdaEstimator (
1451
1471
** hp ,
1452
1472
value_network = None ,
1473
+ deactivate_vmap = self .deactivate_vmap ,
1453
1474
)
1454
1475
else :
1455
1476
raise NotImplementedError (f"Unknown value type { value_type } " )
0 commit comments