Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion examples/configs/fit_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ data:
batch_size: 32
num_workers: 16
yx_patch_size: [256, 256]
pyramid_resolution: "0"
normalizations:
- class_path: viscy.transforms.NormalizeSampled
init_args:
Expand Down Expand Up @@ -87,4 +88,4 @@ data:
sigma_z: [0.25, 1.5]
sigma_y: [0.25, 1.5]
sigma_x: [0.25, 1.5]
caching: false
caching: false
1 change: 1 addition & 0 deletions examples/configs/predict_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -63,5 +63,6 @@ predict:
- 256
caching: false
predict_scale_source: null
pyramid_resolution: "0"
return_predictions: false
ckpt_path: null
1 change: 1 addition & 0 deletions examples/configs/test_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,5 +62,6 @@ data:
- 256
caching: false
ground_truth_masks: null
pyramid_resolution: "0"
ckpt_path: null
verbose: true
13 changes: 11 additions & 2 deletions viscy/data/hcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,8 @@ class SlidingWindowDataset(Dataset):
:param ChannelMap channels: source and target channel names,
e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}``
:param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D
:param str pyramid_resolution: pyramid level.
defaults to 0 (full resolution)
:param DictTransform | None transform:
a callable that transforms data, defaults to None
"""
Expand All @@ -113,6 +115,7 @@ def __init__(
positions: list[Position],
channels: ChannelMap,
z_window_size: int,
pyramid_resolution: str = "0",
transform: DictTransform | None = None,
) -> None:
super().__init__()
Expand All @@ -128,6 +131,7 @@ def __init__(
)
self.z_window_size = z_window_size
self.transform = transform
self.pyramid_resolution = pyramid_resolution
self._get_windows()

def _get_windows(self) -> None:
Expand All @@ -138,7 +142,7 @@ def _get_windows(self) -> None:
self.window_arrays = []
self.window_norm_meta: list[NormMeta | None] = []
for fov in self.positions:
img_arr: ImageArray = fov["0"]
img_arr: ImageArray = fov[str(self.pyramid_resolution)]
ts = img_arr.frames
zs = img_arr.slices - self.z_window_size + 1
w += ts * zs
Expand Down Expand Up @@ -219,7 +223,7 @@ def __getitem__(self, index: int) -> Sample:
sample = {
"index": sample_index,
"source": self._stack_channels(sample_images, "source"),
"norm_meta": norm_meta,
# "norm_meta": norm_meta,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this for?

}
if self.target_ch_idx is not None:
sample["target"] = self._stack_channels(sample_images, "target")
Expand Down Expand Up @@ -301,6 +305,8 @@ class HCSDataModule(LightningDataModule):
:param Path | None ground_truth_masks: path to the ground truth masks,
used in the test stage to compute segmentation metrics,
defaults to None
:param str pyramid_resolution: pyramid resolution level.
defaults to 0 (full resolution)
"""

def __init__(
Expand All @@ -318,6 +324,7 @@ def __init__(
augmentations: list[MapTransform] = [],
caching: bool = False,
ground_truth_masks: Path | None = None,
pyramid_resolution: str = "0",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#218 used array_key for this parameter, which is a more general naming:

array_key: str = "0",

):
super().__init__()
self.data_path = Path(data_path)
Expand All @@ -334,6 +341,7 @@ def __init__(
self.caching = caching
self.ground_truth_masks = ground_truth_masks
self.prepare_data_per_node = True
self.pyramid_resolution = pyramid_resolution

@property
def cache_path(self):
Expand Down Expand Up @@ -390,6 +398,7 @@ def _base_dataset_settings(self) -> dict[str, dict[str, list[str]] | int]:
return {
"channels": {"source": self.source_channel},
"z_window_size": self.z_window_size,
"pyramid_resolution": self.pyramid_resolution,
}

def setup(self, stage: Literal["fit", "validate", "test", "predict"]):
Expand Down