@@ -115,26 +115,23 @@ def __init__(
115
115
positions : list [Position ],
116
116
channels : ChannelMap ,
117
117
z_window_size : int ,
118
- pyramid_resolution : int = 0 ,
118
+ pyramid_resolution : str = "0" ,
119
119
transform : DictTransform | None = None ,
120
120
) -> None :
121
121
super ().__init__ ()
122
122
self .positions = positions
123
123
self .channels = {k : _ensure_channel_list (v ) for k , v in channels .items ()}
124
124
self .source_ch_idx = [
125
- positions [pyramid_resolution ].get_channel_index (c )
126
- for c in channels ["source" ]
125
+ positions [0 ].get_channel_index (c ) for c in channels ["source" ]
127
126
]
128
127
self .target_ch_idx = (
129
- [
130
- positions [pyramid_resolution ].get_channel_index (c )
131
- for c in channels ["target" ]
132
- ]
128
+ [positions [0 ].get_channel_index (c ) for c in channels ["target" ]]
133
129
if "target" in channels
134
130
else None
135
131
)
136
132
self .z_window_size = z_window_size
137
133
self .transform = transform
134
+ self .pyramid_resolution = pyramid_resolution
138
135
self ._get_windows ()
139
136
140
137
def _get_windows (self ) -> None :
@@ -145,7 +142,7 @@ def _get_windows(self) -> None:
145
142
self .window_arrays = []
146
143
self .window_norm_meta : list [NormMeta | None ] = []
147
144
for fov in self .positions :
148
- img_arr : ImageArray = fov ["0" ]
145
+ img_arr : ImageArray = fov [str ( self . pyramid_resolution ) ]
149
146
ts = img_arr .frames
150
147
zs = img_arr .slices - self .z_window_size + 1
151
148
w += ts * zs
@@ -226,7 +223,7 @@ def __getitem__(self, index: int) -> Sample:
226
223
sample = {
227
224
"index" : sample_index ,
228
225
"source" : self ._stack_channels (sample_images , "source" ),
229
- "norm_meta" : norm_meta ,
226
+ # "norm_meta": norm_meta,
230
227
}
231
228
if self .target_ch_idx is not None :
232
229
sample ["target" ] = self ._stack_channels (sample_images , "target" )
@@ -327,7 +324,7 @@ def __init__(
327
324
augmentations : list [MapTransform ] = [],
328
325
caching : bool = False ,
329
326
ground_truth_masks : Path | None = None ,
330
- pyramid_resolution : int = 0 ,
327
+ pyramid_resolution : str = "0" ,
331
328
):
332
329
super ().__init__ ()
333
330
self .data_path = Path (data_path )
@@ -344,6 +341,7 @@ def __init__(
344
341
self .caching = caching
345
342
self .ground_truth_masks = ground_truth_masks
346
343
self .prepare_data_per_node = True
344
+ self .pyramid_resolution = pyramid_resolution
347
345
348
346
@property
349
347
def cache_path (self ):
@@ -400,6 +398,7 @@ def _base_dataset_settings(self) -> dict[str, dict[str, list[str]] | int]:
400
398
return {
401
399
"channels" : {"source" : self .source_channel },
402
400
"z_window_size" : self .z_window_size ,
401
+ "pyramid_resolution" : self .pyramid_resolution ,
403
402
}
404
403
405
404
def setup (self , stage : Literal ["fit" , "validate" , "test" , "predict" ]):
0 commit comments