Skip to content

Commit 0d966c7

Browse files
authored
tests: test different optimization directions for pareto (#21)
1 parent b869d1d commit 0d966c7

File tree

1 file changed

+34
-0
lines changed

1 file changed

+34
-0
lines changed

tests/test_polars.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,3 +160,37 @@ def test_is_pareto_with_nulls():
160160
prs.is_pareto("x", "y").alias("is_pareto")
161161
)
162162
assert df["is_pareto"].to_list() == [True, True, None, True]
163+
164+
165+
def test_is_pareto_min_min():
166+
df = pl.DataFrame(
167+
{
168+
"bias": [1.0, 0.5, 0.5],
169+
"bad_rate": [0.01, 0.02, 0.03],
170+
}
171+
).with_columns(
172+
prs.is_pareto(pl.col("bias").mul(-1), pl.col("bad_rate").mul(-1)).alias(
173+
"is_pareto"
174+
)
175+
)
176+
177+
assert df["is_pareto"].to_list() == [True, True, False]
178+
179+
180+
def test_is_pareto_min_max():
181+
df = pl.DataFrame({"bias": [0.6, 0.5, 0.5], "auc": [0.7, 0.7, 0.6]}).with_columns(
182+
prs.is_pareto(pl.col("bias").mul(-1), "auc").alias("is_pareto")
183+
)
184+
185+
assert df["is_pareto"].to_list() == [False, True, False]
186+
187+
188+
def test_is_pareto_max_min():
189+
df = pl.DataFrame(
190+
{
191+
"air": [0.5, 0.5, 0.6],
192+
"bad_rate": [0.1, 0.05, 0.3],
193+
}
194+
).with_columns(prs.is_pareto(pl.col("bad_rate").mul(-1), "air").alias("is_pareto"))
195+
196+
assert df["is_pareto"].to_list() == [False, True, True]

0 commit comments

Comments
 (0)