Skip to content

Commit 41bcc55

Browse files
committed
[core] Fix spatial derivatives when using mode='bspline' or mode='gaussian'
1 parent 580a05d commit 41bcc55

File tree

2 files changed

+236
-14
lines changed

2 files changed

+236
-14
lines changed

src/deepali/core/image.py

Lines changed: 31 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1486,7 +1486,10 @@ def spatial_derivatives(
14861486
If ``None``, ``forward_central_backward`` is used as default mode.
14871487
sigma: Standard deviation of Gaussian kernel in grid units. If ``None`` or zero,
14881488
no Gaussian smoothing is used for calculation of finite differences, and a
1489-
default standard deviation of 0.4 is used when ``mode="gaussian"``.
1489+
default standard deviation of 0.7355 is used when ``mode="gaussian"``. With a smaller
1490+
standard deviation, the magnitude of the derivative values starts to deviate between
1491+
``mode="gaussian"`` and finite differences of a Gaussian smoothed input. This is likely
1492+
due to a too small discretized Gaussian filter and its derivative.
14901493
spacing: Physical spacing between image grid points, e.g., ``(sx, sy, sz)``.
14911494
When a scalar is given, the same spacing is used for each image and spatial dimension.
14921495
If a sequence is given, it must be of length equal to the number of spatial dimensions ``D``,
@@ -1556,7 +1559,7 @@ def spatial_derivatives(
15561559
if mode in ("forward", "backward", "central", "forward_central_backward", "prewitt", "sobel"):
15571560
if sigma and sigma > 0:
15581561
blur = gaussian1d(sigma, dtype=torch.float, device=data.device)
1559-
data = conv(data, blur, padding=PaddingMode.ZEROS)
1562+
data = conv(data, blur, padding=PaddingMode.REPLICATE)
15601563
if mode in ("prewitt", "sobel"):
15611564
avg_kernel = torch.tensor([1, 1 if mode == "prewitt" else 2, 1], dtype=data.dtype)
15621565
avg_kernel /= avg_kernel.sum()
@@ -1589,7 +1592,7 @@ def spatial_derivatives(
15891592

15901593
if sigma and sigma > 0:
15911594
blur = gaussian1d(sigma, dtype=torch.float, device=data.device)
1592-
data = conv(data, blur, padding=PaddingMode.ZEROS)
1595+
data = conv(data, blur, padding=PaddingMode.REPLICATE)
15931596

15941597
if stride is None:
15951598
stride = 1
@@ -1616,27 +1619,41 @@ def bspline1d(s: int, d: int) -> Tensor:
16161619
for spatial_dim in SpatialDerivativeKeys.split(code):
16171620
order[spatial_dim] += 1
16181621
kernel = [bspline1d(s, d) for s, d in zip(stride, order)]
1619-
derivs[code] = evaluate_cubic_bspline(data, kernel=kernel)
1622+
deriv = evaluate_cubic_bspline(data, kernel=kernel)
1623+
if sum(order) > 0:
1624+
denom = torch.ones(N, dtype=spacing.dtype, device=spacing.device)
1625+
for delta, d in zip(spacing.transpose(0, 1), order):
1626+
if d > 0:
1627+
denom.mul_(delta.pow(d))
1628+
denom = denom.reshape((N,) + (1,) * (deriv.ndim - 1))
1629+
deriv = deriv.div_(denom.to(deriv))
1630+
derivs[code] = deriv
16201631

16211632
elif mode == "gaussian":
1633+
1634+
def pad_spatial_dim(data: Tensor, sdim: int, padding: int) -> Tensor:
1635+
pad = [(padding, padding) if d == sdim else (0, 0) for d in range(data.ndim - 2)]
1636+
pad = [n for v in pad for n in v]
1637+
return F.pad(data, pad, mode="replicate")
1638+
16221639
if not sigma:
1623-
sigma = 0.4
1624-
kernel_0 = gaussian1d(sigma, normalize=False, dtype=torch.float)
1625-
kernel_1 = gaussian1d_I(sigma, normalize=False, dtype=torch.float)
1626-
norm = kernel_0.sum()
1627-
kernel_0 = kernel_0.div_(norm).to(data.device)
1628-
kernel_1 = kernel_1.div_(norm).to(data.device)
1640+
sigma = 0.7355 # same default value as used in downsample()
1641+
kernel_0 = gaussian1d(sigma, normalize=False, dtype=torch.float, device=data.device)
1642+
kernel_1 = gaussian1d_I(sigma, normalize=False, dtype=torch.float, device=data.device)
16291643
for i in range(max_order):
16301644
for code in unique_keys:
16311645
key = code[: i + 1]
16321646
if i < len(code) and key not in derivs:
16331647
sdim = SpatialDim.from_arg(code[i])
1634-
result = data if i == 0 else derivs[code[:i]]
1648+
deriv = data if i == 0 else derivs[code[:i]]
16351649
for d in range(D):
1636-
dim = SpatialDim(d).tensor_dim(result.ndim)
1650+
dim = SpatialDim(d).tensor_dim(deriv.ndim)
16371651
kernel = kernel_1 if sdim == d else kernel_0
1638-
result = conv1d(result, kernel, dim=dim, padding=len(kernel) // 2)
1639-
derivs[key] = result
1652+
deriv = pad_spatial_dim(deriv, d, len(kernel) // 2)
1653+
deriv = conv1d(deriv, kernel, dim=dim, padding=0)
1654+
denom = spacing.narrow(1, sdim, 1).reshape((N,) + (1,) * (deriv.ndim - 1))
1655+
deriv = deriv.div_(denom.to(deriv))
1656+
derivs[key] = deriv
16401657
derivs = {key: derivs[SpatialDerivativeKeys.sorted(key)] for key in which}
16411658

16421659
else:

tests/_test_core_flow_deriv.py

Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
r"""Interactive test and visualization of vector flow derivatives."""
2+
3+
# %%
4+
# Imports
5+
from typing import Dict, Optional, Sequence
6+
7+
import matplotlib.pyplot as plt
8+
9+
import torch
10+
from torch import Tensor
11+
from torch.random import Generator
12+
13+
from deepali.core import Axes, Grid
14+
import deepali.core.bspline as B
15+
import deepali.core.functional as U
16+
17+
18+
# %%
19+
# Auxiliary functions
20+
def change_axes(flow: Tensor, grid: Grid, axes: Axes, to_axes: Axes) -> Tensor:
21+
if axes != to_axes:
22+
flow = U.move_dim(flow, 1, -1)
23+
flow = grid.transform_vectors(flow, axes=axes, to_axes=to_axes)
24+
flow = U.move_dim(flow, -1, 1)
25+
return flow
26+
27+
28+
def flow_derivatives(
29+
flow: Tensor, grid: Grid, axes: Axes, to_axes: Optional[Axes] = None, **kwargs
30+
) -> Dict[str, Tensor]:
31+
if to_axes is None:
32+
to_axes = axes
33+
flow = change_axes(flow, grid, axes, to_axes)
34+
axes = to_axes
35+
if "spacing" not in kwargs:
36+
if axes == Axes.CUBE:
37+
spacing = tuple(2 / n for n in grid.size())
38+
elif axes == Axes.CUBE_CORNERS:
39+
spacing = tuple(2 / (n - 1) for n in grid.size())
40+
elif axes == Axes.GRID:
41+
spacing = 1
42+
elif axes == Axes.WORLD:
43+
spacing = grid.spacing()
44+
else:
45+
spacing = None
46+
kwargs["spacing"] = spacing
47+
return U.flow_derivatives(flow, **kwargs)
48+
49+
50+
def random_svf(
51+
size: Sequence[int],
52+
stride: int = 1,
53+
generator: Optional[Generator] = None,
54+
) -> Tensor:
55+
cp_grid_size = B.cubic_bspline_control_point_grid_size(size, stride=stride)
56+
cp_grid_size = tuple(reversed(cp_grid_size))
57+
data = torch.randn((1, 3) + cp_grid_size, generator=generator)
58+
data = U.fill_border(data, margin=3, value=0, inplace=True)
59+
return B.evaluate_cubic_bspline(data, size=size, stride=stride)
60+
61+
62+
def visualize_flow(
63+
ax: plt.Axes,
64+
flow: Tensor,
65+
grid: Optional[Grid] = None,
66+
axes: Optional[Axes] = None,
67+
label: Optional[str] = None,
68+
) -> None:
69+
if grid is None:
70+
grid = Grid(shape=flow.shape[2:])
71+
if axes is None:
72+
axes = grid.axes()
73+
flow = change_axes(flow, grid, axes, grid.axes())
74+
x = grid.coords(channels_last=False, dtype=flow.dtype, device=flow.device)
75+
x = U.move_dim(x.unsqueeze_(0).add_(flow), 1, -1)
76+
target_grid = U.grid_image(shape=flow.shape[2:], inverted=True, stride=(5, 5))
77+
warped_grid = U.warp_image(target_grid, x, align_corners=grid.align_corners())
78+
ax.imshow(warped_grid[0, 0, flow.shape[2] // 2], cmap="gray")
79+
if label:
80+
ax.set_title(label, fontsize=24)
81+
82+
83+
# %%
84+
# Random velocity fields
85+
generator = torch.Generator().manual_seed(42)
86+
grid = Grid(size=(128, 128, 64), spacing=(0.5, 0.5, 1.0))
87+
flow = random_svf(grid.size(), stride=8, generator=generator).mul_(0.1)
88+
89+
fig, axes = plt.subplots(1, 1, figsize=(4, 4))
90+
91+
ax = axes
92+
ax.set_title("v", fontsize=24, pad=20)
93+
visualize_flow(ax, flow, grid=grid, axes=grid.axes())
94+
95+
96+
# %%
97+
# Visualise first order derivatives for different modes
98+
configs = [
99+
dict(mode="forward_central_backward"),
100+
dict(mode="bspline"),
101+
dict(mode="gaussian", sigma=0.7355),
102+
]
103+
104+
fig, axes = plt.subplots(len(configs), 4, figsize=(16, 4 * len(configs)))
105+
106+
for i, config in enumerate(configs):
107+
derivs = flow_derivatives(
108+
flow,
109+
grid=grid,
110+
axes=grid.axes(),
111+
to_axes=Axes.GRID,
112+
which=["du/dx", "du/dy", "dv/dx", "dv/dy"],
113+
**config,
114+
)
115+
for ax, (key, deriv) in zip(axes[i], derivs.items()):
116+
if i == 0:
117+
ax.set_title(key, fontsize=24, pad=20)
118+
ax.imshow(deriv[0, 0, deriv.shape[2] // 2], vmin=-1, vmax=1)
119+
120+
121+
# %%
122+
# Compare magnitudes of first order derivatives for different modes
123+
flow_axes = [Axes.GRID, Axes.WORLD, Axes.CUBE_CORNERS]
124+
125+
sigma = 0.7355
126+
127+
configs = [
128+
dict(mode="bspline"),
129+
dict(mode="gaussian", sigma=sigma),
130+
dict(mode="forward_central_backward", sigma=sigma),
131+
dict(mode="forward_central_backward"),
132+
]
133+
134+
for to_axes in flow_axes:
135+
for config in configs:
136+
print(f"axes={to_axes}, " + ", ".join(f"{k}={v!r}" for k, v in config.items()))
137+
derivs = flow_derivatives(
138+
flow,
139+
grid=grid,
140+
axes=grid.axes(),
141+
to_axes=to_axes,
142+
which=["du/dx", "du/dy", "dv/dx", "dv/dy"],
143+
**config,
144+
)
145+
for key, deriv in derivs.items():
146+
print(f"- max(abs({key})): {deriv.abs().max().item():.5f}")
147+
print()
148+
print("\n")
149+
150+
151+
# %%
152+
# Visualise second order derivatives for different modes
153+
configs = [
154+
dict(mode="forward_central_backward"),
155+
dict(mode="bspline"),
156+
dict(mode="gaussian", sigma=0.7355),
157+
]
158+
159+
fig, axes = plt.subplots(len(configs), 4, figsize=(16, 4 * len(configs)))
160+
161+
for i, config in enumerate(configs):
162+
derivs = flow_derivatives(
163+
flow,
164+
grid=grid,
165+
axes=grid.axes(),
166+
to_axes=Axes.GRID,
167+
which=["du/dxx", "du/dxy", "dv/dxy", "dv/dyy"],
168+
**config,
169+
)
170+
for ax, (key, deriv) in zip(axes[i], derivs.items()):
171+
if i == 0:
172+
ax.set_title(key, fontsize=24, pad=20)
173+
ax.imshow(deriv[0, 0, deriv.shape[2] // 2], vmin=-0.4, vmax=0.4)
174+
175+
176+
# %%
177+
# Compare magnitudes of second order derivatives for different modes
178+
flow_axes = [Axes.GRID, Axes.WORLD, Axes.CUBE_CORNERS]
179+
180+
sigma = 0.7355
181+
182+
configs = [
183+
dict(mode="bspline"),
184+
dict(mode="gaussian", sigma=sigma),
185+
dict(mode="forward_central_backward", sigma=sigma),
186+
dict(mode="forward_central_backward"),
187+
]
188+
189+
for to_axes in flow_axes:
190+
for config in configs:
191+
print(f"axes={to_axes}, " + ", ".join(f"{k}={v!r}" for k, v in config.items()))
192+
derivs = flow_derivatives(
193+
flow,
194+
grid=grid,
195+
axes=grid.axes(),
196+
to_axes=to_axes,
197+
which=["du/dxx", "du/dxy", "dv/dxy", "dv/dyy"],
198+
**config,
199+
)
200+
for key, deriv in derivs.items():
201+
print(f"- max(abs({key})): {deriv.abs().max().item():.5f}")
202+
print()
203+
print("\n")
204+
205+
# %%

0 commit comments

Comments
 (0)