Skip to content

Commit 221840b

Browse files
committed
fix: lp-polars api
1 parent e01cfed commit 221840b

File tree

1 file changed

+63
-15
lines changed

1 file changed

+63
-15
lines changed

linopy/io.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
concat_kwargs = dict(dim=CONCAT_DIM, coords="minimal")
4141

4242
TQDM_COLOR = "#80bfff"
43+
COEFF_THRESHOLD = 1e-12
4344

4445

4546
def handle_batch(batch: list[str], f: TextIOWrapper, batch_size: int) -> list[str]:
@@ -550,8 +551,10 @@ def objective_to_file_polars(
550551
f.write(f"{sense}\n\nobj:\n\n".encode())
551552
df = m.objective.to_polars()
552553

554+
# Filter out zero coefficients like the regular LP version does
553555
if m.is_linear:
554-
objective_write_linear_terms_polars(f, df, print_variable)
556+
df_filtered = df.filter(pl.col("coeffs").abs() > COEFF_THRESHOLD)
557+
objective_write_linear_terms_polars(f, df_filtered, print_variable)
555558

556559
elif m.is_quadratic:
557560
linear_terms = df.filter(pl.col("vars1").eq(-1) | pl.col("vars2").eq(-1))
@@ -561,9 +564,13 @@ def objective_to_file_polars(
561564
.otherwise(pl.col("vars1"))
562565
.alias("vars")
563566
)
567+
# Filter out zero coefficients
568+
linear_terms = linear_terms.filter(pl.col("coeffs").abs() > COEFF_THRESHOLD)
564569
objective_write_linear_terms_polars(f, linear_terms, print_variable)
565570

566571
quads = df.filter(pl.col("vars1").ne(-1) & pl.col("vars2").ne(-1))
572+
# Filter out zero coefficients
573+
quads = quads.filter(pl.col("coeffs").abs() > COEFF_THRESHOLD)
567574
objective_write_quadratic_terms_polars(f, quads, print_variable)
568575

569576

@@ -731,28 +738,69 @@ def constraints_to_file_polars(
731738
for con_slice in con.iterate_slices(slice_size):
732739
df = con_slice.to_polars()
733740

734-
# df = df.lazy()
735-
# filter out repeated label values
736-
df = df.with_columns(
737-
pl.when(pl.col("labels").is_first_distinct())
738-
.then(pl.col("labels"))
739-
.otherwise(pl.lit(None))
740-
.alias("labels")
741+
# Filter out rows with zero coefficients or invalid variables - but KEEP RHS rows
742+
# RHS rows have null coeffs/vars but contain the constraint sign/rhs
743+
df = df.filter(
744+
# Keep RHS rows (have sign/rhs but null coeffs/vars)
745+
(pl.col("sign").is_not_null() & pl.col("rhs").is_not_null())
746+
|
747+
# OR keep valid coefficient rows
748+
(
749+
(pl.col("coeffs").abs() > COEFF_THRESHOLD)
750+
& (pl.col("vars").is_not_null())
751+
& (pl.col("vars") >= 0)
752+
)
753+
)
754+
755+
if df.height == 0:
756+
continue
757+
758+
# Ensure each constraint has both coefficient and RHS terms
759+
analysis = df.group_by("labels").agg(
760+
[
761+
pl.col("coeffs").is_not_null().sum().alias("coeff_rows"),
762+
pl.col("sign").is_not_null().sum().alias("rhs_rows"),
763+
]
764+
)
765+
766+
valid = analysis.filter(
767+
(pl.col("coeff_rows") > 0) & (pl.col("rhs_rows") > 0)
768+
)
769+
770+
if valid.height == 0:
771+
continue
772+
773+
df = df.join(valid.select("labels"), on="labels", how="inner")
774+
775+
# Sort by labels for proper grouping and mark first/last occurrences
776+
df = df.sort("labels").with_columns(
777+
[
778+
pl.when(pl.col("labels").is_first_distinct())
779+
.then(pl.col("labels"))
780+
.otherwise(pl.lit(None))
781+
.alias("labels_first"),
782+
(pl.col("labels") != pl.col("labels").shift(-1))
783+
.fill_null(True)
784+
.alias("is_last_in_group"),
785+
]
741786
)
742787

743-
row_labels = print_constraint(pl.col("labels"))
788+
# Build output columns
789+
row_labels = print_constraint(pl.col("labels_first"))
744790
col_labels = print_variable(pl.col("vars"))
745791
columns = [
746-
pl.when(pl.col("labels").is_not_null()).then(row_labels[0]),
747-
pl.when(pl.col("labels").is_not_null()).then(row_labels[1]),
748-
pl.when(pl.col("labels").is_not_null()).then(pl.lit(":\n")).alias(":"),
792+
pl.when(pl.col("labels_first").is_not_null()).then(row_labels[0]),
793+
pl.when(pl.col("labels_first").is_not_null()).then(row_labels[1]),
794+
pl.when(pl.col("labels_first").is_not_null())
795+
.then(pl.lit(":\n"))
796+
.alias(":"),
749797
pl.when(pl.col("coeffs") >= 0).then(pl.lit("+")),
750798
pl.col("coeffs").cast(pl.String),
751799
pl.when(pl.col("vars").is_not_null()).then(col_labels[0]),
752800
pl.when(pl.col("vars").is_not_null()).then(col_labels[1]),
753-
"sign",
754-
pl.lit(" "),
755-
pl.col("rhs").cast(pl.String),
801+
pl.when(pl.col("is_last_in_group")).then(pl.col("sign")),
802+
pl.when(pl.col("is_last_in_group")).then(pl.lit(" ")),
803+
pl.when(pl.col("is_last_in_group")).then(pl.col("rhs").cast(pl.String)),
756804
]
757805

758806
kwargs: Any = dict(

0 commit comments

Comments
 (0)