|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | from copy import copy as shallow_copy |
6 | | -from typing import Optional |
| 6 | +from typing import Optional, Union |
7 | 7 |
|
8 | 8 | from torch import Tensor |
9 | 9 | from torch.nn import Module |
10 | 10 |
|
11 | 11 | from deepali.core import ALIGN_CORNERS |
| 12 | +from deepali.core import Array, Scalar, ScalarOrTuple |
12 | 13 | from deepali.core import functional as U |
13 | 14 |
|
14 | 15 |
|
| 16 | +class Curl(Module): |
| 17 | + r"""Layer which calculates the curl of a vector field.""" |
| 18 | + |
| 19 | + def __init__( |
| 20 | + self, |
| 21 | + mode: Optional[str] = None, |
| 22 | + sigma: Optional[float] = None, |
| 23 | + spacing: Optional[Union[Scalar, Array]] = None, |
| 24 | + stride: Optional[ScalarOrTuple[int]] = None, |
| 25 | + ) -> None: |
| 26 | + super().__init__() |
| 27 | + self.mode = mode |
| 28 | + self.sigma = sigma |
| 29 | + self.spacing = spacing |
| 30 | + self.stride = stride |
| 31 | + |
| 32 | + def forward(self, x: Tensor) -> Tensor: |
| 33 | + return U.curl(x, mode=self.mode, sigma=self.sigma, spacing=self.spacing, stride=self.stride) |
| 34 | + |
| 35 | + def extra_repr(self) -> str: |
| 36 | + args = [] |
| 37 | + if self.mode is not None: |
| 38 | + args.append(f"mode={self.mode!r}") |
| 39 | + if self.sigma is not None: |
| 40 | + args.append(f"sigma={self.sigma!r}") |
| 41 | + if self.spacing is not None: |
| 42 | + args.append(f"spacing={self.spacing!r}") |
| 43 | + if self.stride is not None: |
| 44 | + args.append(f"stride={self.stride!r}") |
| 45 | + return ", ".join(args) |
| 46 | + |
| 47 | + |
15 | 48 | class ExpFlow(Module): |
16 | 49 | r"""Layer that computes exponential map of flow field.""" |
17 | 50 |
|
|
0 commit comments