Skip to content

Commit 55352ea

Browse files
committed
[losses] Fix flow field derivatives based loss functions
1 parent 51c3222 commit 55352ea

File tree

4 files changed

+232
-135
lines changed

4 files changed

+232
-135
lines changed

src/deepali/losses/bspline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class BSplineBending(BSplineLoss):
1111

1212
def forward(self, params: Tensor) -> Tensor:
1313
r"""Evaluate loss term for given free form deformation parameters."""
14-
return L.bspline_bending_loss(params, stride=self.stride, reduction=self.reduction)
14+
return L.bending_loss(params, mode="bspline", stride=self.stride, reduction=self.reduction)
1515

1616

1717
BSplineBendingEnergy = BSplineBending

src/deepali/losses/flow.py

Lines changed: 59 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch
88
from torch import Tensor
99

10-
from deepali.core.typing import Shape
10+
from deepali.core.typing import ScalarOrTuple, Shape
1111

1212
from . import functional as L
1313
from .base import DisplacementLoss
@@ -18,21 +18,24 @@ class _SpatialDerivativesLoss(DisplacementLoss):
1818

1919
def __init__(
2020
self,
21-
mode: str = "central",
21+
mode: Optional[str] = None,
2222
sigma: Optional[float] = None,
23+
stride: Optional[ScalarOrTuple] = None,
2324
reduction: str = "mean",
2425
):
2526
r"""Initialize regularization term.
2627
2728
Args:
28-
mode: Method used to approximate spatial derivatives. See ``spatial_derivatives()``.
29-
sigma: Standard deviation of Gaussian in grid units. See ``spatial_derivatives()``.
29+
mode: Method used to approximate :func:`flow_derivatives()`.
30+
sigma: Standard deviation of Gaussian in grid units used to smooth vector field.
31+
stride: Number of output grid points between control points plus one for ``mode='bspline'``.
3032
reduction: Operation to use for reducing spatially distributed loss values.
3133
3234
"""
3335
super().__init__()
3436
self.mode = mode
35-
self.sigma = float(0 if sigma is None else sigma)
37+
self.sigma = sigma
38+
self.stride = stride
3639
self.reduction = reduction
3740

3841
def _spacing(self, u_shape: Shape) -> Optional[Tensor]:
@@ -45,7 +48,15 @@ def _spacing(self, u_shape: Shape) -> Optional[Tensor]:
4548
return 2 / (size - 1)
4649

4750
def extra_repr(self) -> str:
48-
return f"mode={self.mode!r}, sigma={self.sigma!r}, reduction={self.reduction!r}"
51+
args = []
52+
if self.mode:
53+
args.append(f"mode={self.mode!r}")
54+
if self.sigma:
55+
args.append(f"sigma={self.sigma!r}")
56+
if self.stride:
57+
args.append(f"stride={self.stride!r}")
58+
args.append(f"reduction={self.reduction!r}")
59+
return ", ".join(args)
4960

5061

5162
class GradLoss(_SpatialDerivativesLoss):
@@ -55,21 +66,23 @@ def __init__(
5566
self,
5667
p: Union[int, float] = 2,
5768
q: Optional[Union[int, float]] = 1,
58-
mode: str = "central",
69+
mode: Optional[str] = None,
5970
sigma: Optional[float] = None,
71+
stride: Optional[ScalarOrTuple] = None,
6072
reduction: str = "mean",
6173
):
6274
r"""Initialize regularization term.
6375
6476
Args:
65-
mode: Method used to approximate spatial derivatives. See ``spatial_derivatives()``.
66-
sigma: Standard deviation of Gaussian in grid units. See ``spatial_derivatives()``.
77+
mode: Method used to approximate :func:`flow_derivatives()`.
78+
sigma: Standard deviation of Gaussian in grid units used to smooth vector field.
79+
stride: Number of output grid points between control points plus one for ``mode='bspline'``.
6780
reduction: Operation to use for reducing spatially distributed loss values.
6881
6982
"""
70-
super().__init__(mode=mode, sigma=sigma, reduction=reduction)
83+
super().__init__(mode=mode, sigma=sigma, stride=stride, reduction=reduction)
7184
self.p = p
72-
self.q = q
85+
self.q = 1 / p if q is None else q
7386

7487
def forward(self, u: Tensor) -> Tensor:
7588
r"""Evaluate regularization loss for given transformation."""
@@ -78,9 +91,10 @@ def forward(self, u: Tensor) -> Tensor:
7891
u,
7992
p=self.p,
8093
q=self.q,
81-
spacing=spacing,
8294
mode=self.mode,
8395
sigma=self.sigma,
96+
spacing=spacing,
97+
stride=self.stride,
8498
reduction=self.reduction,
8599
)
86100

