Skip to content

Commit 0a2c853

Browse files
committed
Update pyo3 and rust-numpy
1 parent ca1dba6 commit 0a2c853

File tree

3 files changed

+20
-27
lines changed

3 files changed

+20
-27
lines changed

Cargo.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,10 @@ crate-type = ["cdylib"]
1414

1515
[dependencies]
1616
approx = "^0.5.1"
17-
ndarray = "^0.16.1"
17+
ndarray = "^0.16.1" # numpy supports only >= 0.15, < 0.17
1818
num-traits = "^0.2.19"
19-
pyo3 = { version = "^0.22.3", features = ["extension-module", "abi3-py38"], optional = true }
20-
numpy = { version = "^0.22.0", optional = true }
19+
pyo3 = { version = "^0.23.2", features = ["extension-module", "abi3-py38"], optional = true }
20+
numpy = { version = "^0.23.0", optional = true }
2121

2222
[features]
2323
python = ["dep:pyo3", "dep:numpy"]

src/bradley_terry.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,10 @@ pub fn newman(
8585

8686
v = one_nan_to_num(v_new, tolerance);
8787

88-
let broadcast_scores_t = scores.clone().into_shape_with_order((1, scores.len())).unwrap();
88+
let broadcast_scores_t = scores
89+
.clone()
90+
.into_shape_with_order((1, scores.len()))
91+
.unwrap();
8992
let sqrt_scores_outer =
9093
(&broadcast_scores_t * &broadcast_scores_t.t()).mapv_into(f64::sqrt);
9194
let sum_scores = &broadcast_scores_t + &broadcast_scores_t.t();

src/python.rs

Lines changed: 13 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -39,13 +39,12 @@ unsafe impl Element for Winner {
3939
Clone::clone(self)
4040
}
4141

42-
fn get_dtype_bound(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
43-
numpy::dtype_bound::<u8>(py)
42+
fn get_dtype(py: Python<'_>) -> Bound<'_, PyArrayDescr> {
43+
numpy::dtype::<u8>(py)
4444
}
4545
}
4646

4747
create_exception!(evalica, LengthMismatchError, PyValueError);
48-
4948
#[pyfunction]
5049
fn matrices_pyo3<'py>(
5150
py: Python<'py>,
@@ -63,8 +62,8 @@ fn matrices_pyo3<'py>(
6362
total,
6463
) {
6564
Ok((wins, ties)) => Ok((
66-
wins.into_pyarray_bound(py).unbind(),
67-
ties.into_pyarray_bound(py).unbind(),
65+
wins.into_pyarray(py).unbind(),
66+
ties.into_pyarray(py).unbind(),
6867
)),
6968
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
7069
}
@@ -77,7 +76,7 @@ fn pairwise_scores_pyo3<'py>(
7776
) -> PyResult<Py<PyArray2<f64>>> {
7877
let pairwise = pairwise_scores(&scores.as_array());
7978

80-
Ok(pairwise.into_pyarray_bound(py).unbind())
79+
Ok(pairwise.into_pyarray(py).unbind())
8180
}
8281

8382
#[pyfunction]
@@ -100,7 +99,7 @@ fn counting_pyo3<'py>(
10099
win_weight,
101100
tie_weight,
102101
) {
103-
Ok(scores) => Ok(scores.into_pyarray_bound(py).unbind()),
102+
Ok(scores) => Ok(scores.into_pyarray(py).unbind()),
104103
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
105104
}
106105
}
@@ -125,7 +124,7 @@ fn average_win_rate_pyo3<'py>(
125124
win_weight,
126125
tie_weight,
127126
) {
128-
Ok(scores) => Ok(scores.into_pyarray_bound(py).unbind()),
127+
Ok(scores) => Ok(scores.into_pyarray(py).unbind()),
129128
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
130129
}
131130
}
@@ -160,9 +159,7 @@ fn bradley_terry_pyo3<'py>(
160159
);
161160

162161
match bradley_terry(&matrix.view(), tolerance, limit) {
163-
Ok((scores, iterations)) => {
164-
Ok((scores.into_pyarray_bound(py).unbind(), iterations))
165-
}
162+
Ok((scores, iterations)) => Ok((scores.into_pyarray(py).into(), iterations)),
166163
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
167164
}
168165
}
@@ -206,7 +203,7 @@ fn newman_pyo3<'py>(
206203
limit,
207204
) {
208205
Ok((scores, v, iterations)) => {
209-
Ok((scores.into_pyarray_bound(py).unbind(), v, iterations))
206+
Ok((scores.into_pyarray(py).unbind(), v, iterations))
210207
}
211208
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
212209
}
@@ -243,7 +240,7 @@ fn elo_pyo3<'py>(
243240
win_weight,
244241
tie_weight,
245242
) {
246-
Ok(scores) => Ok(scores.into_pyarray_bound(py).unbind()),
243+
Ok(scores) => Ok(scores.into_pyarray(py).unbind()),
247244
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
248245
}
249246
}
@@ -278,9 +275,7 @@ fn eigen_pyo3<'py>(
278275
);
279276

280277
match eigen(&matrix.view(), tolerance, limit) {
281-
Ok((scores, iterations)) => {
282-
Ok((scores.into_pyarray_bound(py).unbind(), iterations))
283-
}
278+
Ok((scores, iterations)) => Ok((scores.into_pyarray(py).unbind(), iterations)),
284279
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
285280
}
286281
}
@@ -319,9 +314,7 @@ fn pagerank_pyo3<'py>(
319314
);
320315

321316
match pagerank(&matrix.view(), damping, tolerance, limit) {
322-
Ok((scores, iterations)) => {
323-
Ok((scores.into_pyarray_bound(py).unbind(), iterations))
324-
}
317+
Ok((scores, iterations)) => Ok((scores.into_pyarray(py).unbind(), iterations)),
325318
Err(_) => Err(LengthMismatchError::new_err("mismatching input shapes")),
326319
}
327320
}
@@ -332,10 +325,7 @@ fn pagerank_pyo3<'py>(
332325
#[pymodule]
333326
fn evalica(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
334327
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
335-
m.add(
336-
"LengthMismatchError",
337-
py.get_type_bound::<LengthMismatchError>(),
338-
)?;
328+
m.add("LengthMismatchError", py.get_type::<LengthMismatchError>())?;
339329
m.add_function(wrap_pyfunction!(matrices_pyo3, m)?)?;
340330
m.add_function(wrap_pyfunction!(pairwise_scores_pyo3, m)?)?;
341331
m.add_function(wrap_pyfunction!(counting_pyo3, m)?)?;

0 commit comments

Comments
 (0)