77import torch
88from torch import Tensor
99
10- from deepali .core .typing import Shape
10+ from deepali .core .typing import ScalarOrTuple , Shape
1111
1212from . import functional as L
1313from .base import DisplacementLoss
@@ -18,21 +18,24 @@ class _SpatialDerivativesLoss(DisplacementLoss):
1818
1919 def __init__ (
2020 self ,
21- mode : str = "central" ,
21+ mode : Optional [ str ] = None ,
2222 sigma : Optional [float ] = None ,
23+ stride : Optional [ScalarOrTuple ] = None ,
2324 reduction : str = "mean" ,
2425 ):
2526 r"""Initialize regularization term.
2627
2728 Args:
28- mode: Method used to approximate spatial derivatives. See ``spatial_derivatives()``.
29- sigma: Standard deviation of Gaussian in grid units. See ``spatial_derivatives()``.
29+ mode: Method used to approximate :func:`flow_derivatives()`.
30+ sigma: Standard deviation of Gaussian in grid units used to smooth vector field.
31+ stride: Number of output grid points between control points plus one for ``mode='bspline'``.
3032 reduction: Operation to use for reducing spatially distributed loss values.
3133
3234 """
3335 super ().__init__ ()
3436 self .mode = mode
35- self .sigma = float (0 if sigma is None else sigma )
37+ self .sigma = sigma
38+ self .stride = stride
3639 self .reduction = reduction
3740
3841 def _spacing (self , u_shape : Shape ) -> Optional [Tensor ]:
@@ -45,7 +48,15 @@ def _spacing(self, u_shape: Shape) -> Optional[Tensor]:
4548 return 2 / (size - 1 )
4649
4750 def extra_repr (self ) -> str :
48- return f"mode={ self .mode !r} , sigma={ self .sigma !r} , reduction={ self .reduction !r} "
51+ args = []
52+ if self .mode :
53+ args .append (f"mode={ self .mode !r} " )
54+ if self .sigma :
55+ args .append (f"sigma={ self .sigma !r} " )
56+ if self .stride :
57+ args .append (f"stride={ self .stride !r} " )
58+ args .append (f"reduction={ self .reduction !r} " )
59+ return ", " .join (args )
4960
5061
5162class GradLoss (_SpatialDerivativesLoss ):
@@ -55,21 +66,23 @@ def __init__(
5566 self ,
5667 p : Union [int , float ] = 2 ,
5768 q : Optional [Union [int , float ]] = 1 ,
58- mode : str = "central" ,
69+ mode : Optional [ str ] = None ,
5970 sigma : Optional [float ] = None ,
71+ stride : Optional [ScalarOrTuple ] = None ,
6072 reduction : str = "mean" ,
6173 ):
6274 r"""Initialize regularization term.
6375
6476 Args:
65- mode: Method used to approximate spatial derivatives. See ``spatial_derivatives()``.
66- sigma: Standard deviation of Gaussian in grid units. See ``spatial_derivatives()``.
77+ mode: Method used to approximate :func:`flow_derivatives()`.
78+ sigma: Standard deviation of Gaussian in grid units used to smooth vector field.
79+ stride: Number of output grid points between control points plus one for ``mode='bspline'``.
6780 reduction: Operation to use for reducing spatially distributed loss values.
6881
6982 """
70- super ().__init__ (mode = mode , sigma = sigma , reduction = reduction )
83+ super ().__init__ (mode = mode , sigma = sigma , stride = stride , reduction = reduction )
7184 self .p = p
72- self .q = q
85+ self .q = 1 / p if q is None else q
7386
7487 def forward (self , u : Tensor ) -> Tensor :
7588 r"""Evaluate regularization loss for given transformation."""
@@ -78,9 +91,10 @@ def forward(self, u: Tensor) -> Tensor:
7891 u ,
7992 p = self .p ,
8093 q = self .q ,
81- spacing = spacing ,
8294 mode = self .mode ,
8395 sigma = self .sigma ,
96+ spacing = spacing ,
97+ stride = self .stride ,
8498 reduction = self .reduction ,
8599 )
86100
@@ -96,9 +110,10 @@ def forward(self, u: Tensor) -> Tensor:
96110 spacing = self ._spacing (u .shape )
97111 return L .bending_loss (
98112 u ,
99- spacing = spacing ,
100113 mode = self .mode ,
101114 sigma = self .sigma ,
115+ spacing = spacing ,
116+ stride = self .stride ,
102117 reduction = self .reduction ,
103118 )
104119
@@ -115,9 +130,10 @@ def forward(self, u: Tensor) -> Tensor:
115130 spacing = self ._spacing (u .shape )
116131 return L .curvature_loss (
117132 u ,
118- spacing = spacing ,
119133 mode = self .mode ,
120134 sigma = self .sigma ,
135+ spacing = spacing ,
136+ stride = self .stride ,
121137 reduction = self .reduction ,
122138 )
123139
@@ -130,9 +146,10 @@ def forward(self, u: Tensor) -> Tensor:
130146 spacing = self ._spacing (u .shape )
131147 return L .diffusion_loss (
132148 u ,
133- spacing = spacing ,
134149 mode = self .mode ,
135150 sigma = self .sigma ,
151+ spacing = spacing ,
152+ stride = self .stride ,
136153 reduction = self .reduction ,
137154 )
138155
@@ -145,9 +162,10 @@ def forward(self, u: Tensor) -> Tensor:
145162 spacing = self ._spacing (u .shape )
146163 return L .divergence_loss (
147164 u ,
148- spacing = spacing ,
149165 mode = self .mode ,
150166 sigma = self .sigma ,
167+ spacing = spacing ,
168+ stride = self .stride ,
151169 reduction = self .reduction ,
152170 )
153171
@@ -163,8 +181,9 @@ def __init__(
163181 poissons_ratio : Optional [float ] = None ,
164182 youngs_modulus : Optional [float ] = None ,
165183 shear_modulus : Optional [float ] = None ,
166- mode : str = "central" ,
184+ mode : Optional [ str ] = None ,
167185 sigma : Optional [float ] = None ,
186+ stride : Optional [ScalarOrTuple ] = None ,
168187 reduction : str = "mean" ,
169188 ):
170189 super ().__init__ (mode = mode , sigma = sigma , reduction = reduction )
@@ -180,18 +199,35 @@ def forward(self, u: Tensor) -> Tensor:
180199 spacing = self ._spacing (u .shape )
181200 return L .elasticity_loss (
182201 u ,
183- spacing = spacing ,
184- mode = self .mode ,
185- sigma = self .sigma ,
186- reduction = self .reduction ,
187202 material_name = self .material_name ,
188203 first_parameter = self .first_parameter ,
189204 second_parameter = self .second_parameter ,
190205 poissons_ratio = self .poissons_ratio ,
191206 youngs_modulus = self .youngs_modulus ,
192207 shear_modulus = self .shear_modulus ,
208+ mode = self .mode ,
209+ sigma = self .sigma ,
210+ spacing = spacing ,
211+ stride = self .stride ,
212+ reduction = self .reduction ,
193213 )
194214
215+ def extra_repr (self ) -> str :
216+ args = []
217+ if self .material_name :
218+ args .append (f"material_name={ self .material_name !r} " )
219+ if self .first_parameter is not None :
220+ args .append (f"first_parameter={ self .first_parameter !r} " )
221+ if self .second_parameter is not None :
222+ args .append (f"second_parameter={ self .second_parameter !r} " )
223+ if self .poissons_ratio is not None :
224+ args .append (f"poissons_ratio={ self .poissons_ratio !r} " )
225+ if self .youngs_modulus is not None :
226+ args .append (f"youngs_modulus={ self .youngs_modulus !r} " )
227+ if self .shear_modulus is not None :
228+ args .append (f"shear_modulus={ self .shear_modulus !r} " )
229+ return ", " .join (args ) + ", " + super ().extra_repr ()
230+
195231
196232class TotalVariation (_SpatialDerivativesLoss ):
197233 r"""Total variation of displacement field."""
@@ -201,9 +237,10 @@ def forward(self, u: Tensor) -> Tensor:
201237 spacing = self ._spacing (u .shape )
202238 return L .total_variation_loss (
203239 u ,
204- spacing = spacing ,
205240 mode = self .mode ,
206241 sigma = self .sigma ,
242+ spacing = spacing ,
243+ stride = self .stride ,
207244 reduction = self .reduction ,
208245 )
209246
0 commit comments