Skip to content

Commit 47ab30f

Browse files
author
Robbie Muir
committed
add to_polars improvements
1 parent 1eddef9 commit 47ab30f

File tree

3 files changed

+69
-4
lines changed

3 files changed

+69
-4
lines changed

linopy/expressions.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,11 @@ def __repr__(self) -> str:
441441

442442
return "\n".join(lines)
443443

444+
@property
445+
def is_constant(self) -> bool:
446+
"""This is true if the expression contains no variables"""
447+
return self.data["coeffs"].values.size == 0
448+
444449
def print(self, display_max_rows: int = 20, display_max_terms: int = 20) -> None:
445450
"""
446451
Print the linear expression.
@@ -1439,12 +1444,18 @@ def to_polars(self) -> pl.DataFrame:
14391444
14401445
The resulting DataFrame represents a long table format of the all
14411446
non-masked expressions with non-zero coefficients. It contains the
1442-
columns `coeffs`, `vars`.
1447+
columns `coeffs`, `vars`, `const`. The coeffs and vars columns will be null if the expression is constant.
14431448
14441449
Returns
14451450
-------
14461451
df : polars.DataFrame
14471452
"""
1453+
if self.is_constant:
1454+
df = pl.DataFrame(
1455+
{"const": self.data["const"].values.reshape(-1)}
1456+
).with_columns(pl.lit(None).alias("coeffs"), pl.lit(None).alias("vars"))
1457+
return df.select(["vars", "coeffs", "const"])
1458+
14481459
df = to_polars(self.data)
14491460
df = filter_nulls_polars(df)
14501461
df = group_terms_polars(df)
@@ -1855,12 +1866,22 @@ def to_polars(self, **kwargs: Any) -> pl.DataFrame:
18551866
18561867
The resulting DataFrame represents a long table format of the all
18571868
non-masked expressions with non-zero coefficients. It contains the
1858-
columns `coeffs`, `vars`.
1869+
columns `vars1`, `vars2`, `coeffs`, `const`. If the expression is constant, the `vars1` and `vars2` and `coeffs` columns will be null.
18591870
18601871
Returns
18611872
-------
18621873
df : polars.DataFrame
18631874
"""
1875+
if self.is_constant:
1876+
df = pl.DataFrame(
1877+
{"const": self.data["const"].values.reshape(-1)}
1878+
).with_columns(
1879+
pl.lit(None).alias("coeffs"),
1880+
pl.lit(None).alias("vars1"),
1881+
pl.lit(None).alias("vars2"),
1882+
)
1883+
return df.select(["vars1", "vars2", "coeffs", "const"])
1884+
18641885
vars = self.data.vars.assign_coords(
18651886
{FACTOR_DIM: ["vars1", "vars2"]}
18661887
).to_dataset(FACTOR_DIM)

test/test_linear_expression.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1125,17 +1125,45 @@ def test_linear_expression_from_tuples_bad_calls(
11251125

11261126
def test_linear_expression_from_constant_scalar(m: Model) -> None:
11271127
expr = LinearExpression.from_constant(model=m, constant=10)
1128+
assert expr.is_constant
11281129
assert isinstance(expr, LinearExpression)
11291130
assert (expr.const == 10).all()
11301131

11311132

1132-
def test_linear_expression_from_constant_array(m: Model) -> None:
1133+
def test_linear_expression_from_constant_1D(m: Model) -> None:
11331134
arr = pd.Series(index=pd.Index([0, 1], name="t"), data=[10, 20])
11341135
expr = LinearExpression.from_constant(model=m, constant=arr)
11351136
assert isinstance(expr, LinearExpression)
11361137
assert list(expr.coords.keys())[0] == "t"
11371138
assert expr.nterm == 0
11381139
assert (expr.const.values == [10, 20]).all()
1140+
assert expr.is_constant
1141+
1142+
exp_polars = expr.to_polars()
1143+
assert exp_polars.columns == ["const", "coeffs", "vars"]
1144+
assert exp_polars["const"].to_list() == [10, 20]
1145+
assert exp_polars["coeffs"].to_list() == [None, None]
1146+
assert exp_polars["vars"].to_list() == [None, None]
1147+
1148+
1149+
def test_linear_expression_two_dimensional_from_constant_2D(m: Model) -> None:
1150+
index_a = pd.Index([0, 1], name="a")
1151+
index_b = pd.Index([0, 1, 2], name="b")
1152+
arr = np.array([[10, 20, 30], [40, 50, 60]])
1153+
const = xr.DataArray(data=arr, coords=[index_a, index_b])
1154+
1155+
le_variable = m.add_variables(name="var", coords=[index_a, index_b]) * 1 + const
1156+
assert not le_variable.is_constant
1157+
le_const = LinearExpression.from_constant(model=m, constant=const)
1158+
assert le_const.is_constant
1159+
1160+
var_pol = le_variable.to_polars()
1161+
const_pol = le_const.to_polars()
1162+
assert var_pol.shape == const_pol.shape
1163+
assert var_pol.columns == const_pol.columns
1164+
assert all(const_pol["const"] == var_pol["const"])
1165+
assert all(const_pol["coeffs"].is_null())
1166+
assert all(const_pol["vars"].is_null())
11391167

11401168

11411169
def test_linear_expression_sanitize(x: Variable, y: Variable, z: Variable) -> None:

test/test_quadratic_expression.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def test_quadratic_expression_flat(x: Variable, y: Variable) -> None:
287287
assert len(expr.flat) == 2
288288

289289

290-
def test_linear_expression_to_polars(x: Variable, y: Variable) -> None:
290+
def test_quadratic_expression_to_polars(x: Variable, y: Variable) -> None:
291291
expr = x * y + x + 5
292292
df = expr.to_polars()
293293
assert isinstance(df, pl.DataFrame)
@@ -296,6 +296,22 @@ def test_linear_expression_to_polars(x: Variable, y: Variable) -> None:
296296
assert len(df) == expr.nterm * 2
297297

298298

299+
def test_quadratic_expression_constant_to_polars() -> None:
300+
m = Model()
301+
arr = pd.Series(index=pd.Index([0, 1], name="t"), data=[10, 20])
302+
lin_expr = LinearExpression.from_constant(model=m, constant=arr)
303+
quad_expr = lin_expr.to_quadexpr()
304+
305+
assert quad_expr.is_constant
306+
df = quad_expr.to_polars()
307+
assert isinstance(df, pl.DataFrame)
308+
assert df.columns == ["vars1", "vars2", "coeffs", "const"]
309+
assert all(df["vars1"].is_null())
310+
assert all(df["vars2"].is_null())
311+
assert all(df["coeffs"].is_null())
312+
assert all(arr.to_numpy() == df["const"].to_numpy())
313+
314+
299315
def test_quadratic_expression_to_matrix(model: Model, x: Variable, y: Variable) -> None:
300316
expr: QuadraticExpression = x * y + x + 5 # type: ignore
301317

0 commit comments

Comments
 (0)