Skip to content

Commit 4dba46a

Browse files
committed
io: write out lp file with sliced variables and constraints
1 parent 03f3cc7 commit 4dba46a

File tree

12 files changed

+394
-162
lines changed

12 files changed

+394
-162
lines changed

doc/release_notes.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Release Notes
44
Upcoming Version
55
----------------
66

7+
* When writing out an LP file, large variables and constraints are now chunked to avoid memory issues. This is especially useful for large models with constraints with many terms. The chunk size can be set with the `slice_size` argument in the `solve` function.
8+
79
Version 0.3.15
810
--------------
911

linopy/common.py

Lines changed: 87 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,14 @@
55
This module contains commonly used functions.
66
"""
77

8+
from __future__ import annotations
9+
810
import operator
911
import os
10-
from collections.abc import Hashable, Iterable, Mapping, Sequence
12+
from collections.abc import Generator, Hashable, Iterable, Mapping, Sequence
1113
from functools import reduce, wraps
1214
from pathlib import Path
13-
from typing import Any, Callable, Union, overload
15+
from typing import TYPE_CHECKING, Any, Callable, overload
1416
from warnings import warn
1517

1618
import numpy as np
@@ -30,6 +32,11 @@
3032
sign_replace_dict,
3133
)
3234

35+
if TYPE_CHECKING:
36+
from linopy.constraints import Constraint
37+
from linopy.expressions import LinearExpression
38+
from linopy.variables import Variable
39+
3340

3441
def maybe_replace_sign(sign: str) -> str:
3542
"""
@@ -86,7 +93,7 @@ def format_string_as_variable_name(name: Hashable):
8693
return str(name).replace(" ", "_").replace("-", "_")
8794

8895

