Skip to content

Commit 2fadaf7

Browse files
sanitize infinity values: fix multiindex assignment (#371)
* constraints: fix multiindex assignment * remove duplication
1 parent 3d0275d commit 2fadaf7

File tree

1 file changed

+12
-13
lines changed

1 file changed

+12
-13
lines changed

linopy/constraints.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -864,33 +864,32 @@ def sanitize_zeros(self) -> None:
864864
"""
865865
for name in self:
866866
not_zero = abs(self[name].coeffs) > 1e-10
867-
constraint = self[name]
868-
constraint.vars = self[name].vars.where(not_zero, -1)
869-
constraint.coeffs = self[name].coeffs.where(not_zero)
867+
con = self[name]
868+
con.vars = self[name].vars.where(not_zero, -1)
869+
con.coeffs = self[name].coeffs.where(not_zero)
870870

871871
def sanitize_missings(self) -> None:
872872
"""
873873
Set constraints labels to -1 where all variables in the lhs are
874874
missing.
875875
"""
876876
for name in self:
877-
contains_non_missing = (self[name].vars != -1).any(self[name].term_dim)
878-
self[name].data["labels"] = self[name].labels.where(
879-
contains_non_missing, -1
880-
)
877+
con = self[name]
878+
contains_non_missing = (con.vars != -1).any(con.term_dim)
879+
labels = self[name].labels.where(contains_non_missing, -1)
880+
con._data = assign_multiindex_safe(con.data, labels=labels)
881881

882882
def sanitize_infinities(self) -> None:
883883
"""
884884
Replace infinite values in the constraints with a large value.
885885
"""
886886
for name in self:
887-
constraint = self[name]
888-
valid_infinity_values = (
889-
(constraint.sign == LESS_EQUAL) & (constraint.rhs == np.inf)
890-
) | ((constraint.sign == GREATER_EQUAL) & (constraint.rhs == -np.inf))
891-
self[name].data["labels"] = self[name].labels.where(
892-
~valid_infinity_values, -1
887+
con = self[name]
888+
valid_infinity_values = ((con.sign == LESS_EQUAL) & (con.rhs == np.inf)) | (
889+
(con.sign == GREATER_EQUAL) & (con.rhs == -np.inf)
893890
)
891+
labels = con.labels.where(~valid_infinity_values, -1)
892+
con._data = assign_multiindex_safe(con.data, labels=labels)
894893

895894
def get_name_by_label(self, label: Union[int, float]) -> str:
896895
"""

0 commit comments

Comments
 (0)