Skip to content

Commit 96571f3

Browse files
committed
Adding explanations
1 parent 8b559ba commit 96571f3

File tree

1 file changed

+73
-43
lines changed

1 file changed

+73
-43
lines changed

tutorials/2-examples/DTEx252_phase_mask_optimization.ipynb

Lines changed: 73 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -5,23 +5,26 @@
55
"id": "5e65db32",
66
"metadata": {},
77
"source": [
8-
"# Particle localization and phase mask optimization"
8+
"# Particle Localization and Phase Mask Optimization"
99
]
1010
},
1111
{
1212
"cell_type": "markdown",
1313
"id": "d4aa8b6b",
1414
"metadata": {},
1515
"source": [
16-
"This tutorial demonstrates how to optimize a phase mask to improve localization of closely spaced particles. The localization is performed by a convolutional neural network (CNN), trained jointly with the phase mask to improve both the optical setup and the network performance."
16+
"This tutorial demonstrates how to jointly optimize an optical phase mask and a neural network to improve the localization of closely spaced particles in microscopy images.\n",
17+
"\n",
18+
"The phase mask, inserted in the optical pupil plane, shapes the microscope’s point spread function (PSF). Simultaneously, a convolutional neural network (CNN) is trained to reconstruct the 3D particle positions from simulated images. Through backpropagation, gradients flow not only through the network but also through the optical model. As a result, the optical system learns to produce images that are easier for the CNN to interpret, leading to improved localization accuracy."
1719
]
1820
},
1921
{
2022
"cell_type": "markdown",
2123
"id": "82dfa30b",
2224
"metadata": {},
2325
"source": [
24-
"# 1. Create a simulation pipeline"
26+
"# 1. Create a simulation pipeline\n",
27+
"The first step is to define a simulation pipeline that generates synthetic data consisting of fluorescent particles imaged through a fluorescence microscope. For each simulated image, the pipeline also provides the true 3D positions of the particles, which serve as the ground truth during training."
2528
]
2629
},
2730
{
@@ -31,28 +34,24 @@
3134
"metadata": {},
3235
"outputs": [
3336
{
34-
"name": "stdout",
37+
"name": "stderr",
3538
"output_type": "stream",
3639
"text": [
37-
"c:\\Users/xgrmir/Documents/DeepTrack2/DeepTrack2\\deeptrack\\__init__.py\n"
40+
"WARNING:pint.util:Redefining '[magnetic_flux]' (<class 'pint.delegates.txt_defparser.plain.DerivedDimensionDefinition'>)\n"
3841
]
3942
}
4043
],
4144
"source": [
42-
"import sys\n",
43-
"\n",
44-
"sys.path.insert(0, '/Users/xgrmir/Documents/DeepTrack2/DeepTrack2')\n",
45-
"\n",
46-
"import deeptrack as dt\n",
47-
"print(dt.__file__)"
45+
"import deeptrack as dt"
4846
]
4947
},
5048
{
5149
"cell_type": "markdown",
5250
"id": "8a7638e5",
5351
"metadata": {},
5452
"source": [
55-
"## 1.1 Set the backend"
53+
"## 1.1 Set the backend\n",
54+
"DeepTrack supports both NumPy and PyTorch backends. In this tutorial, the PyTorch backend is selected to enable backpropagation through the optical setup. Enabling CUDA (optional) allows the computations to run on the GPU, which accelerates both simulation and training."
5655
]
5756
},
5857
{
@@ -86,13 +85,14 @@
8685
"metadata": {},
8786
"source": [
8887
"## 1.2 Define the trainable phase mask\n",
88+
"A phase mask modifies the wavefront of light in the microscope pupil plane, shaping the resulting point spread function (PSF). By making the phase mask trainable, we allow the optical system itself to learn — optimizing its design through gradient descent to improve particle localization performance.\n",
8989
"\n",
90-
"I don't know if inheriting from both nn.Module and dt.Aberration is correct? I inherit from dt.Aberration to get a dt feature that is compatible with the pipeline, and I inherit from nn.Module to make the layer trainable"
90+
"The LearnablePhaseMask class inherits from both `nn.Module` and `dt.Aberration`. `nn.Module` makes the mask compatible with PyTorch’s autograd and optimizer system, allowing the phase parameters to be learned, while `dt.Aberration` ensures the class integrates seamlessly into DeepTrack’s feature graph, making it compatible with the rest of the optical pipeline."
9191
]
9292
},
9393
{
9494
"cell_type": "code",
95-
"execution_count": 3,
95+
"execution_count": null,
9696
"id": "3470f1bf",
9797
"metadata": {},
9898
"outputs": [],
@@ -107,19 +107,19 @@
107107
" dt.Aberration.__init__(self, **kwargs)\n",
108108
" super().__init__(**kwargs)\n",
109109
" \n",
110-
" # Create a torch parameter for the phase\n",
110+
" # Create a trainable tensor representing the phase\n",
111111
" self.phase = nn.Parameter(torch.zeros(shape, dtype=torch.float32))\n",
112112
"\n",
113113
" def forward(self, pupil: torch.Tensor, **kwargs) -> torch.Tensor:\n",
114-
" \"\"\"PyTorch forward pass\"\"\"\n",
114+
" \"\"\"PyTorch forward pass for use in training\"\"\"\n",
115115
" return self.apply_phase(pupil)\n",
116116
"\n",
117117
" def get(self, pupil: torch.Tensor, **kwargs) -> torch.Tensor:\n",
118118
" \"\"\"DeepTrack Feature graph call\"\"\"\n",
119119
" return self.apply_phase(pupil)\n",
120120
"\n",
121121
" def apply_phase(self, pupil: torch.Tensor) -> torch.Tensor:\n",
122-
" \"\"\"Shared implementation\"\"\"\n",
122+
" \"\"\"Shared implementation of phase modulation\"\"\"\n",
123123
" phase = self.phase.to(device)\n",
124124
" phase_mask = torch.cos(phase) + 1j * torch.sin(phase)\n",
125125
" return pupil * phase_mask"
@@ -130,7 +130,8 @@
130130
"id": "e8a1553c",
131131
"metadata": {},
132132
"source": [
133-
"## 1.3 Define the optical setup"
133+
"## 1.3 Define the optical setup\n",
134+
"The optical setup is simulated using `dt.Fluorescence`, incorporating the trainable phase mask as the pupil function."
134135
]
135136
},
136137
{
@@ -166,25 +167,10 @@
166167
"id": "8fe88c92",
167168
"metadata": {},
168169
"source": [
169-
"## 1.4 Simulate particles"
170-
]
171-
},
172-
{
173-
"cell_type": "code",
174-
"execution_count": 6,
175-
"id": "d5c1e312",
176-
"metadata": {},
177-
"outputs": [],
178-
"source": [
179-
"import matplotlib.pyplot as plt"
180-
]
181-
},
182-
{
183-
"cell_type": "markdown",
184-
"id": "06cffcce",
185-
"metadata": {},
186-
"source": [
187-
"### Random positions in 3D, and gt in 3D:"
170+
"## 1.4 Simulate particles\n",
171+
"A simulation pipeline is defined to generate images of randomly positioned 3D particles, together with their corresponding ground-truth localization targets.\n",
172+
"\n",
173+
"A custom DeepTrack feature, `Positions`, creates a binary 3D mask containing randomly placed particles. Each voxel corresponding to a particle is assigned a value of 1, while background voxels remain 0. The true particle coordinates are stored in self.points for later retrieval."
188174
]
189175
},
190176
{
@@ -212,6 +198,14 @@
212198
" return mask + image"
213199
]
214200
},
201+
{
202+
"cell_type": "markdown",
203+
"id": "38c4a606",
204+
"metadata": {},
205+
"source": [
206+
"The number of points for each sample is randomized between 25 and 50 to provide diverse training examples."
207+
]
208+
},
215209
{
216210
"cell_type": "code",
217211
"execution_count": 8,
@@ -223,14 +217,24 @@
223217
"xyz = Positions(num_points=num_points)"
224218
]
225219
},
220+
{
221+
"cell_type": "markdown",
222+
"id": "446fee06",
223+
"metadata": {},
224+
"source": [
225+
"To make the training data physically plausible, a combination of Poisson noise and Gaussian noise is added.\n",
226+
"\n",
227+
"Since true Poisson noise is not differentiable, a Gaussian approximation that preserves differentiability during backpropagation is used."
228+
]
229+
},
226230
{
227231
"cell_type": "code",
228-
"execution_count": 9,
232+
"execution_count": null,
229233
"id": "7d339aaf",
230234
"metadata": {},
231235
"outputs": [],
232236
"source": [
233-
"class poisson_noise_approx(dt.Noise): # Because the real poisson noise is not compatible with backpropagation\n",
237+
"class poisson_noise_approx(dt.Noise):\n",
234238
" def __init__(\n",
235239
" self,\n",
236240
" **kwargs,\n",
@@ -250,6 +254,14 @@
250254
" return noisy_image"
251255
]
252256
},
257+
{
258+
"cell_type": "markdown",
259+
"id": "9ac539dc",
260+
"metadata": {},
261+
"source": [
262+
"Next, the components, including the particle positions, the optics, and the noise, are combined to form the simulation pipeline."
263+
]
264+
},
253265
{
254266
"cell_type": "code",
255267
"execution_count": 10,
@@ -275,9 +287,17 @@
275287
"pip = (im_pip & gt_pip) >> dt.MoveAxis(2, 0)"
276288
]
277289
},
290+
{
291+
"cell_type": "markdown",
292+
"id": "6bff75c1",
293+
"metadata": {},
294+
"source": [
295+
"An example simulation is visualized below."
296+
]
297+
},
278298
{
279299
"cell_type": "code",
280-
"execution_count": 36,
300+
"execution_count": null,
281301
"id": "2f78ddd7",
282302
"metadata": {},
283303
"outputs": [
@@ -303,13 +323,23 @@
303323
}
304324
],
305325
"source": [
326+
"import matplotlib.pyplot as plt\n",
327+
"\n",
306328
"pip.update()\n",
307329
"im, gt = pip.resolve()\n",
308330
"\n",
309331
"fig, axs = plt.subplots(1, 2, figsize=(8,4))\n",
310332
"\n",
311333
"axs[0].imshow(im[0].cpu().detach().numpy(), cmap='gray')\n",
312-
"axs[1].imshow(gt.max(dim=0)[0])"
334+
"axs[1].imshow(gt.max(dim=0)[0]);"
335+
]
336+
},
337+
{
338+
"cell_type": "markdown",
339+
"id": "297f7ade",
340+
"metadata": {},
341+
"source": [
342+
"The left image shows a synthetic microscopy image produced by the optical system with the current, untrained phase mask, while the right image displays the corresponding ground-truth particle positions projected onto the x–y plane."
313343
]
314344
},
315345
{
@@ -2182,7 +2212,7 @@
21822212
},
21832213
{
21842214
"cell_type": "code",
2185-
"execution_count": 30,
2215+
"execution_count": null,
21862216
"id": "edd1bcd0",
21872217
"metadata": {},
21882218
"outputs": [],
@@ -2193,7 +2223,7 @@
21932223
" output_region=(0,0, image_size, image_size),\n",
21942224
")\n",
21952225
"\n",
2196-
"im_pip_no_phase_mask = optics_no_phase_mask(particle ^ num_points)\n",
2226+
"im_pip_no_phase_mask = optics_no_phase_mask(particle ^ num_points) ## Add noise!\n",
21972227
"\n",
21982228
"pip_no_phase_mask = (im_pip_no_phase_mask & gt_pip) >> dt.MoveAxis(2, 0)"
21992229
]

0 commit comments

Comments
 (0)