Skip to content

Commit b6635d4

Browse files
committed
Resolve merge conflicts for docstring standardization
- Fixed conflicts in data modules (combined.py, triplet.py, hcs.py) - Resolved classification.py parameter conflicts - Updated transforms __init__.py imports - Cleaned up _transforms.py duplicate classes
1 parent cf23055 commit b6635d4

File tree

9 files changed

+435
-271
lines changed

9 files changed

+435
-271
lines changed

viscy/data/cell_classification.py

Lines changed: 18 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def __init__(
4646
transform: Callable | None,
4747
initial_yx_patch_size: tuple[int, int],
4848
return_indices: bool = False,
49+
label_column: str = "infection_state",
4950
):
5051
self.plate = plate
5152
self.z_range = z_range
@@ -65,6 +66,7 @@ def __init__(
6566
annotation["y"].between(*y_range, inclusive="neither")
6667
& annotation["x"].between(*x_range, inclusive="neither")
6768
]
69+
self.label_column = label_column
6870

6971
def __len__(self):
7072
"""Return the number of samples in the dataset."""
@@ -103,7 +105,7 @@ def __getitem__(
103105
img = (image - norm_meta["mean"]) / norm_meta["std"]
104106
if self.transform is not None:
105107
img = self.transform(img)
106-
label = torch.tensor(row["infection_state"]).float()[None]
108+
label = torch.tensor(row[self.label_column]).float()[None]
107109
if self.return_indices:
108110
return img, label, row[INDEX_COLUMNS].to_dict()
109111
else:
@@ -149,25 +151,27 @@ def __init__(
149151
val_fovs: list[str] | None,
150152
channel_name: str,
151153
z_range: tuple[int, int],
152-
train_exlude_timepoints: list[int],
154+
train_exclude_timepoints: list[int],
153155
train_transforms: list[Callable] | None,
154156
val_transforms: list[Callable] | None,
155157
initial_yx_patch_size: tuple[int, int],
156158
batch_size: int,
157159
num_workers: int,
160+
label_column: str = "infection_state",
158161
):
159162
super().__init__()
160163
self.image_path = image_path
161164
self.annotation_path = annotation_path
162165
self.val_fovs = val_fovs
163166
self.channel_name = channel_name
164167
self.z_range = z_range
165-
self.train_exlude_timepoints = train_exlude_timepoints
168+
self.train_exclude_timepoints = train_exclude_timepoints
166169
self.train_transform = Compose(train_transforms)
167170
self.val_transform = Compose(val_transforms)
168171
self.initial_yx_patch_size = initial_yx_patch_size
169172
self.batch_size = batch_size
170173
self.num_workers = num_workers
174+
self.label_column = label_column
171175

172176
def _subset(
173177
self,
@@ -189,6 +193,7 @@ def _subset(
189193
transform=transform,
190194
initial_yx_patch_size=self.initial_yx_patch_size,
191195
return_indices=return_indices,
196+
label_column=self.label_column,
192197
)
193198

194199
def setup(self, stage=None) -> None:
@@ -208,8 +213,16 @@ def setup(self, stage=None) -> None:
208213
If stage is unknown.
209214
"""
210215
plate = open_ome_zarr(self.image_path)
211-
all_fovs = ["/" + name for (name, _) in plate.positions()]
212216
annotation = pd.read_csv(self.annotation_path)
217+
all_fovs = [name for (name, _) in plate.positions()]
218+
if annotation["fov_name"].iloc[0].startswith("/"):
219+
all_fovs = ["/" + name for name in all_fovs]
220+
if all_fovs[0].startswith("/"):
221+
if not self.val_fovs[0].startswith("/"):
222+
self.val_fovs = ["/" + name for name in self.val_fovs]
223+
else:
224+
if self.val_fovs[0].startswith("/"):
225+
self.val_fovs = [name[1:] for name in self.val_fovs]
213226
for column in ("t", "y", "x"):
214227
annotation[column] = annotation[column].astype(int)
215228
if stage in (None, "fit", "validate"):
@@ -219,7 +232,7 @@ def setup(self, stage=None) -> None:
219232
annotation,
220233
train_fovs,
221234
transform=self.train_transform,
222-
exclude_timepoints=self.train_exlude_timepoints,
235+
exclude_timepoints=self.train_exclude_timepoints,
223236
)
224237
self.val_dataset = self._subset(
225238
plate,

viscy/data/combined.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -189,7 +189,7 @@ def _get_sample_indices(self, idx: int) -> tuple[int, int]:
189189
sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
190190
return dataset_idx, sample_idx
191191

192-
def __getitems__(self, indices: list[int]) -> list:
192+
def __getitems__(self, indices: list[int]) -> list[dict[str, torch.Tensor]]:
193193
"""Retrieve multiple items by indices with batched dataset access.
194194
195195
Groups indices by source dataset and performs batched retrieval
@@ -202,19 +202,22 @@ def __getitems__(self, indices: list[int]) -> list:
202202
203203
Returns
204204
-------
205-
list
205+
list[dict[str, torch.Tensor]]
206206
Samples from all requested indices, maintaining order.
207207
"""
208208
grouped_indices = defaultdict(list)
209209
for idx in indices:
210210
dataset_idx, sample_indices = self._get_sample_indices(idx)
211211
grouped_indices[dataset_idx].append(sample_indices)
212212
_logger.debug(f"Grouped indices: {grouped_indices}")
213-
sub_batches = []
213+
214+
micro_batches = []
214215
for dataset_idx, sample_indices in grouped_indices.items():
215-
sub_batch = self.datasets[dataset_idx].__getitems__(sample_indices)
216-
sub_batches.extend(sub_batch)
217-
return sub_batches
216+
micro_batch = self.datasets[dataset_idx].__getitems__(sample_indices)
217+
micro_batch["_dataset_idx"] = dataset_idx
218+
micro_batches.append(micro_batch)
219+
220+
return micro_batches
218221

219222

220223
class ConcatDataModule(LightningDataModule):
@@ -369,6 +372,7 @@ def train_dataloader(self) -> ThreadDataLoader:
369372
batch_size=self.batch_size,
370373
shuffle=True,
371374
drop_last=True,
375+
collate_fn=lambda x: x,
372376
**self._dataloader_kwargs(),
373377
)
374378

@@ -387,9 +391,41 @@ def val_dataloader(self) -> ThreadDataLoader:
387391
batch_size=self.batch_size,
388392
shuffle=False,
389393
drop_last=False,
394+
collate_fn=lambda x: x,
390395
**self._dataloader_kwargs(),
391396
)
392397

398+
def on_after_batch_transfer(self, batch, dataloader_idx: int):
399+
"""Apply GPU transforms from constituent data modules to micro-batches."""
400+
processed_micro_batches = []
401+
for micro_batch in batch:
402+
dataset_idx = micro_batch.pop("_dataset_idx")
403+
dm = self.data_modules[dataset_idx]
404+
if hasattr(dm, "on_after_batch_transfer"):
405+
processed_micro_batch = dm.on_after_batch_transfer(
406+
micro_batch, dataloader_idx
407+
)
408+
else:
409+
processed_micro_batch = micro_batch
410+
processed_micro_batches.append(processed_micro_batch)
411+
combined_batch = {}
412+
for key in processed_micro_batches[0].keys():
413+
if isinstance(processed_micro_batches[0][key], list):
414+
combined_batch[key] = []
415+
for micro_batch in processed_micro_batches:
416+
if key in micro_batch:
417+
combined_batch[key].extend(micro_batch[key])
418+
else:
419+
tensors_to_concat = [
420+
micro_batch[key]
421+
for micro_batch in processed_micro_batches
422+
if key in micro_batch
423+
]
424+
if tensors_to_concat:
425+
combined_batch[key] = torch.cat(tensors_to_concat, dim=0)
426+
427+
return combined_batch
428+
393429

394430
class CachedConcatDataModule(LightningDataModule):
395431
"""Cached concatenated data module for distributed training.

viscy/data/hcs.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -735,22 +735,24 @@ def _fit_transform(self) -> tuple[Compose, Compose]:
735735
Training and validation transform compositions
736736
"""
737737
# TODO: These have a fixed order for now... ()
738-
final_crop = [
739-
CenterSpatialCropd(
740-
keys=self.source_channel + self.target_channel,
741-
roi_size=(
742-
self.z_window_size,
743-
self.yx_patch_size[0],
744-
self.yx_patch_size[1],
745-
),
746-
)
747-
]
738+
final_crop = [self._final_crop()]
748739
train_transform = Compose(
749740
self.normalizations + self._train_transform() + final_crop
750741
)
751742
val_transform = Compose(self.normalizations + final_crop)
752743
return train_transform, val_transform
753744

745+
def _final_crop(self) -> CenterSpatialCropd:
746+
"""Setup final cropping: center crop to the target size."""
747+
return CenterSpatialCropd(
748+
keys=self.source_channel + self.target_channel,
749+
roi_size=(
750+
self.z_window_size,
751+
self.yx_patch_size[0],
752+
self.yx_patch_size[1],
753+
),
754+
)
755+
754756
def _train_transform(self) -> list[Callable]:
755757
"""Set up training augmentations.
756758

0 commit comments

Comments
 (0)