55This module contains commonly used functions.
66"""
77
8+ from __future__ import annotations
9+
810import operator
911import os
10- from collections .abc import Hashable , Iterable , Mapping , Sequence
12+ from collections .abc import Generator , Hashable , Iterable , Mapping , Sequence
1113from functools import reduce , wraps
1214from pathlib import Path
13- from typing import Any , Callable , Union , overload
15+ from typing import TYPE_CHECKING , Any , Callable , overload
1416from warnings import warn
1517
1618import numpy as np
3032 sign_replace_dict ,
3133)
3234
35+ if TYPE_CHECKING :
36+ from linopy .constraints import Constraint
37+ from linopy .expressions import LinearExpression
38+ from linopy .variables import Variable
39+
3340
3441def maybe_replace_sign (sign : str ) -> str :
3542 """
@@ -86,7 +93,7 @@ def format_string_as_variable_name(name: Hashable):
8693 return str (name ).replace (" " , "_" ).replace ("-" , "_" )
8794
8895
89- def get_from_iterable (lst : Union [ str , Iterable [Hashable ], None ] , index : int ):
96+ def get_from_iterable (lst : str | Iterable [Hashable ] | None , index : int ):
9097 """
9198 Returns the element at the specified index of the list, or None if the index
9299 is out of bounds.
@@ -99,9 +106,9 @@ def get_from_iterable(lst: Union[str, Iterable[Hashable], None], index: int):
99106
100107
101108def pandas_to_dataarray (
102- arr : Union [ pd .DataFrame , pd .Series ] ,
103- coords : Union [ Sequence [Union [ Sequence , pd .Index , DataArray ]], Mapping , None ] = None ,
104- dims : Union [ Iterable [Hashable ], None ] = None ,
109+ arr : pd .DataFrame | pd .Series ,
110+ coords : Sequence [Sequence | pd .Index | DataArray ] | Mapping | None = None ,
111+ dims : Iterable [Hashable ] | None = None ,
105112 ** kwargs ,
106113) -> DataArray :
107114 """
@@ -156,8 +163,8 @@ def pandas_to_dataarray(
156163
157164def numpy_to_dataarray (
158165 arr : np .ndarray ,
159- coords : Union [ Sequence [Union [ Sequence , pd .Index , DataArray ]], Mapping , None ] = None ,
160- dims : Union [ str , Iterable [Hashable ], None ] = None ,
166+ coords : Sequence [Sequence | pd .Index | DataArray ] | Mapping | None = None ,
167+ dims : str | Iterable [Hashable ] | None = None ,
161168 ** kwargs ,
162169) -> DataArray :
163170 """
@@ -195,8 +202,8 @@ def numpy_to_dataarray(
195202
196203def as_dataarray (
197204 arr ,
198- coords : Union [ Sequence [Union [ Sequence , pd .Index , DataArray ]], Mapping , None ] = None ,
199- dims : Union [ str , Iterable [Hashable ], None ] = None ,
205+ coords : Sequence [Sequence | pd .Index | DataArray ] | Mapping | None = None ,
206+ dims : str | Iterable [Hashable ] | None = None ,
200207 ** kwargs ,
201208) -> DataArray :
202209 """
@@ -246,7 +253,7 @@ def as_dataarray(
246253
247254
248255# TODO: rename to to_pandas_dataframe
249- def to_dataframe (ds : Dataset , mask_func : Union [ Callable , None ] = None ):
256+ def to_dataframe (ds : Dataset , mask_func : Callable | None = None ):
250257 """
251258 Convert an xarray Dataset to a pandas DataFrame.
252259
@@ -467,6 +474,65 @@ def fill_missing_coords(ds, fill_helper_dims: bool = False):
467474 return ds
468475
469476
477+ def iterate_slices (
478+ ds : Dataset | Variable | LinearExpression | Constraint ,
479+ slice_size : int | None = 10_000 ,
480+ slice_dims : list | None = None ,
481+ ) -> Generator [Dataset | Variable | LinearExpression | Constraint , None , None ]:
482+ """
483+ Generate slices of an xarray Dataset or DataArray with a specified soft maximum size.
484+
485+ The slicing is performed on the largest dimension of the input object.
486+ If the maximum size is larger than the total size of the object, the function yields
487+ the original object.
488+
489+ Parameters
490+ ----------
491+ ds : xarray.Dataset or xarray.DataArray
492+ The input xarray Dataset or DataArray to be sliced.
493+ slice_size : int
494+ The maximum number of elements in each slice. If the maximum size is too small to accommodate any slice,
495+ the function splits the largest dimension.
496+ slice_dims : list, optional
497+ The dimensions to slice along. If None, all dimensions in `coord_dims` are used if
498+ `coord_dims` is an attribute of the input object. Otherwise, all dimensions are used.
499+
500+ Yields
501+ ------
502+ xarray.Dataset or xarray.DataArray
503+ A slice of the input Dataset or DataArray.
504+
505+ """
506+ if slice_dims is None :
507+ slice_dims = list (getattr (ds , "coord_dims" , ds .dims ))
508+
509+ # Calculate the total number of elements in the dataset
510+ size = np .prod ([ds .sizes [dim ] for dim in ds .dims ], dtype = int )
511+
512+ if slice_size is None or size <= slice_size :
513+ yield ds
514+ return
515+
516+ # number of slices
517+ n_slices = max (size // slice_size , 1 )
518+
519+ # leading dimension (the dimension with the largest size)
520+ leading_dim = max (ds .sizes , key = ds .sizes .get ) # type: ignore
521+ size_of_leading_dim = ds .sizes [leading_dim ]
522+
523+ if size_of_leading_dim < n_slices :
524+ n_slices = size_of_leading_dim
525+
526+ chunk_size = ds .sizes [leading_dim ] // n_slices
527+
528+ # Iterate over the Cartesian product of slice indices
529+ for i in range (n_slices ):
530+ start = i * chunk_size
531+ end = start + chunk_size
532+ slice_dict = {leading_dim : slice (start , end )}
533+ yield ds .isel (slice_dict )
534+
535+
470536def _remap (array , mapping ):
471537 return mapping [array .ravel ()].reshape (array .shape )
472538
@@ -484,7 +550,7 @@ def replace_by_map(ds, mapping):
484550 )
485551
486552
487- def to_path (path : Union [ str , Path , None ] ) -> Union [ Path , None ] :
553+ def to_path (path : str | Path | None ) -> Path | None :
488554 """
489555 Convert a string to a Path object.
490556 """
@@ -526,7 +592,7 @@ def generate_indices_for_printout(dim_sizes, max_lines):
526592 yield tuple (np .unravel_index (i , dim_sizes ))
527593
528594
529- def align_lines_by_delimiter (lines : list [str ], delimiter : Union [ str , list [str ] ]):
595+ def align_lines_by_delimiter (lines : list [str ], delimiter : str | list [str ]):
530596 # Determine the maximum position of the delimiter
531597 if isinstance (delimiter , str ):
532598 delimiter = [delimiter ]
@@ -548,17 +614,18 @@ def align_lines_by_delimiter(lines: list[str], delimiter: Union[str, list[str]])
548614
549615
550616def get_label_position (
551- obj , values : Union [int , np .ndarray ]
552- ) -> Union [
553- Union [tuple [str , dict ], tuple [None , None ]],
554- list [Union [tuple [str , dict ], tuple [None , None ]]],
555- list [list [Union [tuple [str , dict ], tuple [None , None ]]]],
556- ]:
617+ obj , values : int | np .ndarray
618+ ) -> (
619+ tuple [str , dict ]
620+ | tuple [None , None ]
621+ | list [tuple [str , dict ] | tuple [None , None ]]
622+ | list [list [tuple [str , dict ] | tuple [None , None ]]]
623+ ):
557624 """
558625 Get tuple of name and coordinate for variable labels.
559626 """
560627
561- def find_single (value : int ) -> Union [ tuple [str , dict ], tuple [None , None ] ]:
628+ def find_single (value : int ) -> tuple [str , dict ] | tuple [None , None ]:
562629 if value == - 1 :
563630 return None , None
564631 for name , val in obj .items ():
0 commit comments