@@ -66,3 +66,44 @@ def test_datamodule_setup_fit(
66
66
z_window_size ,
67
67
* yx_patch_size ,
68
68
)
69
+
70
+
71
+ @mark .parametrize ("z_window_size" , [None , 3 ])
72
+ def test_datamodule_z_window_size (
73
+ preprocessed_hcs_dataset , tracks_hcs_dataset , z_window_size
74
+ ):
75
+ z_range = (4 , 9 )
76
+ yx_patch_size = [32 , 32 ]
77
+ batch_size = 4
78
+ with open_ome_zarr (preprocessed_hcs_dataset ) as dataset :
79
+ channel_names = dataset .channel_names
80
+ dm = TripletDataModule (
81
+ data_path = preprocessed_hcs_dataset ,
82
+ tracks_path = tracks_hcs_dataset ,
83
+ source_channel = channel_names ,
84
+ z_range = z_range ,
85
+ initial_yx_patch_size = (64 , 64 ),
86
+ final_yx_patch_size = (32 , 32 ),
87
+ num_workers = 0 ,
88
+ batch_size = batch_size ,
89
+ return_negative = True ,
90
+ z_window_size = z_window_size ,
91
+ )
92
+ dm .setup (stage = "fit" )
93
+ if z_window_size is None :
94
+ expected_z_shape = z_range [1 ] - z_range [0 ]
95
+ else :
96
+ expected_z_shape = z_window_size
97
+ for batch in dm .train_dataloader ():
98
+ assert batch ["anchor" ].shape == (
99
+ batch_size ,
100
+ len (channel_names ),
101
+ expected_z_shape ,
102
+ * yx_patch_size ,
103
+ )
104
+ assert batch ["negative" ].shape == (
105
+ batch_size ,
106
+ len (channel_names ),
107
+ expected_z_shape ,
108
+ * yx_patch_size ,
109
+ )
0 commit comments