|
12 | 12 | from deepali.core.enum import PaddingMode, Sampling |
13 | 13 | from deepali.core.grid import ALIGN_CORNERS, Axes, Grid, grid_transform_vectors |
14 | 14 | from deepali.core.tensor import move_dim |
15 | | -from deepali.core.typing import Array, Device, DType, EllipsisType, PathStr, Scalar |
| 15 | +from deepali.core.typing import Array, Device, DType, EllipsisType, PathStr, Scalar, ScalarOrTuple |
16 | 16 |
|
17 | 17 | from .image import Image, ImageBatch |
18 | 18 |
|
@@ -203,12 +203,25 @@ def axes(self: TFlowFields, axes: Optional[Axes] = None) -> Union[Axes, TFlowFie |
203 | 203 | data = move_dim(data, -1, 1) |
204 | 204 | return self._make_instance(data, self._grid, axes) |
205 | 205 |
|
206 | | - def curl(self: TFlowFields, mode: str = "central") -> ImageBatch: |
| 206 | + def curl( |
| 207 | + self: TFlowFields, |
| 208 | + mode: Optional[str] = None, |
| 209 | + sigma: Optional[float] = None, |
| 210 | + spacing: Optional[Union[Scalar, Array]] = None, |
| 211 | + stride: Optional[ScalarOrTuple[int]] = None, |
| 212 | + ) -> ImageBatch: |
207 | 213 | if self.ndim not in (2, 3): |
208 | | - raise RuntimeError("Cannot compute curl of {self.ndim}-dimensional flow field") |
209 | | - spacing = self.spacing() |
210 | | - data = self.tensor() |
211 | | - data = U.curl(data, spacing=spacing, mode=mode) |
| 214 | + raise RuntimeError(f"Cannot compute curl of {self.ndim}-dimensional flow field") |
| 215 | + if spacing is None: |
| 216 | + if self.axes() is Axes.GRID: |
| 217 | + spacing = 1 |
| 218 | + elif self.axes() is Axes.WORLD: |
| 219 | + spacing = self.spacing() |
| 220 | + elif self.axes() is Axes.CUBE: |
| 221 | + spacing = tuple(2 / n for n in self.grid().size()) |
| 222 | + else: |
| 223 | + spacing = tuple(2 / (n - 1) for n in self.grid().size()) |
| 224 | + data = U.curl(self.tensor(), mode=mode, sigma=sigma, spacing=spacing, stride=stride) |
212 | 225 | return ImageBatch(data, self._grid) |
213 | 226 |
|
214 | 227 | def exp( |
@@ -529,10 +542,16 @@ def write(self, path: PathStr, axes: Optional[Axes] = None, compress: bool = Tru |
529 | 542 | disp = disp.axes(axes or Axes.WORLD) |
530 | 543 | Image.write(disp, path, compress=compress) |
531 | 544 |
|
532 | | - def curl(self: TFlowField, mode: str = "central") -> Image: |
| 545 | + def curl( |
| 546 | + self: TFlowField, |
| 547 | + mode: Optional[str] = None, |
| 548 | + sigma: Optional[float] = None, |
| 549 | + spacing: Optional[Union[Scalar, Array]] = None, |
| 550 | + stride: Optional[ScalarOrTuple[int]] = None, |
| 551 | + ) -> Image: |
533 | 552 | r"""Compute curl of vector field.""" |
534 | 553 | batch = self.batch() |
535 | | - rotvec = batch.curl(mode=mode) |
| 554 | + rotvec = batch.curl(mode=mode, sigma=sigma, spacing=spacing, stride=stride) |
536 | 555 | return rotvec[0] |
537 | 556 |
|
538 | 557 | def exp( |
|
0 commit comments