diff --git a/src/singledispatch/core.rs b/src/singledispatch/core.rs index d393098..3f4a8ae 100644 --- a/src/singledispatch/core.rs +++ b/src/singledispatch/core.rs @@ -74,6 +74,7 @@ impl SingleDispatchState { let cls_mro = get_obj_mro(&cls.clone())?; let mro = compose_mro(py, cls.clone(), self.registry.keys())?; let mut mro_match: Option = None; + eprintln!("Finding impl for {cls}"); for typ in mro.iter() { if self.registry.contains_key(typ) { mro_match = Some(typ.clone_ref(py)); @@ -86,13 +87,14 @@ impl SingleDispatchState { && !cls_mro.contains(m) && Builtins::cached(py) .issubclass(py, m.wrapped().bind(py), typ.wrapped().bind(py)) - .is_ok_and(|res| res) + .is_ok_and(|res| !res) { return Err(PyRuntimeError::new_err(format!( "Ambiguous dispatch: {m} or {typ}" ))); } mro_match = Some(m.clone_ref(py)); + eprintln!("MRO match: {m}"); break; } } @@ -104,6 +106,7 @@ impl SingleDispatchState { Some(f) => Ok(f), None => { let obj_type = PyTypeReference::new(Builtins::cached(py).object_type.clone_ref(py)); + eprintln!("Found impl for {cls}: {obj_type}"); match self.registry.get(&obj_type) { Some(it) => Ok(it.clone_ref(py)), None => Err(PyRuntimeError::new_err(format!( @@ -117,6 +120,7 @@ impl SingleDispatchState { fn get_or_find_impl(&mut self, py: Python, cls: Bound<'_, PyAny>) -> PyResult { let free_cls = cls.unbind(); let type_reference = PyTypeReference::new(free_cls.clone_ref(py)); + eprintln!("Finding impl {type_reference}"); match self.cache.get(&type_reference) { Some(handler) => Ok(handler.clone_ref(py)), @@ -127,6 +131,7 @@ impl SingleDispatchState { }; self.cache .insert(type_reference, handler_for_cls.clone_ref(py)); + eprintln!("Found new handler {handler_for_cls}"); Ok(handler_for_cls) } } @@ -224,6 +229,7 @@ impl SingleDispatch { args: &Bound<'_, PyTuple>, kwargs: Option<&Bound<'_, PyDict>>, ) -> PyResult> { + eprintln!("Calling"); match obj.getattr(intern!(py, "__class__")) { Ok(cls) => { let mut all_args = Vec::with_capacity(1 + args.len()); @@ -240,6 +246,7 @@ impl SingleDispatch { } fn dispatch(&self, py: Python<'_>, cls: Bound<'_, PyAny>) -> PyResult { + eprintln!("Dispatching"); match self.lock.lock() { Ok(mut state) => { if let Some(cache_token) = &state.cache_token { diff --git a/src/singledispatch/mro.rs b/src/singledispatch/mro.rs index 6085fcb..ee7c48e 100644 --- a/src/singledispatch/mro.rs +++ b/src/singledispatch/mro.rs @@ -1,9 +1,11 @@ use crate::singledispatch::builtins::Builtins; use crate::singledispatch::typeref::PyTypeReference; use crate::singledispatch::typing::TypingModule; +use pyo3::exceptions::PyRuntimeError; use pyo3::prelude::*; -use pyo3::types::PyTuple; +use pyo3::types::{PyList, PyTuple}; use pyo3::{intern, Bound, PyObject, PyResult, Python}; +use std::borrow::Borrow; use std::cmp::Reverse; use std::collections::hash_map::Keys; use std::collections::HashSet; @@ -18,22 +20,188 @@ pub(crate) fn get_obj_mro(cls: &Bound<'_, PyAny>) -> PyResult) -> PyResult> { + match cls.getattr_opt(intern!(cls.py(), "__bases__"))? { + Some(b) => Ok(b + .downcast::()? + .iter() + .map(|item| PyTypeReference::new(item.unbind())) + .collect()), + None => Ok(Vec::new()), + } +} + fn get_obj_subclasses(cls: &Bound<'_, PyAny>) -> PyResult> { let subclasses: HashSet<_> = cls .call_method0(intern!(cls.py(), "__subclasses__"))? - .downcast::()? + .downcast::()? .iter() .map(|item| PyTypeReference::new(item.unbind())) .collect(); Ok(subclasses) } +fn find_merge_candidate(py: Python, seqs: &[&mut Vec]) -> Option { + let mut candidate: Option<&PyTypeReference> = None; + for i1 in 0..seqs.len() { + let s1 = &seqs[i1]; + candidate = Some(&s1[0]); + for i2 in 0..seqs.len() { + let s2 = &seqs[i2]; + if s2[1..].contains(candidate.unwrap()) { + candidate = None; + break; + } + } + if candidate.is_some() { + break; + } + } + candidate.map(|c| c.clone_ref(py)) +} + +fn merge_mro( + seqs: &mut Vec<&mut Vec>, + py: Python, +) -> PyResult> { + let mut result: Vec = Vec::new(); + loop { + //let seqs = seqs; + seqs.retain(|seq| !seq.is_empty()); + if seqs.is_empty() { + return Ok(result); + } + match find_merge_candidate(py, seqs.as_slice()) { + Some(c) => { + for i in 0..seqs.len() { + let seq = &mut seqs[i]; + if seq[0].eq(&c) { + seq.remove(0); + } + } + result.push(c); + } + None => return Err(PyRuntimeError::new_err("Inconsistent hierarchy")), + } + } +} + +fn c3_boundary(py: Python, bases: &[PyTypeReference]) -> PyResult { + let mut boundary = 0; + + for (i, base) in bases.iter().rev().enumerate() { + if base + .wrapped() + .bind(py) + .hasattr(intern!(py, "__abstractmethods__"))? + { + boundary = bases.len() - i; + break; + } + } + + Ok(boundary) +} + +fn sub_c3_mro( + py: Python, + bases: I, + abcs: &Vec<&PyTypeReference>, +) -> PyResult>> +where + G: Borrow, + I: Iterator, +{ + let mut v: Vec> = Vec::new(); + for b in bases { + v.push(c3_mro( + py, + b.borrow().wrapped().bind(py), + abcs.iter().map(|abc| abc.clone_ref(py)).collect(), + )?); + } + Ok(v) +} + fn c3_mro( py: Python, - cls: Bound<'_, PyAny>, + cls: &Bound<'_, PyAny>, abcs: Vec, ) -> PyResult> { - Ok(abcs) + eprintln!("cls = {cls:#?}"); + eprintln!("abcs = {abcs:#?}"); + let bases = match get_obj_bases(cls) { + Ok(b) => { + if !b.is_empty() { + b + } else { + return Ok(Vec::new()); + } + } + Err(e) => return Err(e), + }; + eprintln!("bases = {bases:#?}"); + let boundary = c3_boundary(py, &bases)?; + eprintln!("boundary = {boundary}"); + + let (explicit_bases, other_bases) = bases.split_at(boundary); + let abstract_bases: Vec<_> = abcs + .iter() + .flat_map(|abc| { + if Builtins::cached(py) + .issubclass(py, cls, abc.wrapped().bind(py)) + .unwrap() + && !bases.iter().any(|b| { + Builtins::cached(py) + .issubclass(py, b.wrapped().bind(py), abc.wrapped().bind(py)) + .unwrap() + }) + { + vec![abc.clone_ref(py)] + } else { + vec![] + } + }) + .collect(); + eprintln!("explict_bases = {explicit_bases:#?}"); + eprintln!("other_bases = {other_bases:#?}"); + eprintln!("abstract_bases = {abstract_bases:#?}"); + + let new_abcs: Vec<_> = abcs + .iter() + .filter(|&c| !abstract_bases.contains(c)) + .collect(); + eprintln!("new_abcs = {new_abcs:#?}"); + + let mut mros: Vec<&mut Vec> = Vec::new(); + + let mut cls_ref = vec![PyTypeReference::new(cls.clone().unbind())]; + mros.push(&mut cls_ref); + + let mut explicit_bases_mro = sub_c3_mro(py, explicit_bases.iter(), &new_abcs)?; + mros.extend(&mut explicit_bases_mro); + + let mut abstract_bases_mro = sub_c3_mro( + py, + abstract_bases.iter().map(|v| v.clone_ref(py)), + &new_abcs, + )?; + eprintln!("abstract_bases_mro = {abstract_bases_mro:#?}"); + mros.extend(&mut abstract_bases_mro); + + let mut other_bases_mro = sub_c3_mro(py, other_bases.iter(), &new_abcs)?; + mros.extend(&mut other_bases_mro); + + let mut explicit_bases_cloned = Vec::from_iter(explicit_bases.iter().map(|b| b.clone_ref(py))); + mros.push(&mut explicit_bases_cloned); + + let mut abstract_bases_cloned = Vec::from_iter(abstract_bases.iter().map(|b| b.clone_ref(py))); + mros.push(&mut abstract_bases_cloned); + + let mut other_bases_cloned = Vec::from_iter(other_bases.iter().map(|b| b.clone_ref(py))); + mros.push(&mut other_bases_cloned); + + merge_mro(&mut mros, py) } pub(crate) fn compose_mro( @@ -45,7 +213,9 @@ pub(crate) fn compose_mro( let typing = TypingModule::cached(py); let bases: HashSet<_> = get_obj_mro(&cls)?; + eprintln!("bases = {bases:#?}"); let registered_types: HashSet<_> = types.collect(); + eprintln!("registered_types = {registered_types:#?}"); let eligible_types: HashSet<_> = registered_types .iter() .filter(|&tref| { @@ -68,6 +238,7 @@ pub(crate) fn compose_mro( }) .copied() .collect(); + eprintln!("eligible_types = {eligible_types:#?}"); let mut mro: Vec = Vec::new(); eligible_types.iter().for_each(|&tref| { // Subclasses of the ABCs in *types* which are also implemented by @@ -77,6 +248,7 @@ pub(crate) fn compose_mro( .unwrap() .iter() .filter(|subclass| { + eprintln!("subclass = {subclass:#?}"); let typ = subclass.wrapped(); let tref = PyTypeReference::new(typ.clone_ref(py)); !bases.contains(&tref) @@ -106,6 +278,9 @@ pub(crate) fn compose_mro( }); } }); + eprintln!("Pre-mro candidates {mro:#?}"); - c3_mro(py, cls, mro) + let final_rmo = c3_mro(py, &cls, mro); + eprintln!("MRO for {cls}: {final_rmo:#?}"); + final_rmo } diff --git a/src/singledispatch/typeref.rs b/src/singledispatch/typeref.rs index dbae827..ed5eced 100644 --- a/src/singledispatch/typeref.rs +++ b/src/singledispatch/typeref.rs @@ -1,5 +1,5 @@ use pyo3::{PyObject, Python}; -use std::fmt::{Display, Formatter}; +use std::fmt::{Debug, Display, Formatter}; use std::hash::{Hash, Hasher}; pub struct PyTypeReference { @@ -22,6 +22,12 @@ impl PyTypeReference { } } +impl Debug for PyTypeReference { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + std::fmt::Display::fmt(&self.wrapped, f) + } +} + impl Display for PyTypeReference { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { std::fmt::Display::fmt(&self.wrapped, f) diff --git a/tests/test_singledispatch_native.py b/tests/test_singledispatch_native.py index 016c0a6..9356e4a 100644 --- a/tests/test_singledispatch_native.py +++ b/tests/test_singledispatch_native.py @@ -1,4 +1,7 @@ +from collections.abc import Sequence + import pytest +#from functools import singledispatch from singledispatch_native import singledispatch from typing import Any @@ -14,17 +17,31 @@ def _some_fun_str(o: str) -> str: @some_fun.register(int) -def _some_fun_str(o: int) -> str: +def _some_fun_int(o: int) -> str: return "It's an int!" +@some_fun.register(Sequence) +def _some_fun_sequence(l: Sequence) -> str: + return "Sequence: " + ", ".join(l) + + +@some_fun.register(tuple) +def _some_fun_tuple(l: tuple) -> str: + return "tuple: " + ", ".join(l) + + @pytest.mark.parametrize( "v,ret", [ (None, "Got None "), ("val", "It's a string!"), (1, "It's an int!"), - # (True, "It's an int!"), + (True, "It's an int!"), + ([], "Sequence: "), + (["1"], "Sequence: 1"), + (["1", "2", "3"], "Sequence: 1, 2, 3"), + (("1", "2", "3"), "tuple: 1, 2, 3"), ] ) def test_singledispatch(v, ret):