Skip to content

Commit ac7ac63

Browse files
authored
Expose Final Z cropping in TripletDataModule (#270)
* expose final z cropping * test final cropping * cleanup test case
1 parent 9acd61a commit ac7ac63

File tree

2 files changed

+45
-1
lines changed

2 files changed

+45
-1
lines changed

tests/data/test_triplet.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,44 @@ def test_datamodule_setup_fit(
6666
z_window_size,
6767
*yx_patch_size,
6868
)
69+
70+
71+
@mark.parametrize("z_window_size", [None, 3])
72+
def test_datamodule_z_window_size(
73+
preprocessed_hcs_dataset, tracks_hcs_dataset, z_window_size
74+
):
75+
z_range = (4, 9)
76+
yx_patch_size = [32, 32]
77+
batch_size = 4
78+
with open_ome_zarr(preprocessed_hcs_dataset) as dataset:
79+
channel_names = dataset.channel_names
80+
dm = TripletDataModule(
81+
data_path=preprocessed_hcs_dataset,
82+
tracks_path=tracks_hcs_dataset,
83+
source_channel=channel_names,
84+
z_range=z_range,
85+
initial_yx_patch_size=(64, 64),
86+
final_yx_patch_size=(32, 32),
87+
num_workers=0,
88+
batch_size=batch_size,
89+
return_negative=True,
90+
z_window_size=z_window_size,
91+
)
92+
dm.setup(stage="fit")
93+
if z_window_size is None:
94+
expected_z_shape = z_range[1] - z_range[0]
95+
else:
96+
expected_z_shape = z_window_size
97+
for batch in dm.train_dataloader():
98+
assert batch["anchor"].shape == (
99+
batch_size,
100+
len(channel_names),
101+
expected_z_shape,
102+
*yx_patch_size,
103+
)
104+
assert batch["negative"].shape == (
105+
batch_size,
106+
len(channel_names),
107+
expected_z_shape,
108+
*yx_patch_size,
109+
)

viscy/data/triplet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ def __init__(
321321
persistent_workers: bool = False,
322322
prefetch_factor: int | None = None,
323323
pin_memory: bool = False,
324+
z_window_size: int | None = None,
324325
):
325326
"""Lightning data module for triplet sampling of patches.
326327
@@ -374,12 +375,14 @@ def __init__(
374375
Number of batches loaded in advance by each worker, by default None
375376
pin_memory : bool, optional
376377
Whether to pin memory in CPU for faster GPU transfer, by default False
378+
z_window_size : int, optional
379+
Size of the final Z window, by default None (inferred from z_range)
377380
"""
378381
super().__init__(
379382
data_path=data_path,
380383
source_channel=source_channel,
381384
target_channel=[],
382-
z_window_size=z_range[1] - z_range[0],
385+
z_window_size=z_window_size or z_range[1] - z_range[0],
383386
split_ratio=split_ratio,
384387
batch_size=batch_size,
385388
num_workers=num_workers,

0 commit comments

Comments
 (0)