Skip to content

Commit 03cbc6a

Browse files
edyoshikunziw-liu
andauthored
Example for Shearing transform (#305)
* add example for visualizing the shearing transform * plot dimensions in a loop * fix comment * ruff --------- Co-authored-by: Ziwen Liu <[email protected]>
1 parent 3135bd5 commit 03cbc6a

File tree

1 file changed

+176
-0
lines changed

1 file changed

+176
-0
lines changed
Lines changed: 176 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,176 @@
1+
# %%
2+
import matplotlib.pyplot as plt
3+
import numpy as np
4+
import torch
5+
6+
from viscy.transforms._transforms import BatchedRandAffined
7+
8+
9+
def create_3d_phantom(shape=(64, 128, 128)):
10+
"""
11+
Create a 3D phantom with geometric shapes in ZYX coordinates.
12+
13+
Parameters
14+
----------
15+
shape : tuple
16+
Shape of the phantom in (Z, Y, X) order.
17+
18+
Returns
19+
-------
20+
np.ndarray
21+
3D phantom array.
22+
"""
23+
z_size, y_size, x_size = shape
24+
phantom = np.zeros(shape)
25+
26+
# Create coordinate grids
27+
z, y, x = np.mgrid[0:z_size, 0:y_size, 0:x_size]
28+
29+
# Center coordinates
30+
z_center, y_center, x_center = z_size // 2, y_size // 2, x_size // 2
31+
32+
# Sphere in the center
33+
sphere_radius = min(shape) // 6
34+
sphere_mask = (
35+
(z - z_center) ** 2 + (y - y_center) ** 2 + (x - x_center) ** 2
36+
) <= sphere_radius**2
37+
phantom[sphere_mask] = 1.0
38+
39+
# Cylinder along Z-axis
40+
cyl_radius = sphere_radius // 2
41+
cyl_x_offset = x_center + sphere_radius * 2
42+
cyl_y_offset = y_center
43+
cylinder_mask = ((y - cyl_y_offset) ** 2 + (x - cyl_x_offset) ** 2) <= cyl_radius**2
44+
phantom[cylinder_mask] = 0.7
45+
46+
# Box
47+
box_size = sphere_radius
48+
box_z_start = z_center - sphere_radius * 2
49+
box_z_end = box_z_start + box_size
50+
box_y_start = y_center - box_size // 2
51+
box_y_end = box_y_start + box_size
52+
box_x_start = x_center - box_size // 2
53+
box_x_end = box_x_start + box_size
54+
55+
if box_z_start >= 0 and box_z_end < z_size:
56+
phantom[box_z_start:box_z_end, box_y_start:box_y_end, box_x_start:box_x_end] = (
57+
0.5
58+
)
59+
60+
return phantom
61+
62+
63+
def plot_3d_projections(phantom, title="3D Phantom Projections"):
64+
"""
65+
Plot 3D phantom with maximum intensity projections along each axis.
66+
67+
Parameters
68+
----------
69+
phantom : np.ndarray
70+
3D phantom array in ZYX order.
71+
title : str
72+
Title for the figure.
73+
"""
74+
fig, axes = plt.subplots(2, 2, figsize=(12, 10))
75+
fig.suptitle(title, fontsize=16)
76+
77+
mip_xy = np.max(phantom, axis=0)
78+
mip_xz = np.max(phantom, axis=1)
79+
mip_yz = np.max(phantom, axis=2)
80+
81+
z_center = phantom.shape[0] // 2
82+
83+
im1 = axes[0, 0].imshow(mip_xy, cmap="viridis", origin="lower")
84+
axes[0, 0].set_title("MIP XY (along Z)")
85+
axes[0, 0].set_xlabel("X")
86+
axes[0, 0].set_ylabel("Y")
87+
plt.colorbar(im1, ax=axes[0, 0])
88+
89+
im2 = axes[0, 1].imshow(mip_xz, cmap="viridis", origin="lower")
90+
axes[0, 1].set_title("MIP XZ (along Y)")
91+
axes[0, 1].set_xlabel("X")
92+
axes[0, 1].set_ylabel("Z")
93+
plt.colorbar(im2, ax=axes[0, 1])
94+
95+
im3 = axes[1, 0].imshow(mip_yz, cmap="viridis", origin="lower")
96+
axes[1, 0].set_title("MIP YZ (along X)")
97+
axes[1, 0].set_xlabel("Y")
98+
axes[1, 0].set_ylabel("Z")
99+
plt.colorbar(im3, ax=axes[1, 0])
100+
101+
im4 = axes[1, 1].imshow(phantom[z_center, :, :], cmap="viridis", origin="lower")
102+
axes[1, 1].set_title(f"Central XY slice (Z={z_center})")
103+
axes[1, 1].set_xlabel("X")
104+
axes[1, 1].set_ylabel("Y")
105+
plt.colorbar(im4, ax=axes[1, 1])
106+
107+
plt.tight_layout()
108+
return fig
109+
110+
111+
def apply_shear_transform(phantom, shear_values, prob=1.0):
112+
"""
113+
Apply shear transformation using BatchedRandAffined.
114+
115+
Parameters
116+
----------
117+
phantom : np.ndarray
118+
3D phantom array in ZYX order.
119+
shear_values : list or tuple
120+
Shear values for each facet [sxy, sxz, syx, syz, szx, szy] in radians.
121+
prob : float
122+
Probability of applying transform.
123+
124+
Returns
125+
-------
126+
np.ndarray
127+
Transformed phantom.
128+
"""
129+
# Convert to torch tensor with batch and channel dimensions
130+
phantom_tensor = torch.from_numpy(phantom).float()
131+
phantom_tensor = phantom_tensor.unsqueeze(0).unsqueeze(
132+
0
133+
) # Add batch and channel dims
134+
transform = BatchedRandAffined(
135+
keys=["image"], prob=prob, shear_range=shear_values, mode="bilinear"
136+
)
137+
sample = {"image": phantom_tensor}
138+
transformed_sample = transform(sample)
139+
transformed_phantom = transformed_sample["image"].squeeze().numpy()
140+
141+
return transformed_phantom
142+
143+
144+
if __name__ == "__main__":
145+
phantom = create_3d_phantom((64, 128, 128))
146+
147+
fig1 = plot_3d_projections(phantom, "Original Phantom")
148+
shear_names = ["s01", "s02", "s10", "s12", "s20", "s21"]
149+
150+
"""
151+
s{ij}:
152+
153+
[
154+
[1.0, params[0], params[1], 0.0],
155+
[params[2], 1.0, params[3], 0.0],
156+
[params[4], params[5], 1.0, 0.0],
157+
[0.0, 0.0, 0.0, 1.0],
158+
]
159+
"""
160+
161+
for axis in range(6):
162+
shear = [0.0] * 6
163+
shear[axis] = 0.5
164+
phantom_sheared = apply_shear_transform(phantom, shear)
165+
fig = plot_3d_projections(
166+
phantom_sheared, f"Shear applied: {shear_names[axis]}=0.5"
167+
)
168+
plt.show()
169+
170+
shear_combined = [0.2, 0.2, 0.0, 0.2, 0.0, 0.2]
171+
phantom_combined = apply_shear_transform(phantom, shear_combined)
172+
fig6 = plot_3d_projections(phantom_combined, "Combined Shears")
173+
174+
plt.show()
175+
176+
# %%

0 commit comments

Comments
 (0)