Skip to content

Commit b07f0ad

Browse files
committed
initial commit adding resolution
1 parent baa4ee3 commit b07f0ad

File tree

4 files changed

+16
-3
lines changed

4 files changed

+16
-3
lines changed

examples/configs/fit_example.yml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ data:
3737
batch_size: 32
3838
num_workers: 16
3939
yx_patch_size: [256, 256]
40+
pyramid_resolution: 0
4041
normalizations:
4142
- class_path: viscy.transforms.NormalizeSampled
4243
init_args:
@@ -87,4 +88,4 @@ data:
8788
sigma_z: [0.25, 1.5]
8889
sigma_y: [0.25, 1.5]
8990
sigma_x: [0.25, 1.5]
90-
caching: false
91+
caching: false

examples/configs/predict_example.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,5 +63,6 @@ predict:
6363
- 256
6464
caching: false
6565
predict_scale_source: null
66+
pyramid_resolution: 0
6667
return_predictions: false
6768
ckpt_path: null

examples/configs/test_example.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,5 +62,6 @@ data:
6262
- 256
6363
caching: false
6464
ground_truth_masks: null
65+
pyramid_resolution: 0
6566
ckpt_path: null
6667
verbose: true

viscy/data/hcs.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,8 @@ class SlidingWindowDataset(Dataset):
104104
:param ChannelMap channels: source and target channel names,
105105
e.g. ``{'source': 'Phase', 'target': ['Nuclei', 'Membrane']}``
106106
:param int z_window_size: Z window size of the 2.5D U-Net, 1 for 2D
107+
:param int pyramid_resolution: pyramid level.
108+
defaults to 0 (full resolution)
107109
:param DictTransform | None transform:
108110
a callable that transforms data, defaults to None
109111
"""
@@ -113,16 +115,21 @@ def __init__(
113115
positions: list[Position],
114116
channels: ChannelMap,
115117
z_window_size: int,
118+
pyramid_resolution: int = 0,
116119
transform: DictTransform | None = None,
117120
) -> None:
118121
super().__init__()
119122
self.positions = positions
120123
self.channels = {k: _ensure_channel_list(v) for k, v in channels.items()}
121124
self.source_ch_idx = [
122-
positions[0].get_channel_index(c) for c in channels["source"]
125+
positions[pyramid_resolution].get_channel_index(c)
126+
for c in channels["source"]
123127
]
124128
self.target_ch_idx = (
125-
[positions[0].get_channel_index(c) for c in channels["target"]]
129+
[
130+
positions[pyramid_resolution].get_channel_index(c)
131+
for c in channels["target"]
132+
]
126133
if "target" in channels
127134
else None
128135
)
@@ -301,6 +308,8 @@ class HCSDataModule(LightningDataModule):
301308
:param Path | None ground_truth_masks: path to the ground truth masks,
302309
used in the test stage to compute segmentation metrics,
303310
defaults to None
311+
:param int pyramid_resolution: pyramid resolution level.
312+
defaults to 0 (full resolution)
304313
"""
305314

306315
def __init__(
@@ -318,6 +327,7 @@ def __init__(
318327
augmentations: list[MapTransform] = [],
319328
caching: bool = False,
320329
ground_truth_masks: Path | None = None,
330+
pyramid_resolution: int = 0,
321331
):
322332
super().__init__()
323333
self.data_path = Path(data_path)

0 commit comments

Comments
 (0)