Skip to content

Commit 9cbcb02

Browse files
committed
[data] Add FlowFields.curl() method
1 parent 55352ea commit 9cbcb02

File tree

1 file changed

+27
-8
lines changed

1 file changed

+27
-8
lines changed

src/deepali/data/flow.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from deepali.core.enum import PaddingMode, Sampling
1313
from deepali.core.grid import ALIGN_CORNERS, Axes, Grid, grid_transform_vectors
1414
from deepali.core.tensor import move_dim
15-
from deepali.core.typing import Array, Device, DType, EllipsisType, PathStr, Scalar
15+
from deepali.core.typing import Array, Device, DType, EllipsisType, PathStr, Scalar, ScalarOrTuple
1616

1717
from .image import Image, ImageBatch
1818

@@ -203,12 +203,25 @@ def axes(self: TFlowFields, axes: Optional[Axes] = None) -> Union[Axes, TFlowFie
203203
data = move_dim(data, -1, 1)
204204
return self._make_instance(data, self._grid, axes)
205205

206-
def curl(self: TFlowFields, mode: str = "central") -> ImageBatch:
206+
def curl(
207+
self: TFlowFields,
208+
mode: Optional[str] = None,
209+
sigma: Optional[float] = None,
210+
spacing: Optional[Union[Scalar, Array]] = None,
211+
stride: Optional[ScalarOrTuple[int]] = None,
212+
) -> ImageBatch:
207213
if self.ndim not in (2, 3):
208-
raise RuntimeError("Cannot compute curl of {self.ndim}-dimensional flow field")
209-
spacing = self.spacing()
210-
data = self.tensor()
211-
data = U.curl(data, spacing=spacing, mode=mode)
214+
raise RuntimeError(f"Cannot compute curl of {self.ndim}-dimensional flow field")
215+
if spacing is None:
216+
if self.axes() is Axes.GRID:
217+
spacing = 1
218+
elif self.axes() is Axes.WORLD:
219+
spacing = self.spacing()
220+
elif self.axes() is Axes.CUBE:
221+
spacing = tuple(2 / n for n in self.grid().size())
222+
else:
223+
spacing = tuple(2 / (n - 1) for n in self.grid().size())
224+
data = U.curl(self.tensor(), mode=mode, sigma=sigma, spacing=spacing, stride=stride)
212225
return ImageBatch(data, self._grid)
213226

214227
def exp(
@@ -529,10 +542,16 @@ def write(self, path: PathStr, axes: Optional[Axes] = None, compress: bool = Tru
529542
disp = disp.axes(axes or Axes.WORLD)
530543
Image.write(disp, path, compress=compress)
531544

532-
def curl(self: TFlowField, mode: str = "central") -> Image:
545+
def curl(
546+
self: TFlowField,
547+
mode: Optional[str] = None,
548+
sigma: Optional[float] = None,
549+
spacing: Optional[Union[Scalar, Array]] = None,
550+
stride: Optional[ScalarOrTuple[int]] = None,
551+
) -> Image:
533552
r"""Compute curl of vector field."""
534553
batch = self.batch()
535-
rotvec = batch.curl(mode=mode)
554+
rotvec = batch.curl(mode=mode, sigma=sigma, spacing=spacing, stride=stride)
536555
return rotvec[0]
537556

538557
def exp(

0 commit comments

Comments
 (0)