Skip to content

Commit 7e1df1e

Browse files
authored
feat: add is_pareto (#17)
1 parent b2ba0a7 commit 7e1df1e

File tree

4 files changed

+130
-2
lines changed

4 files changed

+130
-2
lines changed

python/rapidstats/_polars/_numeric.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,57 @@ def sum_horizontal(
169169
raise ValueError(
170170
f"Invalid `null_strategy` {null_method}, must be one of `kleene`, `ignore`, or `propagate`"
171171
)
172+
173+
174+
def is_pareto(*exprs: IntoExpr | Iterable[IntoExpr]) -> pl.Expr:
175+
"""Identifies whether each point lies on the Pareto frontier. A point is considered
176+
Pareto-optimal (non-dominated) if there is no other point that is at least as large
177+
in all dimensions and strictly larger in at least one dimension. All dimensions are
178+
assumed to be maximized.
179+
180+
!!! warning
181+
182+
Currently, only 2 dimensions are supported.
183+
184+
Returns
185+
-------
186+
pl.Expr
187+
A boolean expression indicating whether each row is pareto. Rows where there are
188+
any nulls or NaNs are null.
189+
190+
Examples
191+
--------
192+
``` py
193+
import polars as pl
194+
import rapidstats.polars as prs
195+
196+
df = pl.DataFrame({"x": [5, 1, 3, 2], "y": [1, 5, 3, 2]})
197+
df.select(prs.is_pareto("x", "y"))
198+
```
199+
``` title="output"
200+
shape: (4, 1)
201+
┌───────┐
202+
│ x │
203+
│ --- │
204+
│ bool │
205+
╞═══════╡
206+
│ true │
207+
│ true │
208+
│ true │
209+
│ false │
210+
└───────┘
211+
```
212+
213+
Added in version 0.4.0
214+
----------------------
215+
"""
216+
parsed_exprs = _parse_into_list_of_exprs(*exprs)
217+
218+
if len(parsed_exprs) != 2:
219+
raise NotImplementedError("Only 2 dimensions are currently supported")
220+
221+
return pl.plugins.register_plugin_function(
222+
plugin_path=_PLUGIN_PATH,
223+
function_name="pl_pareto_2d",
224+
args=[e.cast(pl.Float64) for e in parsed_exprs],
225+
)

python/rapidstats/polars.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# ruff: noqa: F401
22

33
from ._polars._format import format
4-
from ._polars._numeric import auc, is_close, sum_horizontal
4+
from ._polars._numeric import auc, is_close, is_pareto, sum_horizontal

src/general.rs

Lines changed: 61 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,15 @@
1+
use std::f64;
2+
13
use polars::prelude::*;
2-
use pyo3_polars::derive::polars_expr;
4+
use pyo3_polars::{
5+
derive::polars_expr,
6+
export::{
7+
polars_arrow::{
8+
array::BooleanArray, bitmap::MutableBitmap, datatypes::ArrowDataType::Boolean,
9+
},
10+
polars_core::utils::Container,
11+
},
12+
};
313

414
pub fn trapezoidal_auc(x: &[f64], y: &[f64]) -> f64 {
515
x.windows(2)
@@ -47,3 +57,53 @@ fn pl_auc(inputs: &[Series]) -> PolarsResult<Series> {
4757

4858
Ok(Series::from_vec("auc".into(), vec![res]))
4959
}
60+
61+
#[polars_expr(output_type=Boolean)]
62+
fn pl_pareto_2d(inputs: &[Series]) -> PolarsResult<Series> {
63+
let x = &inputs[0];
64+
let y = &inputs[1];
65+
66+
let df = df!("x" => x, "y" => y)?
67+
.with_row_index("index".into(), Some(0))?
68+
.sort(
69+
["x", "y"],
70+
SortMultipleOptions::default().with_order_descending_multi([true, true]),
71+
)?;
72+
73+
let index = df["index"].u32()?;
74+
let x_sorted = df["x"].f64()?;
75+
let y_sorted = df["y"].f64()?;
76+
77+
let mut res: Vec<bool> = vec![false; x.len()];
78+
let mut validity = MutableBitmap::with_capacity(x.len());
79+
validity.extend_constant(x.len(), true);
80+
81+
let mut best_y = -f64::INFINITY;
82+
for ((i, x), y) in index.into_no_null_iter().zip(x_sorted).zip(y_sorted) {
83+
let i_u = i as usize;
84+
85+
let (x, y) = match (x, y) {
86+
(Some(x), Some(y)) => (x, y),
87+
_ => {
88+
validity.set(i_u, false);
89+
continue;
90+
}
91+
};
92+
93+
if x.is_nan() || y.is_nan() {
94+
validity.set(i_u, false);
95+
continue;
96+
}
97+
98+
if y > best_y {
99+
res[i_u] = true;
100+
best_y = y;
101+
} else {
102+
res[i_u] = false;
103+
}
104+
}
105+
106+
let arr = BooleanArray::new(Boolean, res.into(), validity.into());
107+
108+
Ok(BooleanChunked::with_chunk(PlSmallStr::EMPTY, arr).into_series())
109+
}

tests/test_polars.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,3 +146,17 @@ def test_sum_horizontal():
146146
).to_series(),
147147
pl.Series("x", [3, None, 1]),
148148
)
149+
150+
151+
def test_is_pareto():
152+
df = pl.DataFrame({"x": [5, 1, 3, 2], "y": [1, 5, 3, 2]}).with_columns(
153+
prs.is_pareto("x", "y").alias("is_pareto")
154+
)
155+
assert df["is_pareto"].to_list() == [True, True, True, False]
156+
157+
158+
def test_is_pareto_with_nulls():
159+
df = pl.DataFrame({"x": [5, 1, 3, 2], "y": [1, 5, None, 2]}).with_columns(
160+
prs.is_pareto("x", "y").alias("is_pareto")
161+
)
162+
assert df["is_pareto"].to_list() == [True, True, None, True]

0 commit comments

Comments
 (0)