Skip to content

Commit 53b90ea

Browse files
committed
[losses] Add 'spacing' option to flow field loss modules
1 parent 41bcc55 commit 53b90ea

File tree

1 file changed

+22
-27
lines changed

1 file changed

+22
-27
lines changed

src/deepali/losses/flow.py

Lines changed: 22 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,9 @@
44

55
from typing import Optional, Union
66

7-
import torch
87
from torch import Tensor
98

10-
from deepali.core.typing import ScalarOrTuple, Shape
9+
from deepali.core.typing import Array, Scalar, ScalarOrTuple
1110

1211
from . import functional as L
1312
from .base import DisplacementLoss
@@ -20,6 +19,7 @@ def __init__(
2019
self,
2120
mode: Optional[str] = None,
2221
sigma: Optional[float] = None,
22+
spacing: Optional[Union[Scalar, Array]] = None,
2323
stride: Optional[ScalarOrTuple] = None,
2424
reduction: str = "mean",
2525
):
@@ -28,31 +28,27 @@ def __init__(
2828
Args:
2929
mode: Method used to approximate :func:`flow_derivatives()`.
3030
sigma: Standard deviation of Gaussian in grid units used to smooth vector field.
31+
spacing: Spacing between grid elements. Should be given in the units of the flow vectors.
32+
By default, flow vectors with respect to normalized grid coordinates are assumed.
3133
stride: Number of output grid points between control points plus one for ``mode='bspline'``.
3234
reduction: Operation to use for reducing spatially distributed loss values.
3335
3436
"""
3537
super().__init__()
3638
self.mode = mode
3739
self.sigma = sigma
40+
self.spacing = spacing
3841
self.stride = stride
3942
self.reduction = reduction
4043

41-
def _spacing(self, u_shape: Shape) -> Optional[Tensor]:
42-
ndim = len(u_shape)
43-
if ndim < 3:
44-
raise ValueError(f"{type(self).__name__}.forward() 'u' must be at least 3-dimensional")
45-
if ndim == 3:
46-
return None
47-
size = torch.tensor(u_shape[-1:1:-1], dtype=torch.float, device=torch.device("cpu"))
48-
return 2 / (size - 1)
49-
5044
def extra_repr(self) -> str:
5145
args = []
5246
if self.mode:
5347
args.append(f"mode={self.mode!r}")
5448
if self.sigma:
5549
args.append(f"sigma={self.sigma!r}")
50+
if self.spacing:
51+
args.append(f"spacing={self.spacing!r}")
5652
if self.stride:
5753
args.append(f"stride={self.stride!r}")
5854
args.append(f"reduction={self.reduction!r}")
@@ -68,6 +64,7 @@ def __init__(
6864
q: Optional[Union[int, float]] = 1,
6965
mode: Optional[str] = None,
7066
sigma: Optional[float] = None,
67+
spacing: Optional[Union[Scalar, Array]] = None,
7168
stride: Optional[ScalarOrTuple] = None,
7269
reduction: str = "mean",
7370
):
@@ -76,24 +73,27 @@ def __init__(
7673
Args:
7774
mode: Method used to approximate :func:`flow_derivatives()`.
7875
sigma: Standard deviation of Gaussian in grid units used to smooth vector field.
76+
spacing: Spacing between grid elements. Should be given in the units of the flow vectors.
77+
By default, flow vectors with respect to normalized grid coordinates are assumed.
7978
stride: Number of output grid points between control points plus one for ``mode='bspline'``.
8079
reduction: Operation to use for reducing spatially distributed loss values.
8180
8281
"""
83-
super().__init__(mode=mode, sigma=sigma, stride=stride, reduction=reduction)
82+
super().__init__(
83+
mode=mode, sigma=sigma, spacing=spacing, stride=stride, reduction=reduction
84+
)
8485
self.p = p
8586
self.q = 1 / p if q is None else q
8687

8788
def forward(self, u: Tensor) -> Tensor:
8889
r"""Evaluate regularization loss for given transformation."""
89-
spacing = self._spacing(u.shape)
9090
return L.grad_loss(
9191
u,
9292
p=self.p,
9393
q=self.q,
9494
mode=self.mode,
9595
sigma=self.sigma,
96-
spacing=spacing,
96+
spacing=self.spacing,
9797
stride=self.stride,
9898
reduction=self.reduction,
9999
)
@@ -107,12 +107,11 @@ class Bending(_SpatialDerivativesLoss):
107107

108108
def forward(self, u: Tensor) -> Tensor:
109109
r"""Evaluate regularization loss for given transformation."""
110-
spacing = self._spacing(u.shape)
111110
return L.bending_loss(
112111
u,
113112
mode=self.mode,
114113
sigma=self.sigma,
115-
spacing=spacing,
114+
spacing=self.spacing,
116115
stride=self.stride,
117116
reduction=self.reduction,
118117
)
@@ -127,12 +126,11 @@ class Curvature(_SpatialDerivativesLoss):
127126

128127
def forward(self, u: Tensor) -> Tensor:
129128
r"""Evaluate regularization loss for given transformation."""
130-
spacing = self._spacing(u.shape)
131129
return L.curvature_loss(
132130
u,
133131
mode=self.mode,
134132
sigma=self.sigma,
135-
spacing=spacing,
133+
spacing=self.spacing,
136134
stride=self.stride,
137135
reduction=self.reduction,
138136
)
@@ -143,12 +141,11 @@ class Diffusion(_SpatialDerivativesLoss):
143141

144142
def forward(self, u: Tensor) -> Tensor:
145143
r"""Evaluate regularization loss for given transformation."""
146-
spacing = self._spacing(u.shape)
147144
return L.diffusion_loss(
148145
u,
149146
mode=self.mode,
150147
sigma=self.sigma,
151-
spacing=spacing,
148+
spacing=self.spacing,
152149
stride=self.stride,
153150
reduction=self.reduction,
154151
)
@@ -159,12 +156,11 @@ class Divergence(_SpatialDerivativesLoss):
159156

160157
def forward(self, u: Tensor) -> Tensor:
161158
r"""Evaluate regularization loss for given transformation."""
162-
spacing = self._spacing(u.shape)
163159
return L.divergence_loss(
164160
u,
165161
mode=self.mode,
166162
sigma=self.sigma,
167-
spacing=spacing,
163+
spacing=self.spacing,
168164
stride=self.stride,
169165
reduction=self.reduction,
170166
)
@@ -183,10 +179,11 @@ def __init__(
183179
shear_modulus: Optional[float] = None,
184180
mode: Optional[str] = None,
185181
sigma: Optional[float] = None,
182+
spacing: Optional[Union[Scalar, Array]] = None,
186183
stride: Optional[ScalarOrTuple] = None,
187184
reduction: str = "mean",
188185
):
189-
super().__init__(mode=mode, sigma=sigma, reduction=reduction)
186+
super().__init__(mode=mode, sigma=sigma, spacing=spacing, reduction=reduction)
190187
self.material_name = material_name
191188
self.first_parameter = first_parameter
192189
self.second_parameter = second_parameter
@@ -196,7 +193,6 @@ def __init__(
196193

197194
def forward(self, u: Tensor) -> Tensor:
198195
r"""Evaluate regularization loss for given transformation."""
199-
spacing = self._spacing(u.shape)
200196
return L.elasticity_loss(
201197
u,
202198
material_name=self.material_name,
@@ -207,7 +203,7 @@ def forward(self, u: Tensor) -> Tensor:
207203
shear_modulus=self.shear_modulus,
208204
mode=self.mode,
209205
sigma=self.sigma,
210-
spacing=spacing,
206+
spacing=self.spacing,
211207
stride=self.stride,
212208
reduction=self.reduction,
213209
)
@@ -234,12 +230,11 @@ class TotalVariation(_SpatialDerivativesLoss):
234230

235231
def forward(self, u: Tensor) -> Tensor:
236232
r"""Evaluate regularization loss for given transformation."""
237-
spacing = self._spacing(u.shape)
238233
return L.total_variation_loss(
239234
u,
240235
mode=self.mode,
241236
sigma=self.sigma,
242-
spacing=spacing,
237+
spacing=self.spacing,
243238
stride=self.stride,
244239
reduction=self.reduction,
245240
)

0 commit comments

Comments
 (0)