Skip to content

Commit e53addf

Browse files
committed
C3 algorithm
1 parent 9820fc6 commit e53addf

File tree

2 files changed

+185
-3
lines changed

2 files changed

+185
-3
lines changed

src/singledispatch/core.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ impl SingleDispatchState {
7474
let cls_mro = get_obj_mro(&cls.clone())?;
7575
let mro = compose_mro(py, cls.clone(), self.registry.keys())?;
7676
let mut mro_match: Option<PyTypeReference> = None;
77+
eprintln!("Finding impl for {cls}");
7778
for typ in mro.iter() {
7879
if self.registry.contains_key(typ) {
7980
mro_match = Some(typ.clone_ref(py));
@@ -94,6 +95,7 @@ impl SingleDispatchState {
9495
)));
9596
}
9697
mro_match = Some(m.clone_ref(py));
98+
eprintln!("MRO match: {m}");
9799
break;
98100
}
99101
_ => {}
@@ -110,6 +112,7 @@ impl SingleDispatchState {
110112
Some(f) => Ok(f),
111113
None => {
112114
let obj_type = PyTypeReference::new(Builtins::cached(py).object_type.clone_ref(py));
115+
eprintln!("Found impl for {cls}: {obj_type}");
113116
match self.registry.get(&obj_type) {
114117
Some(it) => Ok(it.clone_ref(py)),
115118
None => Err(PyRuntimeError::new_err(format!(
@@ -123,6 +126,7 @@ impl SingleDispatchState {
123126
fn get_or_find_impl(&mut self, py: Python, cls: Bound<'_, PyAny>) -> PyResult<PyObject> {
124127
let free_cls = cls.unbind();
125128
let type_reference = PyTypeReference::new(free_cls.clone_ref(py));
129+
eprintln!("Finding impl {type_reference}");
126130

127131
match self.cache.get(&type_reference) {
128132
Some(handler) => Ok(handler.clone_ref(py)),
@@ -133,6 +137,7 @@ impl SingleDispatchState {
133137
};
134138
self.cache
135139
.insert(type_reference, handler_for_cls.clone_ref(py));
140+
eprintln!("Found new handler {handler_for_cls}");
136141
Ok(handler_for_cls)
137142
}
138143
}
@@ -228,6 +233,7 @@ impl SingleDispatch {
228233
args: &Bound<'_, PyTuple>,
229234
kwargs: Option<&Bound<'_, PyDict>>,
230235
) -> PyResult<Py<PyAny>> {
236+
eprintln!("Calling");
231237
match obj.getattr(intern!(py, "__class__")) {
232238
Ok(cls) => {
233239
let mut all_args = Vec::with_capacity(1 + args.len());
@@ -244,6 +250,7 @@ impl SingleDispatch {
244250
}
245251

246252
fn dispatch(&self, py: Python<'_>, cls: Bound<'_, PyAny>) -> PyResult<PyObject> {
253+
eprintln!("Dispatching");
247254
match self.lock.lock() {
248255
Ok(mut state) => {
249256
match &state.cache_token {

src/singledispatch/mro.rs

Lines changed: 178 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
use crate::singledispatch::builtins::Builtins;
22
use crate::singledispatch::typeref::PyTypeReference;
33
use crate::singledispatch::typing::TypingModule;
4+
use pyo3::exceptions::PyRuntimeError;
45
use pyo3::prelude::*;
56
use pyo3::types::PyTuple;
67
use pyo3::{intern, Bound, PyObject, PyResult, Python};
@@ -18,6 +19,20 @@ pub(crate) fn get_obj_mro(cls: &Bound<'_, PyAny>) -> PyResult<HashSet<PyTypeRefe
1819
Ok(mro)
1920
}
2021

22+
fn get_obj_bases(cls: &Bound<'_, PyAny>) -> PyResult<Vec<PyTypeReference>> {
23+
match cls.getattr_opt(intern!(cls.py(), "__bases__")) {
24+
Ok(opt) => match opt {
25+
Some(b) => Ok(b
26+
.downcast::<PyTuple>()?
27+
.iter()
28+
.map(|item| PyTypeReference::new(item.unbind()))
29+
.collect()),
30+
None => Ok(Vec::new()),
31+
},
32+
Err(e) => Err(e),
33+
}
34+
}
35+
2136
fn get_obj_subclasses(cls: &Bound<'_, PyAny>) -> PyResult<HashSet<PyTypeReference>> {
2237
let subclasses: HashSet<_> = cls
2338
.call_method0(intern!(cls.py(), "__subclasses__"))?
@@ -28,8 +43,168 @@ fn get_obj_subclasses(cls: &Bound<'_, PyAny>) -> PyResult<HashSet<PyTypeReferenc
2843
Ok(subclasses)
2944
}
3045

31-
fn c3_mro(py: Python, cls: Bound<'_, PyAny>, abcs: Vec<PyTypeReference>) -> PyResult<Vec<PyTypeReference>> {
32-
Ok(abcs)
46+
fn find_merge_candidate(py: Python, seqs: &[&mut Vec<PyTypeReference>]) -> Option<PyTypeReference> {
47+
let mut candidate: Option<&PyTypeReference> = None;
48+
for i1 in 0..seqs.len() {
49+
let s1 = &seqs[i1];
50+
candidate = Some(&s1[0]);
51+
for i2 in 0..seqs.len() {
52+
let s2 = &seqs[i2];
53+
if s2[1..].contains(candidate.unwrap()) {
54+
candidate = None;
55+
break;
56+
}
57+
}
58+
if candidate.is_some() {
59+
break;
60+
}
61+
}
62+
match candidate {
63+
Some(c) => Some(c.clone_ref(py)),
64+
None => None,
65+
}
66+
}
67+
68+
struct C3Mro<'a> {
69+
seqs: &'a mut Vec<&'a mut Vec<PyTypeReference>>,
70+
}
71+
72+
impl C3Mro<'_> {
73+
fn for_abcs<'a>(
74+
py: Python,
75+
abcs: &'a mut Vec<&'a mut Vec<PyTypeReference>>,
76+
) -> PyResult<Vec<PyTypeReference>> {
77+
C3Mro { seqs: abcs }.merge(py)
78+
}
79+
80+
fn merge(&mut self, py: Python) -> PyResult<Vec<PyTypeReference>> {
81+
let mut result: Vec<PyTypeReference> = Vec::new();
82+
loop {
83+
let seqs = &mut self.seqs;
84+
seqs.retain(|seq| !seq.is_empty());
85+
if seqs.is_empty() {
86+
return Ok(result);
87+
}
88+
match find_merge_candidate(py, seqs.as_slice()) {
89+
Some(c) => {
90+
for i in 0..seqs.len() {
91+
let seq = &mut self.seqs[i];
92+
if seq[0].eq(&c) {
93+
seq.remove(0);
94+
}
95+
}
96+
result.push(c);
97+
}
98+
None => return Err(PyRuntimeError::new_err("Inconsistent hierarchy")),
99+
}
100+
}
101+
}
102+
}
103+
104+
fn c3_boundary(py: Python, bases: &[PyTypeReference]) -> usize {
105+
let mut boundary = 0;
106+
107+
for (i, base) in bases.iter().rev().enumerate() {
108+
if base
109+
.wrapped()
110+
.bind(py)
111+
.hasattr(intern!(py, "__abstractmethods__"))
112+
.unwrap()
113+
{
114+
boundary = bases.len() - i;
115+
break;
116+
}
117+
}
118+
119+
boundary
120+
}
121+
122+
fn c3_mro(
123+
py: Python,
124+
cls: &Bound<'_, PyAny>,
125+
abcs: Vec<PyTypeReference>,
126+
) -> PyResult<Vec<PyTypeReference>> {
127+
let bases = match get_obj_bases(cls) {
128+
Ok(b) => {
129+
if b.len() > 0 {
130+
b
131+
} else {
132+
return Ok(Vec::new());
133+
}
134+
}
135+
Err(e) => return Err(e),
136+
};
137+
let boundary = c3_boundary(py, &bases);
138+
eprintln!("boundary = {boundary}");
139+
let base = &bases[boundary];
140+
141+
let (explicit_bases, other_bases) = bases.split_at(boundary);
142+
let abstract_bases: Vec<_> = abcs
143+
.iter()
144+
.flat_map(|abc| {
145+
if Builtins::cached(py)
146+
.issubclass(py, cls, base.wrapped().bind(py))
147+
.unwrap()
148+
&& !bases.iter().any(|b| {
149+
Builtins::cached(py)
150+
.issubclass(py, b.wrapped().bind(py), base.wrapped().bind(py))
151+
.unwrap()
152+
})
153+
{
154+
vec![abc]
155+
} else {
156+
vec![]
157+
}
158+
})
159+
.collect();
160+
161+
let new_abcs: Vec<_> = abcs.iter().filter(|c| abstract_bases.contains(c)).collect();
162+
163+
let mut mros: Vec<&mut Vec<PyTypeReference>> = Vec::new();
164+
165+
let mut cls_ref = vec![PyTypeReference::new(cls.clone().unbind())];
166+
mros.push(&mut cls_ref);
167+
168+
let mut explicit_bases_mro = Vec::from_iter(explicit_bases.iter().map(|b| {
169+
c3_mro(
170+
py,
171+
b.wrapped().bind(py),
172+
new_abcs.iter().map(|abc| abc.clone_ref(py)).collect(),
173+
)
174+
.unwrap()
175+
}));
176+
mros.extend(&mut explicit_bases_mro);
177+
178+
let mut abstract_bases_mro = Vec::from_iter(abstract_bases.iter().map(|b| {
179+
c3_mro(
180+
py,
181+
b.wrapped().bind(py),
182+
new_abcs.iter().map(|abc| abc.clone_ref(py)).collect(),
183+
)
184+
.unwrap()
185+
}));
186+
mros.extend(&mut abstract_bases_mro);
187+
188+
let mut other_bases_mro = Vec::from_iter(other_bases.iter().map(|b| {
189+
c3_mro(
190+
py,
191+
b.wrapped().bind(py),
192+
new_abcs.iter().map(|abc| abc.clone_ref(py)).collect(),
193+
)
194+
.unwrap()
195+
}));
196+
mros.extend(&mut other_bases_mro);
197+
198+
let mut explicit_bases_cloned = Vec::from_iter(explicit_bases.iter().map(|b| b.clone_ref(py)));
199+
mros.push(&mut explicit_bases_cloned);
200+
201+
let mut abstract_bases_cloned = Vec::from_iter(abstract_bases.iter().map(|b| b.clone_ref(py)));
202+
mros.push(&mut abstract_bases_cloned);
203+
204+
let mut other_bases_cloned = Vec::from_iter(other_bases.iter().map(|b| b.clone_ref(py)));
205+
mros.push(&mut other_bases_cloned);
206+
207+
C3Mro::for_abcs(py, &mut mros)
33208
}
34209

35210
pub(crate) fn compose_mro(
@@ -103,5 +278,5 @@ pub(crate) fn compose_mro(
103278
}
104279
});
105280

106-
c3_mro(py, cls, mro)
281+
c3_mro(py, &cls, mro)
107282
}

0 commit comments

Comments
 (0)