Skip to content

Commit e4ec46f

Browse files
evhubclaude
andcommitted
Make _coconut_partial a proper class with __get__ for method binding
Replace the old function-based _coconut_partial (which just wrapped functools.partial and set __name__) with a class-based implementation that subclasses functools.partial. The new class adds: - __get__ descriptor support so partials work correctly as methods when assigned as class attributes (resolves #891) - __repr__ that delegates to functools.partial's repr for consistency - Proper __name__ propagation via __new__ The class is defined in root.py (for both Python 2 and 3 headers) and is now also used internally by the compiler modules (compiler.py, grammar.py, header.py, util.py) instead of functools.partial. Also updates remaining uses of _coconut.functools.partial in the runtime header template to use _coconut_partial, and adds tests for partial-as-method with both positional and keyword argument binding. Co-Authored-By: Claude (fennec-v8-fast) <noreply@anthropic.com>
1 parent bca524e commit e4ec46f

File tree

7 files changed

+58
-15
lines changed

7 files changed

+58
-15
lines changed

coconut/compiler/compiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
import os
3535
import re
3636
from contextlib import contextmanager
37-
from functools import partial, update_wrapper
37+
from functools import update_wrapper
3838
from collections import defaultdict
3939
from threading import Lock
4040
from copy import copy
@@ -54,6 +54,7 @@
5454
__version__ as pyparsing_version,
5555
)
5656

