Skip to content

Commit 8a137d4

Browse files
committed
added the pyramid option to the slidingwindow
1 parent b07f0ad commit 8a137d4

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

viscy/data/hcs.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -115,26 +115,23 @@ def __init__(
115115
positions: list[Position],
116116
channels: ChannelMap,
117117
z_window_size: int,
118-
pyramid_resolution: int = 0,
118+
pyramid_resolution: str = "0",
119119
transform: DictTransform | None = None,
120120
) -> None:
121121
super().__init__()
122122
self.positions = positions
123123
self.channels = {k: _ensure_channel_list(v) for k, v in channels.items()}
124124
self.source_ch_idx = [
125-
positions[pyramid_resolution].get_channel_index(c)
126-
for c in channels["source"]
125+
positions[0].get_channel_index(c) for c in channels["source"]
127126
]
128127
self.target_ch_idx = (
129-
[
130-
positions[pyramid_resolution].get_channel_index(c)
131-
for c in channels["target"]
132-
]
128+
[positions[0].get_channel_index(c) for c in channels["target"]]
133129
if "target" in channels
134130
else None
135131
)
136132
self.z_window_size = z_window_size
137133
self.transform = transform
134+
self.pyramid_resolution = pyramid_resolution
138135
self._get_windows()
139136

140137
def _get_windows(self) -> None:
@@ -145,7 +142,7 @@ def _get_windows(self) -> None:
145142
self.window_arrays = []
146143
self.window_norm_meta: list[NormMeta | None] = []
147144
for fov in self.positions:
148-
img_arr: ImageArray = fov["0"]
145+
img_arr: ImageArray = fov[str(self.pyramid_resolution)]
149146
ts = img_arr.frames
150147
zs = img_arr.slices - self.z_window_size + 1
151148
w += ts * zs
@@ -226,7 +223,7 @@ def __getitem__(self, index: int) -> Sample:
226223
sample = {
227224
"index": sample_index,
228225
"source": self._stack_channels(sample_images, "source"),
229-
"norm_meta": norm_meta,
226+
# "norm_meta": norm_meta,
230227
}
231228
if self.target_ch_idx is not None:
232229
sample["target"] = self._stack_channels(sample_images, "target")
@@ -327,7 +324,7 @@ def __init__(
327324
augmentations: list[MapTransform] = [],
328325
caching: bool = False,
329326
ground_truth_masks: Path | None = None,
330-
pyramid_resolution: int = 0,
327+
pyramid_resolution: str = "0",
331328
):
332329
super().__init__()
333330
self.data_path = Path(data_path)
@@ -344,6 +341,7 @@ def __init__(
344341
self.caching = caching
345342
self.ground_truth_masks = ground_truth_masks
346343
self.prepare_data_per_node = True
344+
self.pyramid_resolution = pyramid_resolution
347345

348346
@property
349347
def cache_path(self):
@@ -400,6 +398,7 @@ def _base_dataset_settings(self) -> dict[str, dict[str, list[str]] | int]:
400398
return {
401399
"channels": {"source": self.source_channel},
402400
"z_window_size": self.z_window_size,
401+
"pyramid_resolution": self.pyramid_resolution,
403402
}
404403

405404
def setup(self, stage: Literal["fit", "validate", "test", "predict"]):

0 commit comments

Comments
 (0)