@@ -96,9 +110,10 @@ def forward(self, u: Tensor) -> Tensor:
96110
spacing = self._spacing(u.shape)
97111
return L.bending_loss(
98112
u,
99-
spacing=spacing,
100113
mode=self.mode,
101114
sigma=self.sigma,
115+
spacing=spacing,
116+
stride=self.stride,
102117
reduction=self.reduction,
103118
)
104119

@@ -115,9 +130,10 @@ def forward(self, u: Tensor) -> Tensor:
115130
spacing = self._spacing(u.shape)
116131
return L.curvature_loss(
117132
u,
118-
spacing=spacing,
119133
mode=self.mode,
120134
sigma=self.sigma,
135+
spacing=spacing,
136+
stride=self.stride,
121137
reduction=self.reduction,
122138
)
123139

@@ -130,9 +146,10 @@ def forward(self, u: Tensor) -> Tensor:
130146
spacing = self._spacing(u.shape)
131147
return L.diffusion_loss(
132148
u,
133-
spacing=spacing,
134149
mode=self.mode,
135150
sigma=self.sigma,
151+
spacing=spacing,
152+
stride=self.stride,
136153
reduction=self.reduction,
137154
)
138155

@@ -145,9 +162,10 @@ def forward(self, u: Tensor) -> Tensor:
145162
spacing = self._spacing(u.shape)
146163
return L.divergence_loss(
147164
u,
148-
spacing=spacing,
149165
mode=self.mode,
150166
sigma=self.sigma,
167+
spacing=spacing,
168+
stride=self.stride,
151169
reduction=self.reduction,
152170
)
153171

@@ -163,8 +181,9 @@ def __init__(
163181
poissons_ratio: Optional[float] = None,
164182
youngs_modulus: Optional[float] = None,
165183
shear_modulus: Optional[float] = None,
166-
mode: str = "central",
184+
mode: Optional[str] = None,
167185
sigma: Optional[float] = None,
186+
stride: Optional[ScalarOrTuple] = None,
168187
reduction: str = "mean",
169188
):
170189
super().__init__(mode=mode, sigma=sigma, reduction=reduction)
@@ -180,18 +199,35 @@ def forward(self, u: Tensor) -> Tensor:
180199
spacing = self._spacing(u.shape)
181200
return L.elasticity_loss(
182201
u,
183-
spacing=spacing,
184-
mode=self.mode,
185-
sigma=self.sigma,
186-
reduction=self.reduction,
187202
material_name=self.material_name,
188203
first_parameter=self.first_parameter,
189204
second_parameter=self.second_parameter,
190205
poissons_ratio=self.poissons_ratio,
191206
youngs_modulus=self.youngs_modulus,
192207
shear_modulus=self.shear_modulus,
208+
mode=self.mode,
209+
sigma=self.sigma,
210+
spacing=spacing,
211+
stride=self.stride,
212+
reduction=self.reduction,
193213
)
194214

215+
def extra_repr(self) -> str:
216+
args = []
217+
if self.material_name:
218+
args.append(f"material_name={self.material_name!r}")
219+
if self.first_parameter is not None:
220+
args.append(f"first_parameter={self.first_parameter!r}")
221+
if self.second_parameter is not None:
222+
args.append(f"second_parameter={self.second_parameter!r}")
223+
if self.poissons_ratio is not None:
224+
args.append(f"poissons_ratio={self.poissons_ratio!r}")
225+
if self.youngs_modulus is not None:
226+
args.append(f"youngs_modulus={self.youngs_modulus!r}")
227+
if self.shear_modulus is not None:
228+
args.append(f"shear_modulus={self.shear_modulus!r}")
229+
return ", ".join(args) + ", " + super().extra_repr()
230+
195231

196232
class TotalVariation(_SpatialDerivativesLoss):
197233
r"""Total variation of displacement field."""
@@ -201,9 +237,10 @@ def forward(self, u: Tensor) -> Tensor:
201237
spacing = self._spacing(u.shape)
202238
return L.total_variation_loss(
203239
u,
204-
spacing=spacing,
205240
mode=self.mode,
206241
sigma=self.sigma,
242+
spacing=spacing,
243+
stride=self.stride,
207244
reduction=self.reduction,
208245
)
209246

0 commit comments

Comments
 (0)