@@ -294,27 +294,47 @@ def _relative_path(self, path: str) -> str:
294294 return pathlib .Path (path ).relative_to (
295295 self .partition_properties .dir ).as_posix ()
296296
297+ def _normalize_partitions (self ,
298+ partitions : Iterable [str ]) -> Iterable [str ]:
299+ """Normalize the provided list of partitions to include the full
300+ partition's path.
301+
302+ Args:
303+ partitions: The list of partitions to normalize.
304+
305+ Returns:
306+ The list of partitions.
307+ """
308+ return filter (
309+ self .fs .exists ,
310+ map (
311+ lambda partition : self .fs .sep .join (
312+ (self .partition_properties .dir , partition )),
313+ sorted (set (partitions ))))
314+
297315 def partitions (
298316 self ,
299317 * ,
300- cache : Iterable [str ] | None = None ,
301- lock : bool = False ,
302318 filters : PartitionFilter = None ,
319+ indexer : Indexer | None = None ,
320+ selected_partitions : Iterable [str ] | None = None ,
303321 relative : bool = False ,
322+ lock : bool = False ,
304323 ) -> Iterator [str ]:
305324 """List the partitions of the collection.
306325
307326 Args:
308- cache: The list of partitions to use. If None, the partitions are
309- listed.
310- lock: Whether to lock the collection or not to avoid listing
311- partitions while the collection is being modified.
312327 filters: The predicate used to filter the partitions to load. If
313328 the predicate is a string, it is a valid python expression to
314329 filter the partitions, using the partitioning scheme as
315330 variables. If the predicate is a function, it is a function that
316331 takes the partition scheme as input and returns a boolean.
332+ indexer: The indexer to apply.
333+ selected_partitions: A list of partitions to load (using the
334+ partition relative path).
317335 relative: Whether to return the relative path.
336+ lock: Whether to lock the collection or not to avoid listing
337+ partitions while the collection is being modified.
318338
319339 Returns:
320340 The list of partitions.
@@ -336,8 +356,9 @@ def partitions(
336356
337357 base_dir : str = self .partition_properties .dir
338358 sep : str = self .fs .sep
339- if cache is not None :
340- partitions : Iterable [str ] = cache
359+ if selected_partitions is not None :
360+ partitions : Iterable [str ] = self ._normalize_partitions (
361+ partitions = selected_partitions )
341362 else :
342363 if lock :
343364 with self .synchronizer :
@@ -347,6 +368,17 @@ def partitions(
347368 partitions = self .partitioning .list_partitions (
348369 self .fs , base_dir )
349370
371+ if indexer is not None :
372+ # List of partitions existing in the indexer and partitions list
373+ partitions = list (partitions )
374+ partitions = [
375+ p for p in list_partitions_from_indexer (
376+ indexer = indexer ,
377+ partition_handler = self .partitioning ,
378+ base_dir = self .partition_properties .dir ,
379+ sep = self .fs .sep ) if p in partitions
380+ ]
381+
350382 yield from (self ._relative_path (item ) if relative else item
351383 for item in partitions
352384 if (item != self ._immutable and self ._is_selected (
@@ -553,8 +585,11 @@ def load(
553585 filters : PartitionFilter = None ,
554586 indexer : Indexer | None = None ,
555587 selected_variables : Iterable [str ] | None = None ,
588+ selected_partitions : Iterable [str ] | None = None ,
589+ distributed : bool = True ,
556590 ) -> dataset .Dataset | None :
557- """Load the selected partitions.
591+ """Load collection's data, respecting filters, indexer, and selected
592+ partitions constraints.
558593
559594 Args:
560595 delayed: Whether to load data in a dask array or not.
@@ -564,6 +599,9 @@ def load(
564599 indexer: The indexer to apply.
565600 selected_variables: A list of variables to retain from the
566601 collection. If None, all variables are kept.
602+ selected_partitions: A list of partitions to load (using the
603+ partition relative path).
604+ distributed: Whether to use dask or not. Default To True.
567605
568606 Returns:
569607 The dataset containing the selected partitions, or None if no
@@ -582,46 +620,149 @@ def load(
582620 ... filters=lambda keys: keys["year"] == 2019 and
583621 ... keys["month"] == 3 and keys["day"] % 2 == 0)
584622 """
585- client : dask .distributed .Client = dask_utils .get_client ()
586- arrays : list [dataset .Dataset ]
623+ # Delayed has to be True if dask is disabled
624+ if not distributed :
625+ delayed = False
626+
587627 if indexer is None :
588- selected_partitions = tuple (self .partitions (filters = filters ))
589- if len (selected_partitions ) == 0 :
590- return None
628+ arrays = self ._load_partitions (
629+ delayed = delayed ,
630+ filters = filters ,
631+ selected_variables = selected_variables ,
632+ selected_partitions = selected_partitions ,
633+ distributed = distributed )
634+ else :
635+ arrays = self ._load_partitions_indexer (
636+ indexer = indexer ,
637+ delayed = delayed ,
638+ filters = filters ,
639+ selected_variables = selected_variables ,
640+ selected_partitions = selected_partitions ,
641+ distributed = distributed )
642+
643+ if arrays is None :
644+ return None
591645
592- # No indexer, so the dataset is loaded directly for each
593- # selected partition.
646+ array : dataset .Dataset = arrays .pop (0 )
647+ if arrays :
648+ array = array .concat (arrays , self .partition_properties .dim )
649+ if self ._immutable :
650+ array .merge (
651+ storage .open_zarr_group (self ._immutable ,
652+ self .fs ,
653+ delayed = delayed ,
654+ selected_variables = selected_variables ))
655+ array .fill_attrs (self .metadata )
656+ return array
657+
658+ def _load_partitions (
659+ self ,
660+ * ,
661+ delayed : bool = True ,
662+ filters : PartitionFilter = None ,
663+ selected_variables : Iterable [str ] | None = None ,
664+ selected_partitions : Iterable [str ] | None = None ,
665+ distributed : bool = True ,
666+ ) -> list [dataset .Dataset ] | None :
667+ """Load collection's partitions, respecting filters, and selected
668+ partitions constraints.
669+
670+ Args:
671+ delayed: Whether to load data in a dask array or not.
672+ filters: The predicate used to filter the partitions to load. To
673+ get more information on the predicate, see the documentation of
674+ the :meth:`partitions` method.
675+ selected_variables: A list of variables to retain from the
676+ collection. If None, all variables are kept.
677+ selected_partitions: A list of partitions to load (using the
678+ partition relative path).
679+ distributed: Whether to use dask or not. Default To True.
680+
681+ Returns:
682+ The list of dataset for each partition, or None if no
683+ partitions were selected.
684+ """
685+ # No indexer, so the dataset is loaded directly for each
686+ # selected partition.
687+ selected_partitions = tuple (
688+ self .partitions (filters = filters ,
689+ selected_partitions = selected_partitions ))
690+
691+ if len (selected_partitions ) == 0 :
692+ return None
693+
694+ if distributed :
695+ client = dask_utils .get_client ()
594696 bag : dask .bag .core .Bag = dask .bag .core .from_sequence (
595- self . partitions ( filters = filters ) ,
697+ selected_partitions ,
596698 npartitions = dask_utils .dask_workers (client , cores_only = True ))
597699 arrays = bag .map (storage .open_zarr_group ,
598700 delayed = delayed ,
599701 fs = self .fs ,
600702 selected_variables = selected_variables ).compute ()
601703 else :
602- # We're going to reuse the indexer variable, so ensure it is
603- # an iterable not a generator.
604- indexer = tuple (indexer )
605-
606- # Build the indexer arguments.
607- partitions = self .partitions (filters = filters ,
608- cache = list_partitions_from_indexer (
609- indexer , self .partitioning ,
610- self .partition_properties .dir ,
611- self .fs .sep ))
612- args = tuple (
613- build_indexer_args (self ,
614- filters ,
615- indexer ,
616- partitions = partitions ))
617- if len (args ) == 0 :
618- return None
704+ arrays = [
705+ storage .open_zarr_group (dirname = partition ,
706+ delayed = delayed ,
707+ fs = self .fs ,
708+ selected_variables = selected_variables )
709+ for partition in selected_partitions
710+ ]
711+
712+ return arrays
713+
714+ def _load_partitions_indexer (
715+ self ,
716+ * ,
717+ indexer : Indexer ,
718+ delayed : bool = True ,
719+ filters : PartitionFilter = None ,
720+ selected_variables : Iterable [str ] | None = None ,
721+ selected_partitions : Iterable [str ] | None = None ,
722+ distributed : bool = True ,
723+ ) -> list [dataset .Dataset ] | None :
724+ """Load collection's partitions, respecting filters, indexer, and
725+ selected partitions constraints.
726+
727+ Args:
728+ indexer: The indexer to apply.
729+ delayed: Whether to load data in a dask array or not.
730+ filters: The predicate used to filter the partitions to load. To
731+ get more information on the predicate, see the documentation of
732+ the :meth:`partitions` method.
733+ selected_variables: A list of variables to retain from the
734+ collection. If None, all variables are kept.
735+ selected_partitions: A list of partitions to load (using the
736+ partition relative path).
737+ distributed: Whether to use dask or not. Default To True.
619738
739+ Returns:
740+ The list of dataset for each partition, or None if no
741+ partitions were selected.
742+ """
743+ # We're going to reuse the indexer variable, so ensure it is
744+ # an iterable not a generator.
745+ indexer = tuple (indexer )
746+
747+ # Build the indexer arguments.
748+ partitions = self .partitions (selected_partitions = selected_partitions ,
749+ filters = filters ,
750+ indexer = indexer )
751+ args = tuple (
752+ build_indexer_args (collection = self ,
753+ filters = filters ,
754+ indexer = indexer ,
755+ partitions = partitions ))
756+ if len (args ) == 0 :
757+ return None
758+
759+ # Finally, load the selected partitions and apply the indexer.
760+ if distributed :
761+ client = dask_utils .get_client ()
620762 bag = dask .bag .core .from_sequence (
621763 args ,
622764 npartitions = dask_utils .dask_workers (client , cores_only = True ))
623765
624- # Finally, load the selected partitions and apply the indexer.
625766 arrays = list (
626767 itertools .chain .from_iterable (
627768 bag .map (
@@ -632,18 +773,19 @@ def load(
632773 partition_properties = self .partition_properties ,
633774 selected_variables = selected_variables ,
634775 ).compute ()))
776+ else :
777+ arrays = list (
778+ itertools .chain .from_iterable ([
779+ _load_and_apply_indexer (
780+ args = a ,
781+ delayed = delayed ,
782+ fs = self .fs ,
783+ partition_handler = self .partitioning ,
784+ partition_properties = self .partition_properties ,
785+ selected_variables = selected_variables ) for a in args
786+ ]))
635787
636- array : dataset .Dataset = arrays .pop (0 )
637- if arrays :
638- array = array .concat (arrays , self .partition_properties .dim )
639- if self ._immutable :
640- array .merge (
641- storage .open_zarr_group (self ._immutable ,
642- self .fs ,
643- delayed = delayed ,
644- selected_variables = selected_variables ))
645- array .fill_attrs (self .metadata )
646- return array
788+ return arrays
647789
648790 def _bag_from_partitions (
649791 self ,
0 commit comments