|
| 1 | +# %% |
| 2 | +import matplotlib.pyplot as plt |
| 3 | +import numpy as np |
| 4 | +import torch |
| 5 | + |
| 6 | +from viscy.transforms._transforms import BatchedRandAffined |
| 7 | + |
| 8 | + |
| 9 | +def create_3d_phantom(shape=(64, 128, 128)): |
| 10 | + """ |
| 11 | + Create a 3D phantom with geometric shapes in ZYX coordinates. |
| 12 | +
|
| 13 | + Parameters |
| 14 | + ---------- |
| 15 | + shape : tuple |
| 16 | + Shape of the phantom in (Z, Y, X) order. |
| 17 | +
|
| 18 | + Returns |
| 19 | + ------- |
| 20 | + np.ndarray |
| 21 | + 3D phantom array. |
| 22 | + """ |
| 23 | + z_size, y_size, x_size = shape |
| 24 | + phantom = np.zeros(shape) |
| 25 | + |
| 26 | + # Create coordinate grids |
| 27 | + z, y, x = np.mgrid[0:z_size, 0:y_size, 0:x_size] |
| 28 | + |
| 29 | + # Center coordinates |
| 30 | + z_center, y_center, x_center = z_size // 2, y_size // 2, x_size // 2 |
| 31 | + |
| 32 | + # Sphere in the center |
| 33 | + sphere_radius = min(shape) // 6 |
| 34 | + sphere_mask = ( |
| 35 | + (z - z_center) ** 2 + (y - y_center) ** 2 + (x - x_center) ** 2 |
| 36 | + ) <= sphere_radius**2 |
| 37 | + phantom[sphere_mask] = 1.0 |
| 38 | + |
| 39 | + # Cylinder along Z-axis |
| 40 | + cyl_radius = sphere_radius // 2 |
| 41 | + cyl_x_offset = x_center + sphere_radius * 2 |
| 42 | + cyl_y_offset = y_center |
| 43 | + cylinder_mask = ((y - cyl_y_offset) ** 2 + (x - cyl_x_offset) ** 2) <= cyl_radius**2 |
| 44 | + phantom[cylinder_mask] = 0.7 |
| 45 | + |
| 46 | + # Box |
| 47 | + box_size = sphere_radius |
| 48 | + box_z_start = z_center - sphere_radius * 2 |
| 49 | + box_z_end = box_z_start + box_size |
| 50 | + box_y_start = y_center - box_size // 2 |
| 51 | + box_y_end = box_y_start + box_size |
| 52 | + box_x_start = x_center - box_size // 2 |
| 53 | + box_x_end = box_x_start + box_size |
| 54 | + |
| 55 | + if box_z_start >= 0 and box_z_end < z_size: |
| 56 | + phantom[box_z_start:box_z_end, box_y_start:box_y_end, box_x_start:box_x_end] = ( |
| 57 | + 0.5 |
| 58 | + ) |
| 59 | + |
| 60 | + return phantom |
| 61 | + |
| 62 | + |
| 63 | +def plot_3d_projections(phantom, title="3D Phantom Projections"): |
| 64 | + """ |
| 65 | + Plot 3D phantom with maximum intensity projections along each axis. |
| 66 | +
|
| 67 | + Parameters |
| 68 | + ---------- |
| 69 | + phantom : np.ndarray |
| 70 | + 3D phantom array in ZYX order. |
| 71 | + title : str |
| 72 | + Title for the figure. |
| 73 | + """ |
| 74 | + fig, axes = plt.subplots(2, 2, figsize=(12, 10)) |
| 75 | + fig.suptitle(title, fontsize=16) |
| 76 | + |
| 77 | + mip_xy = np.max(phantom, axis=0) |
| 78 | + mip_xz = np.max(phantom, axis=1) |
| 79 | + mip_yz = np.max(phantom, axis=2) |
| 80 | + |
| 81 | + z_center = phantom.shape[0] // 2 |
| 82 | + |
| 83 | + im1 = axes[0, 0].imshow(mip_xy, cmap="viridis", origin="lower") |
| 84 | + axes[0, 0].set_title("MIP XY (along Z)") |
| 85 | + axes[0, 0].set_xlabel("X") |
| 86 | + axes[0, 0].set_ylabel("Y") |
| 87 | + plt.colorbar(im1, ax=axes[0, 0]) |
| 88 | + |
| 89 | + im2 = axes[0, 1].imshow(mip_xz, cmap="viridis", origin="lower") |
| 90 | + axes[0, 1].set_title("MIP XZ (along Y)") |
| 91 | + axes[0, 1].set_xlabel("X") |
| 92 | + axes[0, 1].set_ylabel("Z") |
| 93 | + plt.colorbar(im2, ax=axes[0, 1]) |
| 94 | + |
| 95 | + im3 = axes[1, 0].imshow(mip_yz, cmap="viridis", origin="lower") |
| 96 | + axes[1, 0].set_title("MIP YZ (along X)") |
| 97 | + axes[1, 0].set_xlabel("Y") |
| 98 | + axes[1, 0].set_ylabel("Z") |
| 99 | + plt.colorbar(im3, ax=axes[1, 0]) |
| 100 | + |
| 101 | + im4 = axes[1, 1].imshow(phantom[z_center, :, :], cmap="viridis", origin="lower") |
| 102 | + axes[1, 1].set_title(f"Central XY slice (Z={z_center})") |
| 103 | + axes[1, 1].set_xlabel("X") |
| 104 | + axes[1, 1].set_ylabel("Y") |
| 105 | + plt.colorbar(im4, ax=axes[1, 1]) |
| 106 | + |
| 107 | + plt.tight_layout() |
| 108 | + return fig |
| 109 | + |
| 110 | + |
| 111 | +def apply_shear_transform(phantom, shear_values, prob=1.0): |
| 112 | + """ |
| 113 | + Apply shear transformation using BatchedRandAffined. |
| 114 | +
|
| 115 | + Parameters |
| 116 | + ---------- |
| 117 | + phantom : np.ndarray |
| 118 | + 3D phantom array in ZYX order. |
| 119 | + shear_values : list or tuple |
| 120 | + Shear values for each facet [sxy, sxz, syx, syz, szx, szy] in radians. |
| 121 | + prob : float |
| 122 | + Probability of applying transform. |
| 123 | +
|
| 124 | + Returns |
| 125 | + ------- |
| 126 | + np.ndarray |
| 127 | + Transformed phantom. |
| 128 | + """ |
| 129 | + # Convert to torch tensor with batch and channel dimensions |
| 130 | + phantom_tensor = torch.from_numpy(phantom).float() |
| 131 | + phantom_tensor = phantom_tensor.unsqueeze(0).unsqueeze( |
| 132 | + 0 |
| 133 | + ) # Add batch and channel dims |
| 134 | + transform = BatchedRandAffined( |
| 135 | + keys=["image"], prob=prob, shear_range=shear_values, mode="bilinear" |
| 136 | + ) |
| 137 | + sample = {"image": phantom_tensor} |
| 138 | + transformed_sample = transform(sample) |
| 139 | + transformed_phantom = transformed_sample["image"].squeeze().numpy() |
| 140 | + |
| 141 | + return transformed_phantom |
| 142 | + |
| 143 | + |
| 144 | +if __name__ == "__main__": |
| 145 | + phantom = create_3d_phantom((64, 128, 128)) |
| 146 | + |
| 147 | + fig1 = plot_3d_projections(phantom, "Original Phantom") |
| 148 | + shear_names = ["s01", "s02", "s10", "s12", "s20", "s21"] |
| 149 | + |
| 150 | + """ |
| 151 | + s{ij}: |
| 152 | +
|
| 153 | + [ |
| 154 | + [1.0, params[0], params[1], 0.0], |
| 155 | + [params[2], 1.0, params[3], 0.0], |
| 156 | + [params[4], params[5], 1.0, 0.0], |
| 157 | + [0.0, 0.0, 0.0, 1.0], |
| 158 | + ] |
| 159 | + """ |
| 160 | + |
| 161 | + for axis in range(6): |
| 162 | + shear = [0.0] * 6 |
| 163 | + shear[axis] = 0.5 |
| 164 | + phantom_sheared = apply_shear_transform(phantom, shear) |
| 165 | + fig = plot_3d_projections( |
| 166 | + phantom_sheared, f"Shear applied: {shear_names[axis]}=0.5" |
| 167 | + ) |
| 168 | + plt.show() |
| 169 | + |
| 170 | + shear_combined = [0.2, 0.2, 0.0, 0.2, 0.0, 0.2] |
| 171 | + phantom_combined = apply_shear_transform(phantom, shear_combined) |
| 172 | + fig6 = plot_3d_projections(phantom_combined, "Combined Shears") |
| 173 | + |
| 174 | + plt.show() |
| 175 | + |
| 176 | +# %% |
0 commit comments