Skip to content

Commit 7b1e4ce

Browse files
authored
Merge pull request #22 from wli51/dev-add-patch-dataset
Adding the image crop dataset implementation, with refactor of serialization logic to reduce complexity and redundancy
2 parents 957ae9f + 921d410 commit 7b1e4ce

24 files changed

+2421
-940
lines changed

CHANGELOG.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,26 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
77

88
---
99

10+
## [0.4.3] - 2025-12-16
11+
12+
### Added
13+
14+
#### Crop dataset (`virtual_stain_flow/datasets/`):
15+
16+
Allows the dataset to return user specified crops dynamically obtained from the full images. Supports serialization and reserialization to facilitate reproducibility.
17+
18+
- **`CropImageDataset`** (`crop_dataset.py`): Dataset class for serving image crops based on a `CropManifest`. Extends `BaseImageDataset` with crop-specific state management and lazy loading via `CropFileState`.
19+
- **`CropManifest`** (`ds_engine/crop_manifest.py`): Immutable collection of crop definitions wrapping a `DatasetManifest` for file access. Supports serialization/deserialization and factory construction from coordinate specifications.
20+
- **`Crop`** (`ds_engine/crop_manifest.py`): Dataclass defining a single crop region with manifest index, position (x, y), and dimensions (width, height).
21+
- **`CropIndexState`** (`ds_engine/crop_manifest.py`): Mutable state tracker for the currently active crop region.
22+
- **`CropFileState`** (`ds_engine/crop_manifest.py`): Lazy image loading backend that wraps `FileState` to load full images and dynamically extract crop regions on demand.
23+
24+
### Removed
25+
26+
#### All obselete dataset classes
27+
28+
---
29+
1030
## [0.4.2] - 2025-11-17
1131

1232
### Added

examples/2.training_with_logging_example.ipynb

Lines changed: 163 additions & 225 deletions
Large diffs are not rendered by default.

examples/nbconverted/2.training_with_logging_example.py

Lines changed: 77 additions & 187 deletions
Original file line numberDiff line numberDiff line change
@@ -12,25 +12,27 @@
1212

1313

1414
import re
15-
import json
1615
import pathlib
17-
from typing import List, Tuple
16+
from typing import List
1817

1918

20-
import numpy as np
2119
import pandas as pd
2220
import matplotlib.pyplot as plt
2321
from PIL import Image
2422
import torch
25-
from torch.utils.data import Dataset, DataLoader
23+
from torch.utils.data import DataLoader
2624
from PIL import Image
2725
from torchmetrics.image import MultiScaleStructuralSimilarityIndexMeasure
2826
from mlflow.tracking import MlflowClient
2927

28+
from virtual_stain_flow.datasets.base_dataset import BaseImageDataset
29+
from virtual_stain_flow.datasets.crop_dataset import CropImageDataset
30+
from virtual_stain_flow.transforms.normalizations import MaxScaleNormalize
3031
from virtual_stain_flow.trainers.logging_trainer import SingleGeneratorTrainer
3132
from virtual_stain_flow.vsf_logging.MlflowLogger import MlflowLogger
3233
from virtual_stain_flow.vsf_logging.callbacks.PlotCallback import PlotPredictionCallback
3334
from virtual_stain_flow.models.unet import UNet
35+
from virtual_stain_flow.evaluation.visualization import plot_dataset_grid
3436

3537

