44
55from typing import Optional , Union
66
7- import torch
87from torch import Tensor
98
10- from deepali .core .typing import ScalarOrTuple , Shape
9+ from deepali .core .typing import Array , Scalar , ScalarOrTuple
1110
1211from . import functional as L
1312from .base import DisplacementLoss
@@ -20,6 +19,7 @@ def __init__(
2019 self ,
2120 mode : Optional [str ] = None ,
2221 sigma : Optional [float ] = None ,
22+ spacing : Optional [Union [Scalar , Array ]] = None ,
2323 stride : Optional [ScalarOrTuple ] = None ,
2424 reduction : str = "mean" ,
2525 ):
@@ -28,31 +28,27 @@ def __init__(
2828 Args:
2929 mode: Method used to approximate :func:`flow_derivatives()`.
3030 sigma: Standard deviation of Gaussian in grid units used to smooth vector field.
31+ spacing: Spacing between grid elements. Should be given in the units of the flow vectors.
32+ By default, flow vectors with respect to normalized grid coordinates are assumed.
3133 stride: Number of output grid points between control points plus one for ``mode='bspline'``.
3234 reduction: Operation to use for reducing spatially distributed loss values.
3335
3436 """
3537 super ().__init__ ()
3638 self .mode = mode
3739 self .sigma = sigma
40+ self .spacing = spacing
3841 self .stride = stride
3942 self .reduction = reduction
4043
41- def _spacing (self , u_shape : Shape ) -> Optional [Tensor ]:
42- ndim = len (u_shape )
43- if ndim < 3 :
44- raise ValueError (f"{ type (self ).__name__ } .forward() 'u' must be at least 3-dimensional" )
45- if ndim == 3 :
46- return None
47- size = torch .tensor (u_shape [- 1 :1 :- 1 ], dtype = torch .float , device = torch .device ("cpu" ))
48- return 2 / (size - 1 )
49-
5044 def extra_repr (self ) -> str :
5145 args = []
5246 if self .mode :
5347 args .append (f"mode={ self .mode !r} " )
5448 if self .sigma :
5549 args .append (f"sigma={ self .sigma !r} " )
50+ if self .spacing :
51+ args .append (f"spacing={ self .spacing !r} " )
5652 if self .stride :
5753 args .append (f"stride={ self .stride !r} " )
5854 args .append (f"reduction={ self .reduction !r} " )
@@ -68,6 +64,7 @@ def __init__(
6864 q : Optional [Union [int , float ]] = 1 ,
6965 mode : Optional [str ] = None ,
7066 sigma : Optional [float ] = None ,
67+ spacing : Optional [Union [Scalar , Array ]] = None ,
7168 stride : Optional [ScalarOrTuple ] = None ,
7269 reduction : str = "mean" ,
7370 ):
@@ -76,24 +73,27 @@ def __init__(
7673 Args:
7774 mode: Method used to approximate :func:`flow_derivatives()`.
7875 sigma: Standard deviation of Gaussian in grid units used to smooth vector field.
76+ spacing: Spacing between grid elements. Should be given in the units of the flow vectors.
77+ By default, flow vectors with respect to normalized grid coordinates are assumed.
7978 stride: Number of output grid points between control points plus one for ``mode='bspline'``.
8079 reduction: Operation to use for reducing spatially distributed loss values.
8180
8281 """
83- super ().__init__ (mode = mode , sigma = sigma , stride = stride , reduction = reduction )
82+ super ().__init__ (
83+ mode = mode , sigma = sigma , spacing = spacing , stride = stride , reduction = reduction
84+ )
8485 self .p = p
8586 self .q = 1 / p if q is None else q
8687
8788 def forward (self , u : Tensor ) -> Tensor :
8889 r"""Evaluate regularization loss for given transformation."""
89- spacing = self ._spacing (u .shape )
9090 return L .grad_loss (
9191 u ,
9292 p = self .p ,
9393 q = self .q ,
9494 mode = self .mode ,
9595 sigma = self .sigma ,
96- spacing = spacing ,
96+ spacing = self . spacing ,
9797 stride = self .stride ,
9898 reduction = self .reduction ,
9999 )
@@ -107,12 +107,11 @@ class Bending(_SpatialDerivativesLoss):
107107
108108 def forward (self , u : Tensor ) -> Tensor :
109109 r"""Evaluate regularization loss for given transformation."""
110- spacing = self ._spacing (u .shape )
111110 return L .bending_loss (
112111 u ,
113112 mode = self .mode ,
114113 sigma = self .sigma ,
115- spacing = spacing ,
114+ spacing = self . spacing ,
116115 stride = self .stride ,
117116 reduction = self .reduction ,
118117 )
@@ -127,12 +126,11 @@ class Curvature(_SpatialDerivativesLoss):
127126
128127 def forward (self , u : Tensor ) -> Tensor :
129128 r"""Evaluate regularization loss for given transformation."""
130- spacing = self ._spacing (u .shape )
131129 return L .curvature_loss (
132130 u ,
133131 mode = self .mode ,
134132 sigma = self .sigma ,
135- spacing = spacing ,
133+ spacing = self . spacing ,
136134 stride = self .stride ,
137135 reduction = self .reduction ,
138136 )
@@ -143,12 +141,11 @@ class Diffusion(_SpatialDerivativesLoss):
143141
144142 def forward (self , u : Tensor ) -> Tensor :
145143 r"""Evaluate regularization loss for given transformation."""
146- spacing = self ._spacing (u .shape )
147144 return L .diffusion_loss (
148145 u ,
149146 mode = self .mode ,
150147 sigma = self .sigma ,
151- spacing = spacing ,
148+ spacing = self . spacing ,
152149 stride = self .stride ,
153150 reduction = self .reduction ,
154151 )
@@ -159,12 +156,11 @@ class Divergence(_SpatialDerivativesLoss):
159156
160157 def forward (self , u : Tensor ) -> Tensor :
161158 r"""Evaluate regularization loss for given transformation."""
162- spacing = self ._spacing (u .shape )
163159 return L .divergence_loss (
164160 u ,
165161 mode = self .mode ,
166162 sigma = self .sigma ,
167- spacing = spacing ,
163+ spacing = self . spacing ,
168164 stride = self .stride ,
169165 reduction = self .reduction ,
170166 )
@@ -183,10 +179,11 @@ def __init__(
183179 shear_modulus : Optional [float ] = None ,
184180 mode : Optional [str ] = None ,
185181 sigma : Optional [float ] = None ,
182+ spacing : Optional [Union [Scalar , Array ]] = None ,
186183 stride : Optional [ScalarOrTuple ] = None ,
187184 reduction : str = "mean" ,
188185 ):
189- super ().__init__ (mode = mode , sigma = sigma , reduction = reduction )
186+ super ().__init__ (mode = mode , sigma = sigma , spacing = spacing , reduction = reduction )
190187 self .material_name = material_name
191188 self .first_parameter = first_parameter
192189 self .second_parameter = second_parameter
@@ -196,7 +193,6 @@ def __init__(
196193
197194 def forward (self , u : Tensor ) -> Tensor :
198195 r"""Evaluate regularization loss for given transformation."""
199- spacing = self ._spacing (u .shape )
200196 return L .elasticity_loss (
201197 u ,
202198 material_name = self .material_name ,
@@ -207,7 +203,7 @@ def forward(self, u: Tensor) -> Tensor:
207203 shear_modulus = self .shear_modulus ,
208204 mode = self .mode ,
209205 sigma = self .sigma ,
210- spacing = spacing ,
206+ spacing = self . spacing ,
211207 stride = self .stride ,
212208 reduction = self .reduction ,
213209 )
@@ -234,12 +230,11 @@ class TotalVariation(_SpatialDerivativesLoss):
234230
235231 def forward (self , u : Tensor ) -> Tensor :
236232 r"""Evaluate regularization loss for given transformation."""
237- spacing = self ._spacing (u .shape )
238233 return L .total_variation_loss (
239234 u ,
240235 mode = self .mode ,
241236 sigma = self .sigma ,
242- spacing = spacing ,
237+ spacing = self . spacing ,
243238 stride = self .stride ,
244239 reduction = self .reduction ,
245240 )
0 commit comments