9
9
import warnings
10
10
from copy import copy
11
11
from enum import Enum
12
- from typing import Any , Callable , Iterable
12
+ from typing import Any , Callable , Iterable , TypeVar
13
13
14
14
import torch
15
15
from tensordict import NestedKey , TensorDict , TensorDictBase , unravel_key
@@ -101,54 +101,44 @@ def decorate_context(*args, **kwargs):
101
101
return decorate_context
102
102
103
103
104
+ TensorLike = TypeVar ("TensorLike" , Tensor , TensorDict )
105
+
106
+
104
107
def distance_loss (
105
- v1 : torch . Tensor ,
106
- v2 : torch . Tensor ,
108
+ v1 : TensorLike ,
109
+ v2 : TensorLike ,
107
110
loss_function : str ,
108
111
strict_shape : bool = True ,
109
- ) -> torch . Tensor :
112
+ ) -> TensorLike :
110
113
"""Computes a distance loss between two tensors.
111
114
112
115
Args:
113
- v1 (Tensor): a tensor with a shape compatible with v2
114
- v2 (Tensor): a tensor with a shape compatible with v1
116
+ v1 (Tensor | TensorDict ): a tensor or tensordict with a shape compatible with v2.
117
+ v2 (Tensor | TensorDict ): a tensor or tensordict with a shape compatible with v1.
115
118
loss_function (str): One of "l2", "l1" or "smooth_l1" representing which loss function is to be used.
116
119
strict_shape (bool): if False, v1 and v2 are allowed to have a different shape.
117
120
Default is ``True``.
118
121
119
122
Returns:
120
- A tensor of the shape v1.view_as(v2) or v2.view_as(v1) with values equal to the distance loss between the
121
- two.
123
+ A tensor or tensordict of the shape v1.view_as(v2) or v2.view_as(v1)
124
+ with values equal to the distance loss between the two.
122
125
123
126
"""
124
127
if v1 .shape != v2 .shape and strict_shape :
125
128
raise RuntimeError (
126
- f"The input tensors have shapes { v1 .shape } and { v2 .shape } which are incompatible."
129
+ f"The input tensors or tensordicts have shapes { v1 .shape } and { v2 .shape } which are incompatible."
127
130
)
128
131
129
132
if loss_function == "l2" :
130
- value_loss = F .mse_loss (
131
- v1 ,
132
- v2 ,
133
- reduction = "none" ,
134
- )
133
+ return F .mse_loss (v1 , v2 , reduction = "none" )
135
134
136
- elif loss_function == "l1" :
137
- value_loss = F .l1_loss (
138
- v1 ,
139
- v2 ,
140
- reduction = "none" ,
141
- )
135
+ if loss_function == "l1" :
136
+ return F .l1_loss (v1 , v2 , reduction = "none" )
142
137
143
- elif loss_function == "smooth_l1" :
144
- value_loss = F .smooth_l1_loss (
145
- v1 ,
146
- v2 ,
147
- reduction = "none" ,
148
- )
149
- else :
150
- raise NotImplementedError (f"Unknown loss { loss_function } " )
151
- return value_loss
138
+ if loss_function == "smooth_l1" :
139
+ return F .smooth_l1_loss (v1 , v2 , reduction = "none" )
140
+
141
+ raise NotImplementedError (f"Unknown loss { loss_function } ." )
152
142
153
143
154
144
class TargetNetUpdater :
@@ -620,13 +610,13 @@ def _reduce(tensor: torch.Tensor, reduction: str) -> float | torch.Tensor:
620
610
621
611
622
612
def _clip_value_loss (
623
- old_state_value : torch .Tensor ,
624
- state_value : torch .Tensor ,
625
- clip_value : torch .Tensor ,
626
- target_return : torch .Tensor ,
627
- loss_value : torch .Tensor ,
613
+ old_state_value : torch .Tensor | TensorDict ,
614
+ state_value : torch .Tensor | TensorDict ,
615
+ clip_value : torch .Tensor | TensorDict ,
616
+ target_return : torch .Tensor | TensorDict ,
617
+ loss_value : torch .Tensor | TensorDict ,
628
618
loss_critic_type : str ,
629
- ):
619
+ ) -> tuple [ torch . Tensor | TensorDict , torch . Tensor ] :
630
620
"""Value clipping method for loss computation.
631
621
632
622
This method computes a clipped state value from the old state value and the state value,
@@ -644,7 +634,7 @@ def _clip_value_loss(
644
634
loss_function = loss_critic_type ,
645
635
)
646
636
# Chose the most pessimistic value prediction between clipped and non-clipped
647
- loss_value = torch .max (loss_value , loss_value_clipped )
637
+ loss_value = torch .maximum (loss_value , loss_value_clipped )
648
638
return loss_value , clip_fraction
649
639
650
640
0 commit comments