Skip to content

Commit 04c9e52

Browse files
fix: correctly pass slice_size argument (#379)
* fix: correctly pass `slice_size` argument * fix: correctly use slice dim selection in iterate slices
1 parent 0b2a208 commit 04c9e52

File tree

4 files changed

+30
-17
lines changed

4 files changed

+30
-17
lines changed

doc/release_notes.rst

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
11
Release Notes
22
=============
33

4-
.. Upcoming Version
5-
.. ----------------
4+
Upcoming Version
5+
----------------
6+
7+
* Fix the `slice_size` argument in the `solve` function. The argument was not properly passed to the `to_file` function.
8+
* Fix the slicing of constraints in case the term dimension is larger than the leading constraint coordinate dimension.
69

710
Version 0.4.0
811
--------------

linopy/common.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -506,6 +506,11 @@ def iterate_slices(
506506
if slice_dims is None:
507507
slice_dims = list(getattr(ds, "coord_dims", ds.dims))
508508

509+
if not set(slice_dims).issubset(ds.dims):
510+
raise ValueError(
511+
"Invalid slice dimensions. Must be a subset of the dataset dimensions."
512+
)
513+
509514
# Calculate the total number of elements in the dataset
510515
size = np.prod([ds.sizes[dim] for dim in ds.dims], dtype=int)
511516

@@ -517,7 +522,8 @@ def iterate_slices(
517522
n_slices = max(size // slice_size, 1)
518523

519524
# leading dimension (the dimension with the largest size)
520-
leading_dim = max(ds.sizes, key=ds.sizes.get) # type: ignore
525+
sizes = {dim: ds.sizes[dim] for dim in slice_dims}
526+
leading_dim = max(sizes, key=sizes.get) # type: ignore
521527
size_of_leading_dim = ds.sizes[leading_dim]
522528

523529
if size_of_leading_dim < n_slices:

linopy/model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1124,7 +1124,9 @@ def solve(
11241124
env=env,
11251125
)
11261126
else:
1127-
problem_fn = self.to_file(to_path(problem_fn), io_api)
1127+
problem_fn = self.to_file(
1128+
to_path(problem_fn), io_api, slice_size=slice_size
1129+
)
11281130
result = solver.solve_problem_from_file(
11291131
problem_fn=to_path(problem_fn),
11301132
solution_fn=to_path(solution_fn),

test/test_common.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -475,11 +475,11 @@ def test_iterate_slices_basic():
475475

476476
def test_iterate_slices_with_exclude_dims():
477477
ds = xr.Dataset(
478-
{"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002
479-
coords={"x": np.arange(10), "y": np.arange(10)},
478+
{"var": (("x", "y"), np.random.rand(10, 20))}, # noqa: NPY002
479+
coords={"x": np.arange(10), "y": np.arange(20)},
480480
)
481481
slices = list(iterate_slices(ds, slice_size=20, slice_dims=["x"]))
482-
assert len(slices) == 5
482+
assert len(slices) == 10
483483
for s in slices:
484484
assert isinstance(s, xr.Dataset)
485485
assert set(s.dims) == set(ds.dims)
@@ -499,11 +499,13 @@ def test_iterate_slices_large_max_size():
499499

500500
def test_iterate_slices_small_max_size():
501501
ds = xr.Dataset(
502-
{"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002
503-
coords={"x": np.arange(10), "y": np.arange(10)},
502+
{"var": (("x", "y"), np.random.rand(10, 20))}, # noqa: NPY002
503+
coords={"x": np.arange(10), "y": np.arange(20)},
504504
)
505-
slices = list(iterate_slices(ds, slice_size=8, slice_dims=[]))
506-
assert len(slices) == 10
505+
slices = list(iterate_slices(ds, slice_size=8, slice_dims=["x"]))
506+
assert (
507+
len(slices) == 10
508+
) # goes to the smallest slice possible which is 1 for the x dimension
507509
for s in slices:
508510
assert isinstance(s, xr.Dataset)
509511
assert set(s.dims) == set(ds.dims)
@@ -520,16 +522,16 @@ def test_iterate_slices_slice_size_none():
520522
assert ds.equals(s)
521523

522524

523-
def test_iterate_slices_no_slice_dims():
525+
def test_iterate_slices_invalid_slice_dims():
524526
ds = xr.Dataset(
525527
{"var": (("x", "y"), np.random.rand(10, 10))}, # noqa: NPY002
526528
coords={"x": np.arange(10), "y": np.arange(10)},
527529
)
528-
slices = list(iterate_slices(ds, slice_size=50, slice_dims=[]))
529-
assert len(slices) == 2
530-
for s in slices:
531-
assert isinstance(s, xr.Dataset)
532-
assert set(s.dims) == set(ds.dims)
530+
with pytest.raises(ValueError):
531+
list(iterate_slices(ds, slice_size=50, slice_dims=[]))
532+
533+
with pytest.raises(ValueError):
534+
list(iterate_slices(ds, slice_size=50, slice_dims=["z"]))
533535

534536

535537
def test_get_dims_with_index_levels():

0 commit comments

Comments
 (0)