66(chunked in one resp. two dimensions), with support for efficient storage and retrieval
77using the Zarr library.
88"""
9+ from __future__ import annotations
910
1011import logging
1112from abc import ABC , abstractmethod
12- from typing import Callable , Generator , Generic , List , Optional , Tuple , Union
13+ from typing import (
14+ Callable ,
15+ Generator ,
16+ Generic ,
17+ Iterator ,
18+ List ,
19+ Optional ,
20+ Tuple ,
21+ Union ,
22+ cast ,
23+ )
1324
1425import zarr
1526from numpy .typing import NDArray
27+ from tqdm import tqdm
1628from zarr .storage import StoreLike
1729
1830from ..utils import log_duration
@@ -35,9 +47,12 @@ def from_numpy(self, x: NDArray) -> TensorType:
3547
3648class SequenceAggregator (Generic [TensorType ], ABC ):
3749 @abstractmethod
38- def __call__ (self , tensor_generator : Generator [TensorType , None , None ]):
50+ def __call__ (
51+ self ,
52+ tensor_sequence : LazyChunkSequence ,
53+ ):
3954 """
40- Aggregates tensors from a generator .
55+ Aggregates tensors from a sequence .
4156
4257 Implement this method to define how a sequence of tensors, provided by a
4358 generator, should be combined.
@@ -46,31 +61,37 @@ def __call__(self, tensor_generator: Generator[TensorType, None, None]):
4661
4762class ListAggregator (SequenceAggregator ):
4863 def __call__ (
49- self , tensor_generator : Generator [TensorType , None , None ]
64+ self ,
65+ tensor_sequence : LazyChunkSequence ,
5066 ) -> List [TensorType ]:
5167 """
5268 Aggregates tensors from a single-level generator into a list. This method simply
5369 collects each tensor emitted by the generator into a single list.
5470
5571 Args:
56- tensor_generator: A generator that yields TensorType objects.
72+ tensor_sequence: Object wrapping a generator that yields `TensorType`
73+ objects.
5774
5875 Returns:
5976 A list containing all the tensors provided by the tensor_generator.
6077 """
61- return [t for t in tensor_generator ]
78+
79+ gen = cast (Iterator [TensorType ], tensor_sequence .generator_factory ())
80+
81+ if tensor_sequence .len_generator is not None :
82+ gen = cast (
83+ Iterator [TensorType ],
84+ tqdm (gen , total = tensor_sequence .len_generator , desc = "Blocks" ),
85+ )
86+
87+ return [t for t in gen ]
6288
6389
6490class NestedSequenceAggregator (Generic [TensorType ], ABC ):
6591 @abstractmethod
66- def __call__ (
67- self ,
68- nested_generators_of_tensors : Generator [
69- Generator [TensorType , None , None ], None , None
70- ],
71- ):
92+ def __call__ (self , nested_sequence_of_tensors : NestedLazyChunkSequence ):
7293 """
73- Aggregates tensors from a generator of generators .
94+ Aggregates tensors from a nested sequence of tensors .
7495
7596 Implement this method to specify how tensors, nested in two layers of
7697 generators, should be combined. Useful for complex data structures where tensors
@@ -81,27 +102,36 @@ def __call__(
81102class NestedListAggregator (NestedSequenceAggregator ):
82103 def __call__ (
83104 self ,
84- nested_generators_of_tensors : Generator [
85- Generator [TensorType , None , None ], None , None
86- ],
105+ nested_sequence_of_tensors : NestedLazyChunkSequence ,
87106 ) -> List [List [TensorType ]]:
88107 """
89108 Aggregates tensors from a nested generator structure into a list of lists.
90109 Each inner generator is converted into a list of tensors, resulting in a nested
91110 list structure.
92111
93112 Args:
94- nested_generators_of_tensors: A generator of generators, where each inner
95- generator yields TensorType objects.
113+ nested_sequence_of_tensors: Object wrapping a generator of generators,
114+ where each inner generator yields TensorType objects.
96115
97116 Returns:
98117 A list of lists, where each inner list contains tensors returned from one
99118 of the inner generators.
100119 """
101- return [list (tensor_gen ) for tensor_gen in nested_generators_of_tensors ]
120+ outer_gen = cast (
121+ Iterator [Iterator [TensorType ]],
122+ nested_sequence_of_tensors .generator_factory (),
123+ )
124+ len_outer_gen = nested_sequence_of_tensors .len_outer_generator
125+ if len_outer_gen is not None :
126+ outer_gen = cast (
127+ Iterator [Iterator [TensorType ]],
128+ tqdm (outer_gen , total = len_outer_gen , desc = "Row blocks" ),
129+ )
102130
131+ return [list (tensor_gen ) for tensor_gen in outer_gen ]
103132
104- class LazyChunkSequence :
133+
134+ class LazyChunkSequence (Generic [TensorType ]):
105135 """
106136 A class representing a chunked, and lazily evaluated array,
107137 where the chunking is restricted to the first dimension
@@ -114,12 +144,18 @@ class LazyChunkSequence:
114144 Attributes:
115145 generator_factory: A factory function that returns
116146 a generator. This generator yields chunks of the large array when called.
147+ len_generator: if the number of elements from the generator is
148+ known from the context, this optional parameter can be used to improve
149+ logging by adding a progressbar.
117150 """
118151
119152 def __init__ (
120- self , generator_factory : Callable [[], Generator [TensorType , None , None ]]
153+ self ,
154+ generator_factory : Callable [[], Generator [TensorType , None , None ]],
155+ len_generator : Optional [int ] = None ,
121156 ):
122157 self .generator_factory = generator_factory
158+ self .len_generator = len_generator
123159
124160 @log_duration (log_level = logging .INFO )
125161 def compute (self , aggregator : Optional [SequenceAggregator ] = None ):
@@ -140,7 +176,7 @@ def compute(self, aggregator: Optional[SequenceAggregator] = None):
140176 """
141177 if aggregator is None :
142178 aggregator = ListAggregator ()
143- return aggregator (self . generator_factory () )
179+ return aggregator (self )
144180
145181 @log_duration (log_level = logging .INFO )
146182 def to_zarr (
@@ -171,7 +207,15 @@ def to_zarr(
171207 """
172208 row_idx = 0
173209 z = None
174- for block in self .generator_factory ():
210+
211+ gen = cast (Iterator [TensorType ], self .generator_factory ())
212+
213+ if self .len_generator is not None :
214+ gen = cast (
215+ Iterator [TensorType ], tqdm (gen , total = self .len_generator , desc = "Blocks" )
216+ )
217+
218+ for block in gen :
175219 numpy_block = converter .to_numpy (block )
176220
177221 if z is None :
@@ -204,7 +248,7 @@ def _initialize_zarr_array(block: NDArray, path_or_url: str, overwrite: bool):
204248 )
205249
206250
207- class NestedLazyChunkSequence :
251+ class NestedLazyChunkSequence ( Generic [ TensorType ]) :
208252 """
209253 A class representing chunked, and lazily evaluated array, where the chunking is
210254 restricted to the first two dimensions.
@@ -216,16 +260,21 @@ class NestedLazyChunkSequence:
216260
217261 Attributes:
218262 generator_factory: A factory function that returns a generator of generators.
219- Each inner generator yields chunks.
263+ Each inner generator yields chunks
264+ len_outer_generator: if the number of elements from the outer generator is
265+ known from the context, this optional parameter can be used to improve
266+ logging by adding a progressbar.
220267 """
221268
222269 def __init__ (
223270 self ,
224271 generator_factory : Callable [
225272 [], Generator [Generator [TensorType , None , None ], None , None ]
226273 ],
274+ len_outer_generator : Optional [int ] = None ,
227275 ):
228276 self .generator_factory = generator_factory
277+ self .len_outer_generator = len_outer_generator
229278
230279 @log_duration (log_level = logging .INFO )
231280 def compute (self , aggregator : Optional [NestedSequenceAggregator ] = None ):
@@ -247,7 +296,7 @@ def compute(self, aggregator: Optional[NestedSequenceAggregator] = None):
247296 """
248297 if aggregator is None :
249298 aggregator = NestedListAggregator ()
250- return aggregator (self . generator_factory () )
299+ return aggregator (self )
251300
252301 @log_duration (log_level = logging .INFO )
253302 def to_zarr (
@@ -280,7 +329,17 @@ def to_zarr(
280329 row_idx = 0
281330 z = None
282331 numpy_block = None
283- for row_blocks in self .generator_factory ():
332+ block_generator = cast (Iterator [Iterator [TensorType ]], self .generator_factory ())
333+
334+ if self .len_outer_generator is not None :
335+ block_generator = cast (
336+ Iterator [Iterator [TensorType ]],
337+ tqdm (
338+ block_generator , total = self .len_outer_generator , desc = "Row blocks"
339+ ),
340+ )
341+
342+ for row_blocks in block_generator :
284343 col_idx = 0
285344 for block in row_blocks :
286345 numpy_block = converter .to_numpy (block )
0 commit comments