89-
def get_from_iterable(lst: Union[str, Iterable[Hashable], None], index: int):
96+
def get_from_iterable(lst: str | Iterable[Hashable] | None, index: int):
9097
"""
9198
Returns the element at the specified index of the list, or None if the index
9299
is out of bounds.
@@ -99,9 +106,9 @@ def get_from_iterable(lst: Union[str, Iterable[Hashable], None], index: int):
99106

100107

101108
def pandas_to_dataarray(
102-
arr: Union[pd.DataFrame, pd.Series],
103-
coords: Union[Sequence[Union[Sequence, pd.Index, DataArray]], Mapping, None] = None,
104-
dims: Union[Iterable[Hashable], None] = None,
109+
arr: pd.DataFrame | pd.Series,
110+
coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None = None,
111+
dims: Iterable[Hashable] | None = None,
105112
**kwargs,
106113
) -> DataArray:
107114
"""
@@ -156,8 +163,8 @@ def pandas_to_dataarray(
156163

157164
def numpy_to_dataarray(
158165
arr: np.ndarray,
159-
coords: Union[Sequence[Union[Sequence, pd.Index, DataArray]], Mapping, None] = None,
160-
dims: Union[str, Iterable[Hashable], None] = None,
166+
coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None = None,
167+
dims: str | Iterable[Hashable] | None = None,
161168
**kwargs,
162169
) -> DataArray:
163170
"""
@@ -195,8 +202,8 @@ def numpy_to_dataarray(
195202

196203
def as_dataarray(
197204
arr,
198-
coords: Union[Sequence[Union[Sequence, pd.Index, DataArray]], Mapping, None] = None,
199-
dims: Union[str, Iterable[Hashable], None] = None,
205+
coords: Sequence[Sequence | pd.Index | DataArray] | Mapping | None = None,
206+
dims: str | Iterable[Hashable] | None = None,
200207
**kwargs,
201208
) -> DataArray:
202209
"""
@@ -246,7 +253,7 @@ def as_dataarray(
246253

247254

248255
# TODO: rename to to_pandas_dataframe
249-
def to_dataframe(ds: Dataset, mask_func: Union[Callable, None] = None):
256+
def to_dataframe(ds: Dataset, mask_func: Callable | None = None):
250257
"""
251258
Convert an xarray Dataset to a pandas DataFrame.
252259
@@ -467,6 +474,65 @@ def fill_missing_coords(ds, fill_helper_dims: bool = False):
467474
return ds
468475

469476

477+
def iterate_slices(
478+
ds: Dataset | Variable | LinearExpression | Constraint,
479+
slice_size: int | None = 10_000,
480+
slice_dims: list | None = None,
481+
) -> Generator[Dataset | Variable | LinearExpression | Constraint, None, None]:
482+
"""
483+
Generate slices of an xarray Dataset or DataArray with a specified soft maximum size.
484+
485+
The slicing is performed on the largest dimension of the input object.
486+
If the maximum size is larger than the total size of the object, the function yields
487+
the original object.
488+
489+
Parameters
490+
----------
491+
ds : xarray.Dataset or xarray.DataArray
492+
The input xarray Dataset or DataArray to be sliced.
493+
slice_size : int
494+
The maximum number of elements in each slice. If the maximum size is too small to accommodate any slice,
495+
the function splits the largest dimension.
496+
slice_dims : list, optional
497+
The dimensions to slice along. If None, all dimensions in `coord_dims` are used if
498+
`coord_dims` is an attribute of the input object. Otherwise, all dimensions are used.
499+
500+
Yields
501+
------
502+
xarray.Dataset or xarray.DataArray
503+
A slice of the input Dataset or DataArray.
504+
505+
"""
506+
if slice_dims is None:
507+
slice_dims = list(getattr(ds, "coord_dims", ds.dims))
508+
509+
# Calculate the total number of elements in the dataset
510+
size = np.prod([ds.sizes[dim] for dim in ds.dims], dtype=int)
511+
512+
if slice_size is None or size <= slice_size:
513+
yield ds
514+
return
515+
516+
# number of slices
517+
n_slices = max(size // slice_size, 1)
518+
519+
# leading dimension (the dimension with the largest size)
520+
leading_dim = max(ds.sizes, key=ds.sizes.get) # type: ignore
521+
size_of_leading_dim = ds.sizes[leading_dim]
522+
523+
if size_of_leading_dim < n_slices:
524+
n_slices = size_of_leading_dim
525+
526+
chunk_size = ds.sizes[leading_dim] // n_slices
527+
528+
# Iterate over the Cartesian product of slice indices
529+
for i in range(n_slices):
530+
start = i * chunk_size
531+
end = start + chunk_size
532+
slice_dict = {leading_dim: slice(start, end)}
533+
yield ds.isel(slice_dict)
534+
535+
470536
def _remap(array, mapping):
471537
return mapping[array.ravel()].reshape(array.shape)
472538

@@ -484,7 +550,7 @@ def replace_by_map(ds, mapping):
484550
)
485551

486552

487-
def to_path(path: Union[str, Path, None]) -> Union[Path, None]:
553+
def to_path(path: str | Path | None) -> Path | None:
488554
"""
489555
Convert a string to a Path object.
490556
"""
@@ -526,7 +592,7 @@ def generate_indices_for_printout(dim_sizes, max_lines):
526592
yield tuple(np.unravel_index(i, dim_sizes))
527593

528594

529-
def align_lines_by_delimiter(lines: list[str], delimiter: Union[str, list[str]]):
595+
def align_lines_by_delimiter(lines: list[str], delimiter: str | list[str]):
530596
# Determine the maximum position of the delimiter
531597
if isinstance(delimiter, str):
532598
delimiter = [delimiter]
@@ -548,17 +614,18 @@ def align_lines_by_delimiter(lines: list[str], delimiter: Union[str, list[str]])
548614

549615

550616
def get_label_position(
551-
obj, values: Union[int, np.ndarray]
552-
) -> Union[
553-
Union[tuple[str, dict], tuple[None, None]],
554-
list[Union[tuple[str, dict], tuple[None, None]]],
555-
list[list[Union[tuple[str, dict], tuple[None, None]]]],
556-
]:
617+
obj, values: int | np.ndarray
618+
) -> (
619+
tuple[str, dict]
620+
| tuple[None, None]
621+
| list[tuple[str, dict] | tuple[None, None]]
622+
| list[list[tuple[str, dict] | tuple[None, None]]]
623+
):
557624
"""
558625
Get tuple of name and coordinate for variable labels.
559626
"""
560627

561-
def find_single(value: int) -> Union[tuple[str, dict], tuple[None, None]]:
628+
def find_single(value: int) -> tuple[str, dict] | tuple[None, None]:
562629
if value == -1:
563630
return None, None
564631
for name, val in obj.items():

linopy/constraints.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
has_optimized_model,
4141
infer_schema_polars,
4242
is_constant,
43+
iterate_slices,
4344
maybe_replace_signs,
4445
print_coord,
4546
print_single_constraint,
@@ -658,6 +659,8 @@ def to_polars(self):
658659

659660
stack = conwrap(Dataset.stack)
660661

662+
iterate_slices = iterate_slices
663+
661664

662665
@dataclass(repr=False)
663666
class Constraints:

linopy/expressions.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
get_index_map,
5353
group_terms_polars,
5454
has_optimized_model,
55+
iterate_slices,
5556
print_single_expression,
5657
to_dataframe,
5758
to_polars,
@@ -1457,6 +1458,8 @@ def to_polars(self) -> pl.DataFrame:
14571458

14581459
stack = exprwrap(Dataset.stack)
14591460

1461+
iterate_slices = iterate_slices
1462+
14601463

14611464
class QuadraticExpression(LinearExpression):
14621465
"""

0 commit comments

Comments
 (0)