Skip to content

Commit 1c46210

Browse files
committed
More stuff
1 parent 7bbd342 commit 1c46210

File tree

4 files changed

+59
-55
lines changed

4 files changed

+59
-55
lines changed

src/singledispatch/core.rs

Lines changed: 30 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -80,32 +80,26 @@ impl SingleDispatchState {
8080
mro_match = Some(typ.clone_ref(py));
8181
}
8282

83-
match mro_match {
84-
Some(m) => {
85-
let m = &m.clone_ref(py);
86-
if self.registry.contains_key(typ)
87-
&& !cls_mro.contains(typ)
88-
&& !cls_mro.contains(m)
89-
&& Builtins::cached(py)
90-
.issubclass(py, m.wrapped().bind(py), typ.wrapped().bind(py))
91-
.is_ok_and(|res| res)
92-
{
93-
return Err(PyRuntimeError::new_err(format!(
94-
"Ambiguous dispatch: {m} or {typ}"
95-
)));
96-
}
97-
mro_match = Some(m.clone_ref(py));
98-
eprintln!("MRO match: {m}");
99-
break;
83+
if let Some(m) = mro_match {
84+
let m = &m.clone_ref(py);
85+
if self.registry.contains_key(typ)
86+
&& !cls_mro.contains(typ)
87+
&& !cls_mro.contains(m)
88+
&& Builtins::cached(py)
89+
.issubclass(py, m.wrapped().bind(py), typ.wrapped().bind(py))
90+
.is_ok_and(|res| res)
91+
{
92+
return Err(PyRuntimeError::new_err(format!(
93+
"Ambiguous dispatch: {m} or {typ}"
94+
)));
10095
}
101-
_ => {}
96+
mro_match = Some(m.clone_ref(py));
97+
eprintln!("MRO match: {m}");
98+
break;
10299
}
103100
}
104101
let impl_fn = match mro_match {
105-
Some(v) => match self.registry.get(&v) {
106-
Some(&ref it) => Some(it.clone_ref(py)),
107-
None => None,
108-
},
102+
Some(v) => self.registry.get(&v).map(|it| it.clone_ref(py)),
109103
None => None,
110104
};
111105
match impl_fn {
@@ -178,10 +172,12 @@ impl SingleDispatch {
178172
unbound_func.clone_ref(py),
179173
);
180174
}
181-
if state.cache_token.is_none() {
182-
if let Ok(_) = unbound_func.getattr(py, intern!(py, "__abstractmethods__")) {
183-
state.cache_token = Some(get_abc_cache_token(py)?.unbind());
184-
}
175+
if state.cache_token.is_none()
176+
&& unbound_func
177+
.getattr(py, intern!(py, "__abstractmethods__"))
178+
.is_ok()
179+
{
180+
state.cache_token = Some(get_abc_cache_token(py)?.unbind());
185181
}
186182
state.cache.clear();
187183
Ok(unbound_func)
@@ -253,18 +249,15 @@ impl SingleDispatch {
253249
eprintln!("Dispatching");
254250
match self.lock.lock() {
255251
Ok(mut state) => {
256-
match &state.cache_token {
257-
Some(cache_token) => {
258-
let current_token = get_abc_cache_token(py)?;
259-
match current_token.rich_compare(cache_token.bind(py), CompareOp::Eq) {
260-
Ok(_) => {
261-
state.cache.clear();
262-
state.cache_token = Some(current_token.unbind());
263-
}
264-
_ => (),
265-
}
252+
if let Some(cache_token) = &state.cache_token {
253+
let current_token = get_abc_cache_token(py)?;
254+
if current_token
255+
.rich_compare(cache_token.bind(py), CompareOp::Eq)
256+
.is_ok()
257+
{
258+
state.cache.clear();
259+
state.cache_token = Some(current_token.unbind());
266260
}
267-
_ => (),
268261
}
269262

270263
state.get_or_find_impl(py, cls)

src/singledispatch/mro.rs

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,7 @@ fn find_merge_candidate(py: Python, seqs: &[&mut Vec<PyTypeReference>]) -> Optio
6060
break;
6161
}
6262
}
63-
match candidate {
64-
Some(c) => Some(c.clone_ref(py)),
65-
None => None,
66-
}
63+
candidate.map(|c| c.clone_ref(py))
6764
}
6865

6966
struct C3Mro<'a> {
@@ -146,28 +143,28 @@ fn c3_mro(
146143
) -> PyResult<Vec<PyTypeReference>> {
147144
let bases = match get_obj_bases(cls) {
148145
Ok(b) => {
149-
if b.len() > 0 {
146+
if !b.is_empty() {
150147
b
151148
} else {
152149
return Ok(Vec::new());
153150
}
154151
}
155152
Err(e) => return Err(e),
156153
};
154+
eprintln!("bases = {bases:#?}");
157155
let boundary = c3_boundary(py, &bases)?;
158156
eprintln!("boundary = {boundary}");
159-
let base = &bases[boundary];
160157

161158
let (explicit_bases, other_bases) = bases.split_at(boundary);
162159
let abstract_bases: Vec<_> = abcs
163160
.iter()
164161
.flat_map(|abc| {
165162
if Builtins::cached(py)
166-
.issubclass(py, cls, base.wrapped().bind(py))
163+
.issubclass(py, cls, abc.wrapped().bind(py))
167164
.unwrap()
168165
&& !bases.iter().any(|b| {
169166
Builtins::cached(py)
170-
.issubclass(py, b.wrapped().bind(py), base.wrapped().bind(py))
167+
.issubclass(py, b.wrapped().bind(py), abc.wrapped().bind(py))
171168
.unwrap()
172169
})
173170
{
@@ -177,6 +174,9 @@ fn c3_mro(
177174
}
178175
})
179176
.collect();
177+
eprintln!("explict_bases = {explicit_bases:#?}");
178+
eprintln!("other_bases = {other_bases:#?}");
179+
eprintln!("abstract_bases = {abstract_bases:#?}");
180180

181181
let new_abcs: Vec<_> = abcs.iter().filter(|c| abstract_bases.contains(c)).collect();
182182

@@ -189,6 +189,7 @@ fn c3_mro(
189189
mros.extend(&mut explicit_bases_mro);
190190

191191
let mut abstract_bases_mro = sub_c3_mro(py, abstract_bases.iter().map(|v| *v), &new_abcs)?;
192+
eprintln!("abstract_bases_mro = {abstract_bases_mro:#?}");
192193
mros.extend(&mut abstract_bases_mro);
193194

194195
let mut other_bases_mro = sub_c3_mro(py, other_bases.iter(), &new_abcs)?;
@@ -215,7 +216,9 @@ pub(crate) fn compose_mro(
215216
let typing = TypingModule::cached(py);
216217

217218
let bases: HashSet<_> = get_obj_mro(&cls)?;
219+
eprintln!("bases = {bases:#?}");
218220
let registered_types: HashSet<_> = types.collect();
221+
eprintln!("registered_types = {registered_types:#?}");
219222
let eligible_types: HashSet<_> = registered_types
220223
.iter()
221224
.filter(|&tref| {
@@ -236,8 +239,9 @@ pub(crate) fn compose_mro(
236239
*tref != other && other_mro.contains(tref)
237240
})
238241
})
239-
.map(|tref| *tref)
242+
.copied()
240243
.collect();
244+
eprintln!("eligible_types = {eligible_types:#?}");
241245
let mut mro: Vec<PyTypeReference> = Vec::new();
242246
eligible_types.iter().for_each(|&tref| {
243247
// Subclasses of the ABCs in *types* which are also implemented by
@@ -247,6 +251,7 @@ pub(crate) fn compose_mro(
247251
.unwrap()
248252
.iter()
249253
.filter(|subclass| {
254+
eprintln!("subclass = {subclass:#?}");
250255
let typ = subclass.wrapped();
251256
let tref = PyTypeReference::new(typ.clone_ref(py));
252257
!bases.contains(&tref)
@@ -270,12 +275,15 @@ pub(crate) fn compose_mro(
270275
} else {
271276
found_subclasses.sort_by_key(|s| Reverse(s.len()));
272277
found_subclasses.iter().flatten().for_each(|tref| {
273-
if !mro.contains(&tref) {
278+
if !mro.contains(tref) {
274279
mro.push(tref.clone_ref(py));
275280
}
276281
});
277282
}
278283
});
284+
eprintln!("Pre-mro candidates {mro:#?}");
279285

280-
c3_mro(py, &cls, mro)
286+
let final_rmo = c3_mro(py, &cls, mro);
287+
eprintln!("MRO for {cls}: {final_rmo:#?}");
288+
final_rmo
281289
}

src/singledispatch/typeref.rs

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
use pyo3::{PyObject, Python};
2-
use std::fmt::{Display, Formatter};
2+
use std::fmt::{Debug, Display, Formatter};
33
use std::hash::{Hash, Hasher};
44

55
pub struct PyTypeReference {
@@ -22,6 +22,12 @@ impl PyTypeReference {
2222
}
2323
}
2424

25+
impl Debug for PyTypeReference {
26+
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
27+
std::fmt::Display::fmt(&self.wrapped, f)
28+
}
29+
}
30+
2531
impl Display for PyTypeReference {
2632
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
2733
std::fmt::Display::fmt(&self.wrapped, f)
@@ -38,10 +44,6 @@ impl PartialEq for PyTypeReference {
3844
fn eq(&self, other: &Self) -> bool {
3945
self.wrapped.is(&other.wrapped)
4046
}
41-
42-
fn ne(&self, other: &Self) -> bool {
43-
!self.wrapped.is(&other.wrapped)
44-
}
4547
}
4648

4749
impl Eq for PyTypeReference {}

tests/test_singledispatch_native.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from collections.abc import Sequence
22

33
import pytest
4+
#from functools import singledispatch
45
from singledispatch_native import singledispatch
56

67
from typing import Any
@@ -16,7 +17,7 @@ def _some_fun_str(o: str) -> str:
1617

1718

1819
@some_fun.register(int)
19-
def _some_fun_str(o: int) -> str:
20+
def _some_fun_int(o: int) -> str:
2021
return "It's an int!"
2122

2223

@@ -26,7 +27,7 @@ def _some_fun_sequence(l: Sequence) -> str:
2627

2728

2829
@some_fun.register(tuple)
29-
def _some_fun_sequence(l: tuple) -> str:
30+
def _some_fun_tuple(l: tuple) -> str:
3031
return "tuple: " + ", ".join(l)
3132

3233

0 commit comments

Comments
 (0)