Skip to content

Commit 51c3222

Browse files
committed
[modules] Add layer which computes Curl of 3D vector field
1 parent df3c60d commit 51c3222

File tree

2 files changed

+36
-1
lines changed

2 files changed

+36
-1
lines changed

src/deepali/modules/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from .basic import Reshape
1313
from .basic import View
1414

15+
from .flow import Curl
1516
from .flow import ExpFlow
1617

1718
from .image import BlurImage
@@ -37,6 +38,7 @@
3738
__all__ = (
3839
"AlignImage",
3940
"BlurImage",
41+
"Curl",
4042
"DeviceProperty",
4143
"ExpFlow",
4244
"FilterImage",

src/deepali/modules/flow.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,48 @@
33
from __future__ import annotations
44

55
from copy import copy as shallow_copy
6-
from typing import Optional
6+
from typing import Optional, Union
77

88
from torch import Tensor
99
from torch.nn import Module
1010

1111
from deepali.core import ALIGN_CORNERS
12+
from deepali.core import Array, Scalar, ScalarOrTuple
1213
from deepali.core import functional as U
1314

1415

16+
class Curl(Module):
17+
r"""Layer which calculates the curl of a vector field."""
18+
19+
def __init__(
20+
self,
21+
mode: Optional[str] = None,
22+
sigma: Optional[float] = None,
23+
spacing: Optional[Union[Scalar, Array]] = None,
24+
stride: Optional[ScalarOrTuple[int]] = None,
25+
) -> None:
26+
super().__init__()
27+
self.mode = mode
28+
self.sigma = sigma
29+
self.spacing = spacing
30+
self.stride = stride
31+
32+
def forward(self, x: Tensor) -> Tensor:
33+
return U.curl(x, mode=self.mode, sigma=self.sigma, spacing=self.spacing, stride=self.stride)
34+
35+
def extra_repr(self) -> str:
36+
args = []
37+
if self.mode is not None:
38+
args.append(f"mode={self.mode!r}")
39+
if self.sigma is not None:
40+
args.append(f"sigma={self.sigma!r}")
41+
if self.spacing is not None:
42+
args.append(f"spacing={self.spacing!r}")
43+
if self.stride is not None:
44+
args.append(f"stride={self.stride!r}")
45+
return ", ".join(args)
46+
47+
1548
class ExpFlow(Module):
1649
r"""Layer that computes exponential map of flow field."""
1750

0 commit comments

Comments
 (0)