Skip to content

Full C3 MRO algorithm #1

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 8 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/singledispatch/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ impl SingleDispatchState {
let cls_mro = get_obj_mro(&cls.clone())?;
let mro = compose_mro(py, cls.clone(), self.registry.keys())?;
let mut mro_match: Option<PyTypeReference> = None;
eprintln!("Finding impl for {cls}");
for typ in mro.iter() {
if self.registry.contains_key(typ) {
mro_match = Some(typ.clone_ref(py));
Expand All @@ -86,13 +87,14 @@ impl SingleDispatchState {
&& !cls_mro.contains(m)
&& Builtins::cached(py)
.issubclass(py, m.wrapped().bind(py), typ.wrapped().bind(py))
.is_ok_and(|res| res)
.is_ok_and(|res| !res)
{
return Err(PyRuntimeError::new_err(format!(
"Ambiguous dispatch: {m} or {typ}"
)));
}
mro_match = Some(m.clone_ref(py));
eprintln!("MRO match: {m}");
break;
}
}
Expand All @@ -104,6 +106,7 @@ impl SingleDispatchState {
Some(f) => Ok(f),
None => {
let obj_type = PyTypeReference::new(Builtins::cached(py).object_type.clone_ref(py));
eprintln!("Found impl for {cls}: {obj_type}");
match self.registry.get(&obj_type) {
Some(it) => Ok(it.clone_ref(py)),
None => Err(PyRuntimeError::new_err(format!(
Expand All @@ -117,6 +120,7 @@ impl SingleDispatchState {
fn get_or_find_impl(&mut self, py: Python, cls: Bound<'_, PyAny>) -> PyResult<PyObject> {
let free_cls = cls.unbind();
let type_reference = PyTypeReference::new(free_cls.clone_ref(py));
eprintln!("Finding impl {type_reference}");

match self.cache.get(&type_reference) {
Some(handler) => Ok(handler.clone_ref(py)),
Expand All @@ -127,6 +131,7 @@ impl SingleDispatchState {
};
self.cache
.insert(type_reference, handler_for_cls.clone_ref(py));
eprintln!("Found new handler {handler_for_cls}");
Ok(handler_for_cls)
}
}
Expand Down Expand Up @@ -224,6 +229,7 @@ impl SingleDispatch {
args: &Bound<'_, PyTuple>,
kwargs: Option<&Bound<'_, PyDict>>,
) -> PyResult<Py<PyAny>> {
eprintln!("Calling");
match obj.getattr(intern!(py, "__class__")) {
Ok(cls) => {
let mut all_args = Vec::with_capacity(1 + args.len());
Expand All @@ -240,6 +246,7 @@ impl SingleDispatch {
}

fn dispatch(&self, py: Python<'_>, cls: Bound<'_, PyAny>) -> PyResult<PyObject> {
eprintln!("Dispatching");
match self.lock.lock() {
Ok(mut state) => {
if let Some(cache_token) = &state.cache_token {
Expand Down
185 changes: 180 additions & 5 deletions src/singledispatch/mro.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
use crate::singledispatch::builtins::Builtins;
use crate::singledispatch::typeref::PyTypeReference;
use crate::singledispatch::typing::TypingModule;
use pyo3::exceptions::PyRuntimeError;
use pyo3::prelude::*;
use pyo3::types::PyTuple;
use pyo3::types::{PyList, PyTuple};
use pyo3::{intern, Bound, PyObject, PyResult, Python};
use std::borrow::Borrow;
use std::cmp::Reverse;
use std::collections::hash_map::Keys;
use std::collections::HashSet;
Expand All @@ -18,22 +20,188 @@ pub(crate) fn get_obj_mro(cls: &Bound<'_, PyAny>) -> PyResult<HashSet<PyTypeRefe
Ok(mro)
}

fn get_obj_bases(cls: &Bound<'_, PyAny>) -> PyResult<Vec<PyTypeReference>> {
match cls.getattr_opt(intern!(cls.py(), "__bases__"))? {
Some(b) => Ok(b
.downcast::<PyTuple>()?
.iter()
.map(|item| PyTypeReference::new(item.unbind()))
.collect()),
None => Ok(Vec::new()),
}
}

fn get_obj_subclasses(cls: &Bound<'_, PyAny>) -> PyResult<HashSet<PyTypeReference>> {
let subclasses: HashSet<_> = cls
.call_method0(intern!(cls.py(), "__subclasses__"))?
.downcast::<PyTuple>()?
.downcast::<PyList>()?
.iter()
.map(|item| PyTypeReference::new(item.unbind()))
.collect();
Ok(subclasses)
}

fn find_merge_candidate(py: Python, seqs: &[&mut Vec<PyTypeReference>]) -> Option<PyTypeReference> {
let mut candidate: Option<&PyTypeReference> = None;
for i1 in 0..seqs.len() {
let s1 = &seqs[i1];
candidate = Some(&s1[0]);
for i2 in 0..seqs.len() {
let s2 = &seqs[i2];
if s2[1..].contains(candidate.unwrap()) {
candidate = None;
break;
}
}
if candidate.is_some() {
break;
}
}
candidate.map(|c| c.clone_ref(py))
}

fn merge_mro(
seqs: &mut Vec<&mut Vec<PyTypeReference>>,
py: Python,
) -> PyResult<Vec<PyTypeReference>> {
let mut result: Vec<PyTypeReference> = Vec::new();
loop {
//let seqs = seqs;
seqs.retain(|seq| !seq.is_empty());
if seqs.is_empty() {
return Ok(result);
}
match find_merge_candidate(py, seqs.as_slice()) {
Some(c) => {
for i in 0..seqs.len() {
let seq = &mut seqs[i];
if seq[0].eq(&c) {
seq.remove(0);
}
}
result.push(c);
}
None => return Err(PyRuntimeError::new_err("Inconsistent hierarchy")),
}
}
}

fn c3_boundary(py: Python, bases: &[PyTypeReference]) -> PyResult<usize> {
let mut boundary = 0;

for (i, base) in bases.iter().rev().enumerate() {
if base
.wrapped()
.bind(py)
.hasattr(intern!(py, "__abstractmethods__"))?
{
boundary = bases.len() - i;
break;
}
}

Ok(boundary)
}

fn sub_c3_mro<I, G>(
py: Python,
bases: I,
abcs: &Vec<&PyTypeReference>,
) -> PyResult<Vec<Vec<PyTypeReference>>>
where
G: Borrow<PyTypeReference>,
I: Iterator<Item = G>,
{
let mut v: Vec<Vec<PyTypeReference>> = Vec::new();
for b in bases {
v.push(c3_mro(
py,
b.borrow().wrapped().bind(py),
abcs.iter().map(|abc| abc.clone_ref(py)).collect(),
)?);
}
Ok(v)
}

fn c3_mro(
py: Python,
cls: Bound<'_, PyAny>,
cls: &Bound<'_, PyAny>,
abcs: Vec<PyTypeReference>,
) -> PyResult<Vec<PyTypeReference>> {
Ok(abcs)
eprintln!("cls = {cls:#?}");
eprintln!("abcs = {abcs:#?}");
let bases = match get_obj_bases(cls) {
Ok(b) => {
if !b.is_empty() {
b
} else {
return Ok(Vec::new());
}
}
Err(e) => return Err(e),
};
eprintln!("bases = {bases:#?}");
let boundary = c3_boundary(py, &bases)?;
eprintln!("boundary = {boundary}");

let (explicit_bases, other_bases) = bases.split_at(boundary);
let abstract_bases: Vec<_> = abcs
.iter()
.flat_map(|abc| {
if Builtins::cached(py)
.issubclass(py, cls, abc.wrapped().bind(py))
.unwrap()
&& !bases.iter().any(|b| {
Builtins::cached(py)
.issubclass(py, b.wrapped().bind(py), abc.wrapped().bind(py))
.unwrap()
})
{
vec![abc.clone_ref(py)]
} else {
vec![]
}
})
.collect();
eprintln!("explict_bases = {explicit_bases:#?}");
eprintln!("other_bases = {other_bases:#?}");
eprintln!("abstract_bases = {abstract_bases:#?}");

let new_abcs: Vec<_> = abcs
.iter()
.filter(|&c| !abstract_bases.contains(c))
.collect();
eprintln!("new_abcs = {new_abcs:#?}");

let mut mros: Vec<&mut Vec<PyTypeReference>> = Vec::new();

let mut cls_ref = vec![PyTypeReference::new(cls.clone().unbind())];
mros.push(&mut cls_ref);

let mut explicit_bases_mro = sub_c3_mro(py, explicit_bases.iter(), &new_abcs)?;
mros.extend(&mut explicit_bases_mro);

let mut abstract_bases_mro = sub_c3_mro(
py,
abstract_bases.iter().map(|v| v.clone_ref(py)),
&new_abcs,
)?;
eprintln!("abstract_bases_mro = {abstract_bases_mro:#?}");
mros.extend(&mut abstract_bases_mro);

let mut other_bases_mro = sub_c3_mro(py, other_bases.iter(), &new_abcs)?;
mros.extend(&mut other_bases_mro);

let mut explicit_bases_cloned = Vec::from_iter(explicit_bases.iter().map(|b| b.clone_ref(py)));
mros.push(&mut explicit_bases_cloned);

let mut abstract_bases_cloned = Vec::from_iter(abstract_bases.iter().map(|b| b.clone_ref(py)));
mros.push(&mut abstract_bases_cloned);

let mut other_bases_cloned = Vec::from_iter(other_bases.iter().map(|b| b.clone_ref(py)));
mros.push(&mut other_bases_cloned);

merge_mro(&mut mros, py)
}

pub(crate) fn compose_mro(
Expand All @@ -45,7 +213,9 @@ pub(crate) fn compose_mro(
let typing = TypingModule::cached(py);

let bases: HashSet<_> = get_obj_mro(&cls)?;
eprintln!("bases = {bases:#?}");
let registered_types: HashSet<_> = types.collect();
eprintln!("registered_types = {registered_types:#?}");
let eligible_types: HashSet<_> = registered_types
.iter()
.filter(|&tref| {
Expand All @@ -68,6 +238,7 @@ pub(crate) fn compose_mro(
})
.copied()
.collect();
eprintln!("eligible_types = {eligible_types:#?}");
let mut mro: Vec<PyTypeReference> = Vec::new();
eligible_types.iter().for_each(|&tref| {
// Subclasses of the ABCs in *types* which are also implemented by
Expand All @@ -77,6 +248,7 @@ pub(crate) fn compose_mro(
.unwrap()
.iter()
.filter(|subclass| {
eprintln!("subclass = {subclass:#?}");
let typ = subclass.wrapped();
let tref = PyTypeReference::new(typ.clone_ref(py));
!bases.contains(&tref)
Expand Down Expand Up @@ -106,6 +278,9 @@ pub(crate) fn compose_mro(
});
}
});
eprintln!("Pre-mro candidates {mro:#?}");

c3_mro(py, cls, mro)
let final_rmo = c3_mro(py, &cls, mro);
eprintln!("MRO for {cls}: {final_rmo:#?}");
final_rmo
}
8 changes: 7 additions & 1 deletion src/singledispatch/typeref.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use pyo3::{PyObject, Python};
use std::fmt::{Display, Formatter};
use std::fmt::{Debug, Display, Formatter};
use std::hash::{Hash, Hasher};

pub struct PyTypeReference {
Expand All @@ -22,6 +22,12 @@ impl PyTypeReference {
}
}

impl Debug for PyTypeReference {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self.wrapped, f)
}
}

impl Display for PyTypeReference {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
std::fmt::Display::fmt(&self.wrapped, f)
Expand Down
21 changes: 19 additions & 2 deletions tests/test_singledispatch_native.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from collections.abc import Sequence

import pytest
#from functools import singledispatch
from singledispatch_native import singledispatch

from typing import Any
Expand All @@ -14,17 +17,31 @@ def _some_fun_str(o: str) -> str:


@some_fun.register(int)
def _some_fun_str(o: int) -> str:
def _some_fun_int(o: int) -> str:
return "It's an int!"


@some_fun.register(Sequence)
def _some_fun_sequence(l: Sequence) -> str:
return "Sequence: " + ", ".join(l)


@some_fun.register(tuple)
def _some_fun_tuple(l: tuple) -> str:
return "tuple: " + ", ".join(l)


@pytest.mark.parametrize(
"v,ret",
[
(None, "Got None <class 'NoneType'>"),
("val", "It's a string!"),
(1, "It's an int!"),
# (True, "It's an int!"),
(True, "It's an int!"),
([], "Sequence: "),
(["1"], "Sequence: 1"),
(["1", "2", "3"], "Sequence: 1, 2, 3"),
(("1", "2", "3"), "tuple: 1, 2, 3"),
]
)
def test_singledispatch(v, ret):
Expand Down