Skip to content

Commit 1922c28

Browse files
committed
C3 algorithm
1 parent d31e463 commit 1922c28

File tree

2 files changed

+219
-18
lines changed

2 files changed

+219
-18
lines changed

src/singledispatch/core.rs

Lines changed: 39 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,16 @@ use pyo3::prelude::*;
88
use crate::singledispatch::builtins::Builtins;
99
use pyo3::types::{PyDict, PyTuple, PyType};
1010
use pyo3::{
11-
pyclass, pyfunction, pymethods, Bound, IntoPyObjectExt, Py, PyAny, PyObject, PyResult, Python,
11+
intern, pyclass, pyfunction, pymethods, Bound, IntoPyObjectExt, Py, PyAny, PyObject, PyResult,
12+
Python,
1213
};
1314
use std::collections::HashMap;
1415
use std::sync::Mutex;
1516

1617
fn get_abc_cache_token(py: Python) -> Bound<'_, PyAny> {
17-
py.import("abc")
18+
py.import(intern!(py, "abc"))
1819
.unwrap()
19-
.getattr("get_cache_token")
20+
.getattr(intern!(py, "get_cache_token"))
2021
.unwrap()
2122
.call0()
2223
.unwrap()
@@ -74,8 +75,12 @@ struct SingleDispatchState {
7475
impl SingleDispatchState {
7576
fn find_impl(&mut self, py: Python, cls: Bound<'_, PyAny>) -> PyResult<PyObject> {
7677
let cls_mro = get_obj_mro(&cls.clone()).unwrap();
77-
let mro = compose_mro(py, cls.clone(), self.registry.keys())?;
78+
let mro = match compose_mro(py, cls.clone(), self.registry.keys()) {
79+
Ok(mro) => mro,
80+
Err(e) => return Err(e),
81+
};
7882
let mut mro_match: Option<PyTypeReference> = None;
83+
eprintln!("Finding impl for {cls}");
7984
for typ in mro.iter() {
8085
if self.registry.contains_key(typ) {
8186
mro_match = Some(typ.clone_ref(py));
@@ -95,31 +100,50 @@ impl SingleDispatchState {
95100
)));
96101
}
97102
mro_match = Some(m.clone_ref(py));
103+
eprintln!("MRO match: {m}");
98104
break;
99105
}
100106
}
101-
match mro_match {
107+
let impl_fn = match mro_match {
102108
Some(_) => match self.registry.get(&mro_match.unwrap()) {
103-
Some(&ref it) => Ok(it.clone_ref(py)),
104-
None => Ok(py.None()),
109+
Some(&ref it) => Some(it.clone_ref(py)),
110+
None => None,
105111
},
106-
None => Ok(py.None()),
112+
None => None,
113+
};
114+
match impl_fn {
115+
Some(f) => Ok(f),
116+
None => {
117+
let obj_type = PyTypeReference::new(Builtins::cached(py).object_type.clone_ref(py));
118+
eprintln!("Found impl for {cls}: {obj_type}");
119+
match self.registry.get(&obj_type) {
120+
Some(it) => Ok(it.clone_ref(py)),
121+
None => Err(PyRuntimeError::new_err(format!(
122+
"No dispatch function found for {cls}!"
123+
))),
124+
}
125+
}
107126
}
108127
}
109128