3638
# ## Pathing and Additional utils
@@ -74,216 +76,101 @@ def _collect_field_prefixes(
7476
break
7577
return prefixes
7678

77-
78-
def _load_single_channel(
79+
def build_file_index(
7980
plate_dir: pathlib.Path,
80-
field_prefix: str,
81-
channel: int,
82-
normalize: bool = True,
83-
) -> np.ndarray:
84-
"""
85-
Load a single channel image for a given field prefix and channel index.
86-
87-
:param plate_dir: Directory containing TIFF files for one JUMP plate
88-
:param field_prefix: Prefix like 'r01c01f01p01'.
89-
:param channel: Channel index, e.g. 5 for Hoechst, 7 for BF mid-z.
90-
:param normalize: If True, convert to float32 and divide by dtype max
91-
:return: Image array of shape (H, W), float32.
92-
"""
93-
fname = f"{field_prefix}-ch{channel:d}sk1fk1fl1.tiff"
94-
path = plate_dir / fname
95-
if not path.exists():
96-
raise FileNotFoundError(f"Expected file not found: {path}")
97-
98-
arr = np.array(Image.open(path)) # typically uint16
99-
100-
if normalize:
101-
if np.issubdtype(arr.dtype, np.integer):
102-
info = np.iinfo(arr.dtype)
103-
arr = arr.astype("float32") / float(info.max)
104-
else:
105-
arr = arr.astype("float32")
106-
else:
107-
arr = arr.astype("float32")
108-
109-
return arr # (H, W), float32
110-
111-
112-
def load_jump_bf_hoechst(
113-
plate_dir: str | pathlib.Path,
114-
max_fields: int = 32,
115-
bf_channel: int = 7,
116-
dna_channel: int = 5,
117-
normalize: bool = True,
118-
) -> Tuple[np.ndarray, np.ndarray, List[str]]:
81+
max_fields: int = 16,
82+
) -> pd.DataFrame:
11983
"""
120-
Load a small BF->Hoechst subset from a CPJUMP1 plate.
121-
122-
:param plate_dir: Directory containing TIFF files for one JUMP plate
123-
:param max_fields: Maximum number of fields to load
124-
:param bf_channel: Channel index for BF mid-z (default 7)
125-
:param dna_channel: Channel index for Hoechst (default 5)
126-
:param normalize: If True, convert to float32 and divide by dtype max
84+
Helper function to build a file index that specifies
85+
the relationship of images across channels and field/fovs.
86+
The result can directly be supplied to BaseImageDataset to create a
87+
dataset with the correct image pairs.
12788
"""
128-
plate_dir = pathlib.Path(plate_dir)
129-
130-
if not plate_dir.exists() or not plate_dir.is_dir():
131-
raise FileNotFoundError(
132-
f"Plate directory {plate_dir} does not exist or is not a directory."
133-
)
13489

135-
prefixes = _collect_field_prefixes(plate_dir, max_fields=max_fields)
136-
if not prefixes:
137-
raise RuntimeError(f"No valid JUMP image files found in {plate_dir}")
138-
139-
bf_list: list[np.ndarray] = []
140-
dna_list: list[np.ndarray] = []
141-
used_prefixes: list[str] = []
142-
143-
for prefix in prefixes:
144-
try:
145-
bf = _load_single_channel(
146-
plate_dir, prefix, bf_channel, normalize=normalize
147-
)
148-
dna = _load_single_channel(
149-
plate_dir, prefix, dna_channel, normalize=normalize
150-
)
151-
except FileNotFoundError:
152-
# Skip incomplete fields (missing channels)
153-
continue
154-
155-
# Add channel axis: (1, H, W)
156-
bf_list.append(bf[None, ...])
157-
dna_list.append(dna[None, ...])
158-
used_prefixes.append(prefix)
90+
fields = _collect_field_prefixes(
91+
plate_dir,
92+
max_fields=max_fields,
93+
)
15994

160-
if not bf_list:
161-
raise RuntimeError(
162-
f"No complete BF + DNA pairs found in {plate_dir} "
163-
f"for bf_channel={bf_channel}, dna_channel={dna_channel}"
164-
)
95+
file_index_list = []
96+
for field in fields:
97+
sample = {}
98+
for chan in DATA_PATH.glob(f"**/{field}*.tiff"):
99+
match = FIELD_RE.match(chan.name)
100+
if match and match.groups()[1]:
101+
sample[f"ch{match.groups()[1]}"] = str(chan)
165102

166-
X = np.stack(bf_list, axis=0) # (N, 1, H, W)
167-
Y = np.stack(dna_list, axis=0) # (N, 1, H, W)
103+
file_index_list.append(sample)
168104

169-
return X, Y, used_prefixes
105+
file_index = pd.DataFrame(file_index_list)
106+
file_index.dropna(how='all', inplace=True)
107+
if file_index.empty:
108+
raise ValueError(f"No files found in {plate_dir} matching the expected pattern.")
170109

110+
return file_index.loc[:, sorted(file_index.columns)]
171111

172-
# Dataset object for training
173112

174113
# In[3]:
175114

176115

177-
class SimpleDataset(Dataset):
178-
"""
179-
Simple dataset for demo purposes.
180-
Loads images from disk, crops the center, and returns as tensors.
181-
"""
182-
def __init__(self, X: np.ndarray, Y: np.ndarray, crop_size: int = 256):
183-
self.X = X
184-
self.Y = Y
185-
self.crop_size = crop_size
186-
187-
def __len__(self):
188-
return len(self.X)
189-
190-
def __getitem__(self, idx):
191-
x = self.X[idx, 0, :, :]
192-
y = self.Y[idx, 0, :, :]
193-
194-
# Get image dimensions
195-
height, width = x.shape
196-
197-
# Calculate crop coordinates for center
198-
left = (width - self.crop_size) // 2
199-
top = (height - self.crop_size) // 2
200-
right = left + self.crop_size
201-
bottom = top + self.crop_size
202-
203-
# Crop center
204-
x_crop = x[top:bottom,left:right]
205-
y_crop = y[top:bottom,left:right]
206-
207-
# Convert to tensor
208-
x_tensor = torch.from_numpy(x_crop).unsqueeze(0) # Add channel dimension
209-
y_tensor = torch.from_numpy(y_crop).unsqueeze(0) # Add channel dimension
210-
211-
return x_tensor, y_tensor
212-
213-
214-
# ## Load subsetted demo data
215-
216-
# In[ ]:
217-
218-
219116
# Load very small subset of CJUMP1, BF and Hoechst channel as input-target pairs
220117
# for demo purposes
221118
# See https://github.com/jump-cellpainting/2024_Chandrasekaran_NatureMethods_CPJUMP1 for details
222-
X, Y, prefixes = load_jump_bf_hoechst(
223-
plate_dir=DATA_PATH,
224-
# retrieve up to 64 fields (different positions of images)
225-
# this results in a very small sample size good for demo purposes
226-
# for better training results, increase this number/load the full dataset
227-
max_fields=64,
228-
bf_channel=7, # mid-z BF for CPJUMP1
229-
dna_channel=5, # Hoechst
230-
)
231-
232-
# Print and visualize first 3 images from the loaded data
233-
print("X (BF):", X.shape, X.dtype) # (N, 1, H, W)
234-
print("Y (DNA):", Y.shape, Y.dtype) # (N, 1, H, W)
235-
print("First few fields:", prefixes[:5])
236-
237-
panel_width = 3
238-
indices = [1, 2, 3]
239-
fig, ax = plt.subplots(len(indices), 2, figsize=(panel_width * 2, panel_width * len(indices)))
240-
241-
for i, j in enumerate(indices):
242-
input, target = X[j], Y[j]
243-
ax[i][0].imshow(input[0], cmap='gray')
244-
ax[i][0].set_title(f'No.{j} Input')
245-
ax[i][0].axis('off')
246-
ax[i][1].imshow(target[0], cmap='gray')
247-
ax[i][1].set_title(f'No.{j} Target')
248-
ax[i][1].axis('off')
249-
plt.tight_layout()
250-
plt.show()
119+
file_index = build_file_index(DATA_PATH, max_fields=64)
120+
print(file_index.head())
251121

252122

253123
# ## Create dataset that returns tensors needed for training, and visualize several patches
254124

255-
# In[5]:
125+
# In[4]:
256126

257127

258-
# Create dataset instance
259-
dataset = SimpleDataset(X, Y, crop_size=256)
260-
print(f"Dataset created with {len(dataset)} samples")
128+
# Create a dataset with Brightfield as input and Hoechst as target
129+
# See https://github.com/jump-cellpainting/2024_Chandrasekaran_NatureMethods_CPJUMP1
130+
# for which channel codes correspond to which channel
131+
dataset = BaseImageDataset(
132+
file_index=file_index,
133+
check_exists=True,
134+
pil_image_mode="I;16",
135+
input_channel_keys=["ch7"],
136+
target_channel_keys=["ch5"],
137+
)
138+
print(f"Dataset length: {len(dataset)}")
139+
print(
140+
f"Input channels: {dataset.input_channel_keys}, target channels: {dataset._target_channel_keys}"
141+
)
142+
plot_dataset_grid(
143+
dataset=dataset,
144+
indices=[0,1,2,3],
145+
wspace=0.025,
146+
hspace=0.05
147+
)
261148

262-
# Plot the first 5 samples from the dataset
263-
fig, axes = plt.subplots(5, 2, figsize=(8, 16))
264149

265-
for i in range(5):
266-
brightfield, dna = dataset[i]
267-
brightfield = brightfield.numpy().squeeze()
268-
dna = dna.numpy().squeeze()
150+
# ## Generate cropped dataset by taking the center 256 x 256 square using built in utilities.
151+
# Also visualize the first few crops
269152

270-
# Plot brightfield image
271-
axes[i, 0].imshow(brightfield.squeeze(), cmap='gray')
272-
axes[i, 0].set_title(f'Sample {i} - Brightfield')
273-
axes[i, 0].axis('off')
153+
# In[5]:
274154

275-
# Plot DNA image
276-
axes[i, 1].imshow(dna.squeeze(), cmap='gray')
277-
axes[i, 1].set_title(f'Sample {i} - DNA')
278-
axes[i, 1].axis('off')
279155

280-
plt.tight_layout()
281-
plt.show()
156+
cropped_dataset = CropImageDataset.from_base_dataset(
157+
dataset,
158+
crop_size=256,
159+
transforms=MaxScaleNormalize(
160+
normalization_factor='16bit'
161+
)
162+
)
163+
plot_dataset_grid(
164+
dataset=cropped_dataset,
165+
indices=[0,1,2,3],
166+
wspace=0.025,
167+
hspace=0.05
168+
)
282169

283170

284171
# ## Configure and train
285172

286-
# In[ ]:
173+
# In[6]:
287174

288175

289176
## Hyperparameters
@@ -303,7 +190,7 @@ def __getitem__(self, idx):
303190
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
304191

305192
# Batch with DataLoader
306-
train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
193+
train_loader = DataLoader(cropped_dataset, batch_size=batch_size, shuffle=True)
307194

308195
# Model & Optimizer
309196
fully_conv_unet = UNet(
@@ -325,11 +212,14 @@ def __getitem__(self, idx):
325212
# plots to the training.
326213
plot_callback = PlotPredictionCallback(
327214
name="plot_callback_with_train_data",
328-
dataset=dataset,
215+
dataset=cropped_dataset,
329216
indices=[0,1,2,3,4], # first 5 samples
330217
plot_metrics=[torch.nn.L1Loss()],
331218
every_n_epochs=5,
332-
show_plot=False
219+
# kwargs passed to plotting backend
220+
show_plot=False, # don't show plot in notebook
221+
wspace=0.025, # small spacing between subplots
222+
hspace=0.05 # small spacing between subplots
333223
)
334224

335225
# MLflow Logger
@@ -381,7 +271,7 @@ def __getitem__(self, idx):
381271

382272
# ### Display the last logged prediction plot artifact
383273

384-
# In[ ]:
274+
# In[7]:
385275

386276

387277
# Create MLflow client

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ dependencies = [
2929
"jupyter",
3030
"notebook",
3131
"tifffile",
32+
"pandera[pandas]",
3233
]
3334

3435
[project.optional-dependencies]

0 commit comments

Comments
 (0)