Skip to content

Commit c6c64a7

Browse files
FBumannpre-commit-ci[bot]FabianHofmann
authored
Speed up printing (#526)
* Add printing benchmark * Improve benchmark and implement LabelPositionIndex * Improve benchmark and implement LabelPositionIndex * Add test * Split benchmark script into separate files * Update benchmark_printing.py * Update benchmark_printing.py * More improvements * Improve benchmark * Optimize nvars and ncons * Rename methods and only keep old one for benchmarking * Update release notes for printing optimization * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add summary benchmark_scaling.py * Update release_notes.rst * Remove dev-scripts * Remove dev-scripts * Fix mypy type errors in LabelPositionIndex * Add type annotations to test_label_position.py --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Fabian <[email protected]>
1 parent b3a4bfd commit c6c64a7

File tree

5 files changed

+622
-11
lines changed

5 files changed

+622
-11
lines changed

doc/release_notes.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Release Notes
44
.. Upcoming Version
55
66
* Fix compatibility for xpress versions below 9.6 (regression)
7+
* Performance: Up to 50x faster ``repr()`` for variables/constraints via O(log n) label lookup and direct numpy indexing
8+
* Performance: Up to 46x faster ``ncons`` property by replacing ``.flat.labels.unique()`` with direct counting
79

810
Version 0.5.8
911
--------------

linopy/common.py

Lines changed: 169 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -750,7 +750,118 @@ def get_dims_with_index_levels(
750750
return dims_with_levels
751751

752752

753-
def get_label_position(
753+
class LabelPositionIndex:
754+
"""
755+
Index for fast O(log n) lookup of label positions using binary search.
756+
757+
This class builds a sorted index of label ranges and uses binary search
758+
to find which container (variable/constraint) a label belongs to.
759+
760+
Parameters
761+
----------
762+
obj : Any
763+
Container object with items() method returning (name, val) pairs,
764+
where val has .labels and .range attributes.
765+
"""
766+
767+
__slots__ = ("_starts", "_names", "_obj", "_built")
768+
769+
def __init__(self, obj: Any) -> None:
770+
self._obj = obj
771+
self._starts: np.ndarray | None = None
772+
self._names: list[str] | None = None
773+
self._built = False
774+
775+
def _build_index(self) -> None:
776+
"""Build the sorted index of label ranges."""
777+
if self._built:
778+
return
779+
780+
ranges = []
781+
for name, val in self._obj.items():
782+
start, stop = val.range
783+
ranges.append((start, name))
784+
785+
# Sort by start value
786+
ranges.sort(key=lambda x: x[0])
787+
self._starts = np.array([r[0] for r in ranges])
788+
self._names = [r[1] for r in ranges]
789+
self._built = True
790+
791+
def invalidate(self) -> None:
792+
"""Invalidate the index (call when items are added/removed)."""
793+
self._built = False
794+
self._starts = None
795+
self._names = None
796+
797+
def find_single(self, value: int) -> tuple[str, dict] | tuple[None, None]:
798+
"""Find the name and coordinates for a single label value."""
799+
if value == -1:
800+
return None, None
801+
802+
self._build_index()
803+
starts = self._starts
804+
names = self._names
805+
assert starts is not None and names is not None
806+
807+
# Binary search to find the right range
808+
idx = int(np.searchsorted(starts, value, side="right")) - 1
809+
810+
if idx < 0 or idx >= len(starts):
811+
raise ValueError(f"Label {value} is not existent in the model.")
812+
813+
name = names[idx]
814+
val = self._obj[name]
815+
start, stop = val.range
816+
817+
# Verify the value is in range
818+
if value < start or value >= stop:
819+
raise ValueError(f"Label {value} is not existent in the model.")
820+
821+
labels = val.labels
822+
index = np.unravel_index(value - start, labels.shape)
823+
coord = {dim: labels.indexes[dim][i] for dim, i in zip(labels.dims, index)}
824+
return name, coord
825+
826+
def find_single_with_index(
827+
self, value: int
828+
) -> tuple[str, dict, tuple[int, ...]] | tuple[None, None, None]:
829+
"""
830+
Find name, coordinates, and raw numpy index for a single label value.
831+
832+
Returns (name, coord, index) where index is a tuple of integers that
833+
can be used for direct numpy indexing (e.g., arr.values[index]).
834+
This avoids the overhead of xarray's .sel() method.
835+
"""
836+
if value == -1:
837+
return None, None, None
838+
839+
self._build_index()
840+
starts = self._starts
841+
names = self._names
842+
assert starts is not None and names is not None
843+
844+
# Binary search to find the right range
845+
idx = int(np.searchsorted(starts, value, side="right")) - 1
846+
847+
if idx < 0 or idx >= len(starts):
848+
raise ValueError(f"Label {value} is not existent in the model.")
849+
850+
name = names[idx]
851+
val = self._obj[name]
852+
start, stop = val.range
853+
854+
# Verify the value is in range
855+
if value < start or value >= stop:
856+
raise ValueError(f"Label {value} is not existent in the model.")
857+
858+
labels = val.labels
859+
index = np.unravel_index(value - start, labels.shape)
860+
coord = {dim: labels.indexes[dim][i] for dim, i in zip(labels.dims, index)}
861+
return name, coord, index
862+
863+
864+
def _get_label_position_linear(
754865
obj: Any, values: int | np.ndarray
755866
) -> (
756867
tuple[str, dict]
@@ -760,6 +871,9 @@ def get_label_position(
760871
):
761872
"""
762873
Get tuple of name and coordinate for variable labels.
874+
875+
This is the original O(n) implementation that scans through all items.
876+
Used only for testing/benchmarking comparisons.
763877
"""
764878

765879
def find_single(value: int) -> tuple[str, dict] | tuple[None, None]:
@@ -795,6 +909,53 @@ def find_single(value: int) -> tuple[str, dict] | tuple[None, None]:
795909
raise ValueError("Array's with more than two dimensions is not supported")
796910

797911

912+
def get_label_position(
913+
obj: Any,
914+
values: int | np.ndarray,
915+
index: LabelPositionIndex | None = None,
916+
) -> (
917+
tuple[str, dict]
918+
| tuple[None, None]
919+
| list[tuple[str, dict] | tuple[None, None]]
920+
| list[list[tuple[str, dict] | tuple[None, None]]]
921+
):
922+
"""
923+
Get tuple of name and coordinate for variable labels.
924+
925+
Uses O(log n) binary search with a cached index for fast lookups.
926+
927+
Parameters
928+
----------
929+
obj : Any
930+
Container object with items() method (Variables or Constraints).
931+
values : int or np.ndarray
932+
Label value(s) to look up.
933+
index : LabelPositionIndex, optional
934+
Pre-built index for fast lookups. If None, one will be created.
935+
936+
Returns
937+
-------
938+
tuple or list
939+
(name, coord) tuple for single values, or list of tuples for arrays.
940+
"""
941+
if index is None:
942+
index = LabelPositionIndex(obj)
943+
944+
if isinstance(values, int):
945+
return index.find_single(values)
946+
947+
values = np.array(values)
948+
ndim = values.ndim
949+
if ndim == 0:
950+
return index.find_single(values.item())
951+
elif ndim == 1:
952+
return [index.find_single(int(v)) for v in values]
953+
elif ndim == 2:
954+
return [[index.find_single(int(v)) for v in col] for col in values.T]
955+
else:
956+
raise ValueError("Array's with more than two dimensions is not supported")
957+
958+
798959
def print_coord(coord: dict[str, Any] | Iterable[Any]) -> str:
799960
"""
800961
Format coordinates into a string representation.
@@ -838,14 +999,16 @@ def print_single_variable(model: Any, label: int) -> str:
838999
return "None"
8391000

8401001
variables = model.variables
841-
name, coord = variables.get_label_position(label)
1002+
name, coord, index = variables.get_label_position_with_index(label)
8421003

843-
lower = variables[name].lower.sel(coord).item()
844-
upper = variables[name].upper.sel(coord).item()
1004+
var = variables[name]
1005+
# Use direct numpy indexing instead of .sel() for performance
1006+
lower = var.lower.values[index]
1007+
upper = var.upper.values[index]
8451008

846-
if variables[name].attrs["binary"]:
1009+
if var.attrs["binary"]:
8471010
bounds = " ∈ {0, 1}"
848-
elif variables[name].attrs["integer"]:
1011+
elif var.attrs["integer"]:
8491012
bounds = f" ∈ Z ⋂ [{lower:.4g},...,{upper:.4g}]"
8501013
else:
8511014
bounds = f" ∈ [{lower:.4g}, {upper:.4g}]"

linopy/constraints.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
from linopy import expressions, variables
3131
from linopy.common import (
32+
LabelPositionIndex,
3233
LocIndexer,
3334
align_lines_by_delimiter,
3435
assign_multiindex_safe,
@@ -696,6 +697,7 @@ class Constraints:
696697

697698
data: dict[str, Constraint]
698699
model: Model
700+
_label_position_index: LabelPositionIndex | None = None
699701

700702
dataset_attrs = ["labels", "coeffs", "vars", "sign", "rhs"]
701703
dataset_names = [
@@ -792,12 +794,19 @@ def add(self, constraint: Constraint) -> None:
792794
Add a constraint to the constraints constrainer.
793795
"""
794796
self.data[constraint.name] = constraint
797+
self._invalidate_label_position_index()
795798

796799
def remove(self, name: str) -> None:
797800
"""
798801
Remove constraint `name` from the constraints.
799802
"""
800803
self.data.pop(name)
804+
self._invalidate_label_position_index()
805+
806+
def _invalidate_label_position_index(self) -> None:
807+
"""Invalidate the label position index cache."""
808+
if self._label_position_index is not None:
809+
self._label_position_index.invalidate()
801810

802811
@property
803812
def labels(self) -> Dataset:
@@ -869,9 +878,36 @@ def ncons(self) -> int:
869878
"""
870879
Get the number all constraints effectively used by the model.
871880
872-
These excludes constraints with missing labels.
881+
This excludes constraints with missing labels or where all variables
882+
are masked (vars == -1).
873883
"""
874-
return len(self.flat.labels.unique())
884+
total = 0
885+
for con in self.data.values():
886+
labels = con.labels.values
887+
vars_arr = con.vars.values
888+
889+
# Handle scalar constraint (single constraint, labels is 0-d)
890+
if labels.ndim == 0:
891+
# Scalar: valid if label != -1 and any var != -1
892+
if labels != -1 and (vars_arr != -1).any():
893+
total += 1
894+
continue
895+
896+
# Array constraint: labels has constraint dimensions, vars has
897+
# constraint dimensions + _term dimension
898+
valid_labels = labels != -1
899+
900+
# Check if any variable in each constraint is valid (not -1)
901+
# vars has shape (..., n_terms) where ... matches labels shape
902+
has_valid_var = (vars_arr != -1).any(axis=-1)
903+
904+
active = valid_labels & has_valid_var
905+
906+
if con.mask is not None:
907+
active = active & con.mask.values
908+
909+
total += int(active.sum())
910+
return total
875911

876912
@property
877913
def inequalities(self) -> Constraints:
@@ -957,8 +993,12 @@ def get_label_position(
957993
):
958994
"""
959995
Get tuple of name and coordinate for constraint labels.
996+
997+
Uses an optimized O(log n) binary search implementation with a cached index.
960998
"""
961-
return get_label_position(self, values)
999+
if self._label_position_index is None:
1000+
self._label_position_index = LabelPositionIndex(self)
1001+
return get_label_position(self, values, self._label_position_index)
9621002

9631003
def print_labels(
9641004
self, values: Sequence[int], display_max_terms: int | None = None

linopy/variables.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030

3131
import linopy.expressions as expressions
3232
from linopy.common import (
33+
LabelPositionIndex,
3334
LocIndexer,
3435
as_dataarray,
3536
assign_multiindex_safe,
@@ -1166,6 +1167,7 @@ class Variables:
11661167

11671168
data: dict[str, Variable]
11681169
model: Model
1170+
_label_position_index: LabelPositionIndex | None = None
11691171

11701172
dataset_attrs = ["labels", "lower", "upper"]
11711173
dataset_names = ["Labels", "Lower bounds", "Upper bounds"]
@@ -1256,12 +1258,19 @@ def add(self, variable: Variable) -> None:
12561258
Add a variable to the variables container.
12571259
"""
12581260
self.data[variable.name] = variable
1261+
self._invalidate_label_position_index()
12591262

12601263
def remove(self, name: str) -> None:
12611264
"""
12621265
Remove variable `name` from the variables.
12631266
"""
12641267
self.data.pop(name)
1268+
self._invalidate_label_position_index()
1269+
1270+
def _invalidate_label_position_index(self) -> None:
1271+
"""Invalidate the label position index cache."""
1272+
if self._label_position_index is not None:
1273+
self._label_position_index.invalidate()
12651274

12661275
@property
12671276
def attrs(self) -> dict[Any, Any]:
@@ -1321,7 +1330,14 @@ def nvars(self) -> int:
13211330
13221331
These excludes variables with missing labels.
13231332
"""
1324-
return len(self.flat.labels.unique())
1333+
total = 0
1334+
for var in self.data.values():
1335+
labels = var.labels.values
1336+
if var.mask is not None:
1337+
total += int((labels[var.mask.values] != -1).sum())
1338+
else:
1339+
total += int((labels != -1).sum())
1340+
return total
13251341

13261342
@property
13271343
def binaries(self) -> Variables:
@@ -1418,8 +1434,36 @@ def get_label_range(self, name: str) -> tuple[int, int]:
14181434
def get_label_position(self, values: int | ndarray) -> Any:
14191435
"""
14201436
Get tuple of name and coordinate for variable labels.
1437+
1438+
Uses an optimized O(log n) binary search implementation with a cached index.
1439+
"""
1440+
if self._label_position_index is None:
1441+
self._label_position_index = LabelPositionIndex(self)
1442+
return get_label_position(self, values, self._label_position_index)
1443+
1444+
def get_label_position_with_index(
1445+
self, label: int
1446+
) -> tuple[str, dict, tuple[int, ...]] | tuple[None, None, None]:
1447+
"""
1448+
Get name, coordinate, and raw numpy index for a single variable label.
1449+
1450+
This is an optimized version that also returns the raw index for direct
1451+
numpy array access, avoiding xarray's .sel() overhead.
1452+
1453+
Parameters
1454+
----------
1455+
label : int
1456+
The variable label to look up.
1457+
1458+
Returns
1459+
-------
1460+
tuple
1461+
(name, coord, index) where index is a tuple for numpy indexing,
1462+
or (None, None, None) if label is -1.
14211463
"""
1422-
return get_label_position(self, values)
1464+
if self._label_position_index is None:
1465+
self._label_position_index = LabelPositionIndex(self)
1466+
return self._label_position_index.find_single_with_index(label)
14231467

14241468
def print_labels(self, values: list[int]) -> None:
14251469
"""

0 commit comments

Comments
 (0)