110129
fn get_or_find_impl(&mut self, py: Python, cls: Bound<'_, PyAny>) -> PyResult<PyObject> {
111130
let free_cls = cls.unbind();
112131
let type_reference = PyTypeReference::new(free_cls.clone_ref(py));
132+
eprintln!("Finding impl {type_reference}");
113133

114134
match self.cache.get(&type_reference) {
115135
Some(handler) => Ok(handler.clone_ref(py)),
116136
None => {
117137
let handler_for_cls = match self.registry.get(&type_reference) {
118138
Some(handler) => handler.clone_ref(py),
119-
None => self.find_impl(py, free_cls.bind(py).clone())?,
139+
None => match self.find_impl(py, free_cls.bind(py).clone()) {
140+
Ok(handler) => handler,
141+
Err(e) => return Err(e),
142+
},
120143
};
121144
self.cache
122145
.insert(type_reference, handler_for_cls.clone_ref(py));
146+
eprintln!("Found new handler {handler_for_cls}");
123147
Ok(handler_for_cls)
124148
}
125149
}
@@ -161,7 +185,7 @@ impl SingleDispatch {
161185
);
162186
}
163187
if state.cache_token.is_none() {
164-
if let Ok(_) = unbound_func.getattr(py, "__abstractmethods__") {
188+
if let Ok(_) = unbound_func.getattr(py, intern!(py, "__abstractmethods__")) {
165189
state.cache_token = Some(get_abc_cache_token(py).unbind());
166190
}
167191
}
@@ -178,7 +202,7 @@ impl SingleDispatch {
178202
cls: Bound<'_, PyAny>,
179203
func: Bound<'_, PyAny>,
180204
) -> PyResult<PyObject> {
181-
match func.getattr("__annotations__") {
205+
match func.getattr(intern!(_py, "__annotations__")) {
182206
Ok(_annotations) => Err(PyNotImplementedError::new_err("Oops!")),
183207
Err(_) => Err(PyTypeError::new_err(
184208
format!("Invalid first argument to `register()`: {cls}. Use either `@register(some_class)` or plain `@register` on an annotated function."),
@@ -213,22 +237,24 @@ impl SingleDispatch {
213237
args: &Bound<'_, PyTuple>,
214238
kwargs: Option<&Bound<'_, PyDict>>,
215239
) -> PyResult<Py<PyAny>> {
216-
match obj.getattr("__class__") {
240+
eprintln!("Calling");
241+
match obj.getattr(intern!(py, "__class__")) {
217242
Ok(cls) => {
218243
let mut all_args = Vec::with_capacity(1 + args.len());
219244
all_args.insert(0, obj);
220245
all_args.extend(args);
221246

222247
match self.dispatch(py, cls) {
223248
Ok(handler) => handler.call(py, PyTuple::new(py, all_args)?, kwargs),
224-
Err(_) => panic!("no handler for singledispatch"),
249+
Err(e) => Err(e),
225250
}
226251
}
227252
Err(_) => Err(PyTypeError::new_err("expected __class__ attribute for obj")),
228253
}
229254
}
230255

231256
fn dispatch(&self, py: Python<'_>, cls: Bound<'_, PyAny>) -> PyResult<PyObject> {
257+
eprintln!("Dispatching");
232258
match self.lock.lock() {
233259
Ok(mut state) => {
234260
match &state.cache_token {

src/singledispatch/mro.rs

Lines changed: 180 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::singledispatch::builtins::Builtins;
22
use crate::singledispatch::typeref::PyTypeReference;
33
use crate::singledispatch::typing::TypingModule;
4+
use pyo3::exceptions::PyRuntimeError;
45
use pyo3::prelude::*;
56
use pyo3::types::PyTuple;
67
use pyo3::{intern, Bound, PyObject, PyResult, Python};
@@ -18,18 +19,192 @@ pub(crate) fn get_obj_mro(cls: &Bound<'_, PyAny>) -> PyResult<HashSet<PyTypeRefe
1819
Ok(mro)
1920
}
2021

22+
fn get_obj_bases(cls: &Bound<'_, PyAny>) -> PyResult<Vec<PyTypeReference>> {
23+
match cls.getattr_opt(intern!(cls.py(), "__bases__")) {
24+
Ok(opt) => match opt {
25+
Some(b) => Ok(b
26+
.downcast::<PyTuple>()?
27+
.iter()
28+
.map(|item| PyTypeReference::new(item.unbind()))
29+
.collect()),
30+
None => Ok(Vec::new()),
31+
},
32+
Err(e) => Err(e),
33+
}
34+
}
35+
2136
fn get_obj_subclasses(cls: &Bound<'_, PyAny>) -> PyResult<HashSet<PyTypeReference>> {
22-
let mro: HashSet<_> = cls
37+
let subclasses: HashSet<_> = cls
2338
.call_method0(intern!(cls.py(), "__subclasses__"))?
2439
.downcast::<PyTuple>()?
2540
.iter()
2641
.map(|item| PyTypeReference::new(item.unbind()))
2742
.collect();
28-
Ok(mro)
43+
Ok(subclasses)
44+
}
45+
46+
fn find_merge_candidate(py: Python, seqs: &[&mut Vec<PyTypeReference>]) -> Option<PyTypeReference> {
47+
let mut candidate: Option<&PyTypeReference> = None;
48+
for i1 in 0..seqs.len() {
49+
let s1 = &seqs[i1];
50+
candidate = Some(&s1[0]);
51+
for i2 in 0..seqs.len() {
52+
let s2 = &seqs[i2];
53+
if s2[1..].contains(candidate.unwrap()) {
54+
candidate = None;
55+
break;
56+
}
57+
}
58+
if candidate.is_some() {
59+
break;
60+
}
61+
}
62+
match candidate {
63+
Some(c) => Some(c.clone_ref(py)),
64+
None => None,
65+
}
66+
}
67+
68+
struct C3Mro<'a> {
69+
seqs: &'a mut Vec<&'a mut Vec<PyTypeReference>>,
2970
}
3071

31-
fn c3_mro(py: Python, cls: Bound<'_, PyAny>, abcs: Vec<PyTypeReference>) -> PyResult<Vec<PyTypeReference>> {
32-
Ok(abcs)
72+
impl C3Mro<'_> {
73+
fn for_abcs<'a>(
74+
py: Python,
75+
abcs: &'a mut Vec<&'a mut Vec<PyTypeReference>>,
76+
) -> PyResult<Vec<PyTypeReference>> {
77+
C3Mro { seqs: abcs }.merge(py)
78+
}
79+
80+
fn merge(&mut self, py: Python) -> PyResult<Vec<PyTypeReference>> {
81+
let mut result: Vec<PyTypeReference> = Vec::new();
82+
loop {
83+
let seqs = &mut self.seqs;
84+
seqs.retain(|seq| !seq.is_empty());
85+
if seqs.is_empty() {
86+
return Ok(result);
87+
}
88+
match find_merge_candidate(py, seqs.as_slice()) {
89+
Some(c) => {
90+
for i in 0..seqs.len() {
91+
let seq = &mut self.seqs[i];
92+
if seq[0].eq(&c) {
93+
seq.remove(0);
94+
}
95+
}
96+
result.push(c);
97+
}
98+
None => return Err(PyRuntimeError::new_err("Inconsistent hierarchy")),
99+
}
100+
}
101+
}
102+
}
103+
104+
fn c3_boundary(py: Python, bases: &[PyTypeReference]) -> usize {
105+
let mut boundary = 0;
106+
107+
for (i, base) in bases.iter().rev().enumerate() {
108+
if base
109+
.wrapped()
110+
.bind(py)
111+
.hasattr(intern!(py, "__abstractmethods__"))
112+
.unwrap()
113+
{
114+
boundary = bases.len() - i;
115+
break;
116+
}
117+
}
118+
119+
boundary
120+
}
121+
122+
fn c3_mro(
123+
py: Python,
124+
cls: &Bound<'_, PyAny>,
125+
abcs: Vec<PyTypeReference>,
126+
) -> PyResult<Vec<PyTypeReference>> {
127+
let bases = match get_obj_bases(cls) {
128+
Ok(b) => {
129+
if b.len() > 0 {
130+
b
131+
} else {
132+
return Ok(Vec::new());
133+
}
134+
}
135+
Err(e) => return Err(e),
136+
};
137+
let boundary = c3_boundary(py, &bases);
138+
eprintln!("boundary = {boundary}");
139+
let base = &bases[boundary];
140+
141+
let (explicit_bases, other_bases) = bases.split_at(boundary);
142+
let abstract_bases: Vec<_> = abcs
143+
.iter()
144+
.flat_map(|abc| {
145+
if Builtins::cached(py)
146+
.issubclass(py, cls, base.wrapped().bind(py))
147+
.unwrap()
148+
&& !bases.iter().any(|b| {
149+
Builtins::cached(py)
150+
.issubclass(py, b.wrapped().bind(py), base.wrapped().bind(py))
151+
.unwrap()
152+
})
153+
{
154+
vec![abc]
155+
} else {
156+
vec![]
157+
}
158+
})
159+
.collect();
160+
161+
let new_abcs: Vec<_> = abcs.iter().filter(|c| abstract_bases.contains(c)).collect();
162+
163+
let mut mros: Vec<&mut Vec<PyTypeReference>> = Vec::new();
164+
165+
let mut cls_ref = vec![PyTypeReference::new(cls.clone().unbind())];
166+
mros.push(&mut cls_ref);
167+
168+
let mut explicit_bases_mro = Vec::from_iter(explicit_bases.iter().map(|b| {
169+
c3_mro(
170+
py,
171+
b.wrapped().bind(py),
172+
new_abcs.iter().map(|abc| abc.clone_ref(py)).collect(),
173+
)
174+
.unwrap()
175+
}));
176+
mros.extend(&mut explicit_bases_mro);
177+
178+
let mut abstract_bases_mro = Vec::from_iter(abstract_bases.iter().map(|b| {
179+
c3_mro(
180+
py,
181+
b.wrapped().bind(py),
182+
new_abcs.iter().map(|abc| abc.clone_ref(py)).collect(),
183+
)
184+
.unwrap()
185+
}));
186+
mros.extend(&mut abstract_bases_mro);
187+
188+
let mut other_bases_mro = Vec::from_iter(other_bases.iter().map(|b| {
189+
c3_mro(
190+
py,
191+
b.wrapped().bind(py),
192+
new_abcs.iter().map(|abc| abc.clone_ref(py)).collect(),
193+
)
194+
.unwrap()
195+
}));
196+
mros.extend(&mut other_bases_mro);
197+
198+
let mut explicit_bases_cloned = Vec::from_iter(explicit_bases.iter().map(|b| b.clone_ref(py)));
199+
mros.push(&mut explicit_bases_cloned);
200+
201+
let mut abstract_bases_cloned = Vec::from_iter(abstract_bases.iter().map(|b| b.clone_ref(py)));
202+
mros.push(&mut abstract_bases_cloned);
203+
204+
let mut other_bases_cloned = Vec::from_iter(other_bases.iter().map(|b| b.clone_ref(py)));
205+
mros.push(&mut other_bases_cloned);
206+
207+
C3Mro::for_abcs(py, &mut mros)
33208
}
34209

35210
pub(crate) fn compose_mro(
@@ -103,5 +278,5 @@ pub(crate) fn compose_mro(
103278
}
104279
});
105280

106-
c3_mro(py, cls, mro)
281+
c3_mro(py, &cls, mro)
107282
}

0 commit comments

Comments
 (0)