57+
from coconut.root import _coconut_partial as partial
5758
from coconut.constants import (
5859
PY35,
5960
specific_targets,

coconut/compiler/grammar.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,7 @@
2929

3030
from collections import defaultdict
3131
from contextlib import contextmanager
32-
from functools import partial
33-
32+
from coconut.root import _coconut_partial as partial
3433
from coconut._pyparsing import (
3534
USE_LINE_BY_LINE,
3635
Forward,

coconut/compiler/header.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020
from coconut.root import * # NOQA
2121

2222
import os.path
23-
from functools import partial
2423

25-
from coconut.root import _indent, _get_root_header
24+
from coconut.root import (
25+
_coconut_partial as partial,
26+
_indent,
27+
_get_root_header,
28+
)
2629
from coconut.exceptions import CoconutInternalException
2730
from coconut.terminal import internal_assert
2831
from coconut.constants import (

coconut/compiler/templates/header.py_template

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,11 +66,6 @@ class _coconut{object}:{COMMENT.EVERYTHING_HERE_MUST_BE_COPIED_TO_STUB_FILE}
6666
fmappables = list, tuple, dict, set, frozenset, bytes, bytearray
6767
abc.Sequence.register(collections.deque)
6868
Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, min, max, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, super, tuple, type, vars, zip, repr, print{comma_bytearray} = Ellipsis, NotImplemented, NotImplementedError, Exception, AttributeError, ImportError, IndexError, KeyError, NameError, TypeError, ValueError, StopIteration, RuntimeError, all, any, bool, bytes, callable, chr, classmethod, complex, dict, enumerate, filter, float, frozenset, getattr, hasattr, hash, id, int, isinstance, issubclass, iter, len, list, locals, globals, map, {lstatic}min{rstatic}, {lstatic}max{rstatic}, next, object, ord, property, range, reversed, set, setattr, slice, str, sum, {lstatic}super{rstatic}, tuple, type, vars, zip, {lstatic}repr{rstatic}, {lstatic}print{rstatic}{comma_bytearray}
69-
@_coconut_wraps(_coconut.functools.partial)
70-
def _coconut_partial(_coconut_func, *args, **kwargs):
71-
partial_func = _coconut.functools.partial(_coconut_func, *args, **kwargs)
72-
partial_func.__name__ = _coconut.getattr(_coconut_func, "__name__", None)
73-
return partial_func
7469
def _coconut_handle_cls_kwargs(**kwargs):
7570
"""Some code taken from six under the terms of its MIT license."""
7671
metaclass = kwargs.pop("metaclass", None)
@@ -699,7 +694,7 @@ class flatten(_coconut_has_iter):{COMMENT.cant_implement_len_else_list_calls_bec
699694
for i in _coconut.reversed(_coconut.range(0 if self.levels is None else self.levels + 1)):
700695
mapper = {_coconut_}reiterable
701696
for _ in _coconut.range(i):
702-
mapper = _coconut.functools.partial({_coconut_}map, mapper)
697+
mapper = _coconut_partial({_coconut_}map, mapper)
703698
self.iter = mapper(self.iter)
704699
self._made_reit = True
705700
return self.iter
@@ -725,7 +720,7 @@ class flatten(_coconut_has_iter):{COMMENT.cant_implement_len_else_list_calls_bec
725720
for i in _coconut.reversed(_coconut.range(self.levels + 1)):
726721
reverser = {_coconut_}reversed
727722
for _ in _coconut.range(i):
728-
reverser = _coconut.functools.partial({_coconut_}map, reverser)
723+
reverser = _coconut_partial({_coconut_}map, reverser)
729724
reversed_iter = reverser(reversed_iter)
730725
return self.__class__(reversed_iter, self.levels)
731726
def __repr__(self):
@@ -2028,7 +2023,7 @@ def mapreduce(key_value_func, iterable, **kwargs):
20282023
def _coconut_parallel_mapreduce(mapreduce_func, map_cls, *args, **kwargs):
20292024
if "map_using" in kwargs:
20302025
raise _coconut.TypeError("redundant map_using argument to process/thread mapreduce/collectby")
2031-
kwargs["map_using"] = _coconut.functools.partial(map_cls, stream=True, ordered=kwargs.pop("ordered", False), chunksize=kwargs.pop("chunksize", 1))
2026+
kwargs["map_using"] = _coconut_partial(map_cls, stream=True, ordered=kwargs.pop("ordered", False), chunksize=kwargs.pop("chunksize", 1))
20322027
with map_cls.multiple_sequential_calls(max_workers=kwargs.pop("max_workers", None)):
20332028
return mapreduce_func(*args, **kwargs)
20342029
mapreduce.using_processes = _coconut_partial(_coconut_parallel_mapreduce, mapreduce, process_map)

coconut/compiler/util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
import itertools
3737
import weakref
3838
import datetime as dt
39-
from functools import partial, reduce
39+
from functools import reduce
4040
from collections import defaultdict
4141
from contextlib import contextmanager
4242
from pprint import pformat, pprint
@@ -75,6 +75,7 @@
7575
all_parse_elements,
7676
)
7777

78+
from coconut.root import _coconut_partial as partial
7879
from coconut.integrations import embed
7980
from coconut.util import (
8081
pickle,

coconut/root.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
VERSION = "3.2.0"
2727
VERSION_NAME = None
2828
# False for release, int >= 1 for develop
29-
DEVELOP = 7
29+
DEVELOP = 8
3030
ALPHA = False # for pre releases rather than post releases
3131

3232
assert DEVELOP is False or DEVELOP >= 1, "DEVELOP must be False or an int >= 1"
@@ -77,6 +77,25 @@ def wrap(new_func):
7777
py_bytes, py_chr, py_dict, py_hex, py_input, py_int, py_map, py_object, py_oct, py_open, py_print, py_range, py_str, py_super, py_zip, py_filter, py_reversed, py_enumerate, py_repr, py_min, py_max = bytes, chr, dict, hex, input, int, map, object, oct, open, print, range, str, super, zip, filter, reversed, enumerate, repr, min, max
7878
_coconut_py_str, _coconut_py_super, _coconut_py_dict, _coconut_py_min, _coconut_py_max = str, super, dict, min, max
7979
exec("_coconut_exec = exec")
80+
class _coconut_partial(_coconut_functools.partial):
81+
__slots__ = ()
82+
def __new__(cls, func, *args, **kwargs):
83+
self = _coconut_functools.partial.__new__(cls, func, *args, **kwargs)
84+
self.__name__ = _coconut.getattr(func, "__name__", None)
85+
return self
86+
def __get__(self, obj, objtype=None):
87+
if obj is None:
88+
return self
89+
return _coconut.types.MethodType(self, obj)
90+
def __repr__(self):
91+
cls = type(self)
92+
old_qualname, cls.__qualname__ = cls.__qualname__, _coconut_functools.partial.__qualname__
93+
old_module, cls.__module__ = cls.__module__, _coconut_functools.partial.__module__
94+
try:
95+
return _coconut_functools.partial.__repr__(self)
96+
finally:
97+
cls.__qualname__ = old_qualname
98+
cls.__module__ = old_module
8099
'''
81100

82101
# if a new assignment is added below, a new builtins import should be added alongside it
@@ -87,6 +106,25 @@ def wrap(new_func):
87106
from future_builtins import *
88107
chr, str = unichr, unicode
89108
from io import open
109+
class _coconut_partial(_coconut_functools.partial):
110+
__slots__ = ()
111+
def __new__(cls, func, *args, **kwargs):
112+
self = _coconut_functools.partial.__new__(cls, func, *args, **kwargs)
113+
self.__name__ = _coconut.getattr(func, "__name__", None)
114+
return self
115+
def __get__(self, obj, objtype=None):
116+
if obj is None:
117+
return self
118+
return _coconut.types.MethodType(self, obj, objtype)
119+
def __repr__(self):
120+
cls = type(self)
121+
old_name, cls.__name__ = cls.__name__, _coconut_functools.partial.__name__
122+
old_module, cls.__module__ = cls.__module__, _coconut_functools.partial.__module__
123+
try:
124+
return _coconut_functools.partial.__repr__(self)
125+
finally:
126+
cls.__name__ = old_name
127+
cls.__module__ = old_module
90128
class object(object):
91129
__slots__ = ()
92130
def __ne__(self, other):
@@ -487,4 +525,7 @@ def _get_root_header(version="universal"):
487525
import os as _os
488526
_coconut.os = _os
489527

528+
import types as _types
529+
_coconut.types = _types
530+
490531
exec(_get_root_header())

coconut/tests/src/cocotest/agnostic/primary_2.coco

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,8 +367,11 @@ def primary_test_2() -> bool:
367367
class HasPartial:
368368
def f(self, x) = (self, x)
369369
g = f$(?, 1)
370+
def h(self, x, y) = (self, x, y)
371+
i = h$(y=42)
370372
has_partial = HasPartial()
371373
assert has_partial.g() == (has_partial, 1)
374+
assert has_partial.i(10) == (has_partial, 10, 42)
372375
xs = zip([1, 2], [3, 4])
373376
py_xs = py_zip([1, 2], [3, 4])
374377
assert list(xs) == [(1, 3), (2, 4)] == list(xs)

0 commit comments

Comments
 (0)