1212
1313
1414import re
15- import json
1615import pathlib
17- from typing import List , Tuple
16+ from typing import List
1817
1918
20- import numpy as np
2119import pandas as pd
2220import matplotlib .pyplot as plt
2321from PIL import Image
2422import torch
25- from torch .utils .data import Dataset , DataLoader
23+ from torch .utils .data import DataLoader
2624from PIL import Image
2725from torchmetrics .image import MultiScaleStructuralSimilarityIndexMeasure
2826from mlflow .tracking import MlflowClient
2927
28+ from virtual_stain_flow .datasets .base_dataset import BaseImageDataset
29+ from virtual_stain_flow .datasets .crop_dataset import CropImageDataset
30+ from virtual_stain_flow .transforms .normalizations import MaxScaleNormalize
3031from virtual_stain_flow .trainers .logging_trainer import SingleGeneratorTrainer
3132from virtual_stain_flow .vsf_logging .MlflowLogger import MlflowLogger
3233from virtual_stain_flow .vsf_logging .callbacks .PlotCallback import PlotPredictionCallback
3334from virtual_stain_flow .models .unet import UNet
35+ from virtual_stain_flow .evaluation .visualization import plot_dataset_grid
3436
3537
3638# ## Pathing and Additional utils
@@ -74,216 +76,101 @@ def _collect_field_prefixes(
7476 break
7577 return prefixes
7678
77-
78- def _load_single_channel (
79+ def build_file_index (
7980 plate_dir : pathlib .Path ,
80- field_prefix : str ,
81- channel : int ,
82- normalize : bool = True ,
83- ) -> np .ndarray :
84- """
85- Load a single channel image for a given field prefix and channel index.
86-
87- :param plate_dir: Directory containing TIFF files for one JUMP plate
88- :param field_prefix: Prefix like 'r01c01f01p01'.
89- :param channel: Channel index, e.g. 5 for Hoechst, 7 for BF mid-z.
90- :param normalize: If True, convert to float32 and divide by dtype max
91- :return: Image array of shape (H, W), float32.
92- """
93- fname = f"{ field_prefix } -ch{ channel :d} sk1fk1fl1.tiff"
94- path = plate_dir / fname
95- if not path .exists ():
96- raise FileNotFoundError (f"Expected file not found: { path } " )
97-
98- arr = np .array (Image .open (path )) # typically uint16
99-
100- if normalize :
101- if np .issubdtype (arr .dtype , np .integer ):
102- info = np .iinfo (arr .dtype )
103- arr = arr .astype ("float32" ) / float (info .max )
104- else :
105- arr = arr .astype ("float32" )
106- else :
107- arr = arr .astype ("float32" )
108-
109- return arr # (H, W), float32
110-
111-
112- def load_jump_bf_hoechst (
113- plate_dir : str | pathlib .Path ,
114- max_fields : int = 32 ,
115- bf_channel : int = 7 ,
116- dna_channel : int = 5 ,
117- normalize : bool = True ,
118- ) -> Tuple [np .ndarray , np .ndarray , List [str ]]:
81+ max_fields : int = 16 ,
82+ ) -> pd .DataFrame :
11983 """
120- Load a small BF->Hoechst subset from a CPJUMP1 plate.
121-
122- :param plate_dir: Directory containing TIFF files for one JUMP plate
123- :param max_fields: Maximum number of fields to load
124- :param bf_channel: Channel index for BF mid-z (default 7)
125- :param dna_channel: Channel index for Hoechst (default 5)
126- :param normalize: If True, convert to float32 and divide by dtype max
84+ Helper function to build a file index that specifies
85+ the relationship of images across channels and field/fovs.
86+ The result can directly be supplied to BaseImageDataset to create a
87+ dataset with the correct image pairs.
12788 """
128- plate_dir = pathlib .Path (plate_dir )
129-
130- if not plate_dir .exists () or not plate_dir .is_dir ():
131- raise FileNotFoundError (
132- f"Plate directory { plate_dir } does not exist or is not a directory."
133- )
13489
135- prefixes = _collect_field_prefixes (plate_dir , max_fields = max_fields )
136- if not prefixes :
137- raise RuntimeError (f"No valid JUMP image files found in { plate_dir } " )
138-
139- bf_list : list [np .ndarray ] = []
140- dna_list : list [np .ndarray ] = []
141- used_prefixes : list [str ] = []
142-
143- for prefix in prefixes :
144- try :
145- bf = _load_single_channel (
146- plate_dir , prefix , bf_channel , normalize = normalize
147- )
148- dna = _load_single_channel (
149- plate_dir , prefix , dna_channel , normalize = normalize
150- )
151- except FileNotFoundError :
152- # Skip incomplete fields (missing channels)
153- continue
154-
155- # Add channel axis: (1, H, W)
156- bf_list .append (bf [None , ...])
157- dna_list .append (dna [None , ...])
158- used_prefixes .append (prefix )
90+ fields = _collect_field_prefixes (
91+ plate_dir ,
92+ max_fields = max_fields ,
93+ )
15994
160- if not bf_list :
161- raise RuntimeError (
162- f"No complete BF + DNA pairs found in { plate_dir } "
163- f"for bf_channel={ bf_channel } , dna_channel={ dna_channel } "
164- )
95+ file_index_list = []
96+ for field in fields :
97+ sample = {}
98+ for chan in DATA_PATH .glob (f"**/{ field } *.tiff" ):
99+ match = FIELD_RE .match (chan .name )
100+ if match and match .groups ()[1 ]:
101+ sample [f"ch{ match .groups ()[1 ]} " ] = str (chan )
165102
166- X = np .stack (bf_list , axis = 0 ) # (N, 1, H, W)
167- Y = np .stack (dna_list , axis = 0 ) # (N, 1, H, W)
103+ file_index_list .append (sample )
168104
169- return X , Y , used_prefixes
105+ file_index = pd .DataFrame (file_index_list )
106+ file_index .dropna (how = 'all' , inplace = True )
107+ if file_index .empty :
108+ raise ValueError (f"No files found in { plate_dir } matching the expected pattern." )
170109
110+ return file_index .loc [:, sorted (file_index .columns )]
171111
172- # Dataset object for training
173112
174113# In[3]:
175114
176115
177- class SimpleDataset (Dataset ):
178- """
179- Simple dataset for demo purposes.
180- Loads images from disk, crops the center, and returns as tensors.
181- """
182- def __init__ (self , X : np .ndarray , Y : np .ndarray , crop_size : int = 256 ):
183- self .X = X
184- self .Y = Y
185- self .crop_size = crop_size
186-
187- def __len__ (self ):
188- return len (self .X )
189-
190- def __getitem__ (self , idx ):
191- x = self .X [idx , 0 , :, :]
192- y = self .Y [idx , 0 , :, :]
193-
194- # Get image dimensions
195- height , width = x .shape
196-
197- # Calculate crop coordinates for center
198- left = (width - self .crop_size ) // 2
199- top = (height - self .crop_size ) // 2
200- right = left + self .crop_size
201- bottom = top + self .crop_size
202-
203- # Crop center
204- x_crop = x [top :bottom ,left :right ]
205- y_crop = y [top :bottom ,left :right ]
206-
207- # Convert to tensor
208- x_tensor = torch .from_numpy (x_crop ).unsqueeze (0 ) # Add channel dimension
209- y_tensor = torch .from_numpy (y_crop ).unsqueeze (0 ) # Add channel dimension
210-
211- return x_tensor , y_tensor
212-
213-
214- # ## Load subsetted demo data
215-
216- # In[ ]:
217-
218-
219116# Load very small subset of CJUMP1, BF and Hoechst channel as input-target pairs
220117# for demo purposes
221118# See https://github.com/jump-cellpainting/2024_Chandrasekaran_NatureMethods_CPJUMP1 for details
222- X , Y , prefixes = load_jump_bf_hoechst (
223- plate_dir = DATA_PATH ,
224- # retrieve up to 64 fields (different positions of images)
225- # this results in a very small sample size good for demo purposes
226- # for better training results, increase this number/load the full dataset
227- max_fields = 64 ,
228- bf_channel = 7 , # mid-z BF for CPJUMP1
229- dna_channel = 5 , # Hoechst
230- )
231-
232- # Print and visualize first 3 images from the loaded data
233- print ("X (BF):" , X .shape , X .dtype ) # (N, 1, H, W)
234- print ("Y (DNA):" , Y .shape , Y .dtype ) # (N, 1, H, W)
235- print ("First few fields:" , prefixes [:5 ])
236-
237- panel_width = 3
238- indices = [1 , 2 , 3 ]
239- fig , ax = plt .subplots (len (indices ), 2 , figsize = (panel_width * 2 , panel_width * len (indices )))
240-
241- for i , j in enumerate (indices ):
242- input , target = X [j ], Y [j ]
243- ax [i ][0 ].imshow (input [0 ], cmap = 'gray' )
244- ax [i ][0 ].set_title (f'No.{ j } Input' )
245- ax [i ][0 ].axis ('off' )
246- ax [i ][1 ].imshow (target [0 ], cmap = 'gray' )
247- ax [i ][1 ].set_title (f'No.{ j } Target' )
248- ax [i ][1 ].axis ('off' )
249- plt .tight_layout ()
250- plt .show ()
119+ file_index = build_file_index (DATA_PATH , max_fields = 64 )
120+ print (file_index .head ())
251121
252122
253123# ## Create dataset that returns tensors needed for training, and visualize several patches
254124
255- # In[5 ]:
125+ # In[4 ]:
256126
257127
258- # Create dataset instance
259- dataset = SimpleDataset (X , Y , crop_size = 256 )
260- print (f"Dataset created with { len (dataset )} samples" )
128+ # Create a dataset with Brightfield as input and Hoechst as target
129+ # See https://github.com/jump-cellpainting/2024_Chandrasekaran_NatureMethods_CPJUMP1
130+ # for which channel codes correspond to which channel
131+ dataset = BaseImageDataset (
132+ file_index = file_index ,
133+ check_exists = True ,
134+ pil_image_mode = "I;16" ,
135+ input_channel_keys = ["ch7" ],
136+ target_channel_keys = ["ch5" ],
137+ )
138+ print (f"Dataset length: { len (dataset )} " )
139+ print (
140+ f"Input channels: { dataset .input_channel_keys } , target channels: { dataset ._target_channel_keys } "
141+ )
142+ plot_dataset_grid (
143+ dataset = dataset ,
144+ indices = [0 ,1 ,2 ,3 ],
145+ wspace = 0.025 ,
146+ hspace = 0.05
147+ )
261148
262- # Plot the first 5 samples from the dataset
263- fig , axes = plt .subplots (5 , 2 , figsize = (8 , 16 ))
264149
265- for i in range (5 ):
266- brightfield , dna = dataset [i ]
267- brightfield = brightfield .numpy ().squeeze ()
268- dna = dna .numpy ().squeeze ()
150+ # ## Generate cropped dataset by taking the center 256 x 256 square using built in utilities.
151+ # Also visualize the first few crops
269152
270- # Plot brightfield image
271- axes [i , 0 ].imshow (brightfield .squeeze (), cmap = 'gray' )
272- axes [i , 0 ].set_title (f'Sample { i } - Brightfield' )
273- axes [i , 0 ].axis ('off' )
153+ # In[5]:
274154
275- # Plot DNA image
276- axes [i , 1 ].imshow (dna .squeeze (), cmap = 'gray' )
277- axes [i , 1 ].set_title (f'Sample { i } - DNA' )
278- axes [i , 1 ].axis ('off' )
279155
280- plt .tight_layout ()
281- plt .show ()
156+ cropped_dataset = CropImageDataset .from_base_dataset (
157+ dataset ,
158+ crop_size = 256 ,
159+ transforms = MaxScaleNormalize (
160+ normalization_factor = '16bit'
161+ )
162+ )
163+ plot_dataset_grid (
164+ dataset = cropped_dataset ,
165+ indices = [0 ,1 ,2 ,3 ],
166+ wspace = 0.025 ,
167+ hspace = 0.05
168+ )
282169
283170
284171# ## Configure and train
285172
286- # In[ ]:
173+ # In[6 ]:
287174
288175
289176## Hyperparameters
@@ -303,7 +190,7 @@ def __getitem__(self, idx):
303190device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
304191
305192# Batch with DataLoader
306- train_loader = DataLoader (dataset , batch_size = batch_size , shuffle = True )
193+ train_loader = DataLoader (cropped_dataset , batch_size = batch_size , shuffle = True )
307194
308195# Model & Optimizer
309196fully_conv_unet = UNet (
@@ -325,11 +212,14 @@ def __getitem__(self, idx):
325212# plots to the training.
326213plot_callback = PlotPredictionCallback (
327214 name = "plot_callback_with_train_data" ,
328- dataset = dataset ,
215+ dataset = cropped_dataset ,
329216 indices = [0 ,1 ,2 ,3 ,4 ], # first 5 samples
330217 plot_metrics = [torch .nn .L1Loss ()],
331218 every_n_epochs = 5 ,
332- show_plot = False
219+ # kwargs passed to plotting backend
220+ show_plot = False , # don't show plot in notebook
221+ wspace = 0.025 , # small spacing between subplots
222+ hspace = 0.05 # small spacing between subplots
333223)
334224
335225# MLflow Logger
@@ -381,7 +271,7 @@ def __getitem__(self, idx):
381271
382272# ### Display the last logged prediction plot artifact
383273
384- # In[ ]:
274+ # In[7 ]:
385275
386276
387277# Create MLflow client
0 commit comments