Skip to content

Commit 8e76a8a

Browse files
committed
Add full support for instance methods
1 parent 2af9879 commit 8e76a8a

File tree

2 files changed

+42
-19
lines changed

2 files changed

+42
-19
lines changed

src/variants/_variants.py

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# -*- coding: utf-8 -*-
22
"""Provides the variant form decorator."""
33

4-
import types
4+
import functools
55

66
__all__ = ['variants']
77

@@ -30,26 +30,60 @@ def myfunc(url):
3030
class VariantFunction:
3131
__doc__ = f.__doc__
3232

33+
def __init__(self):
34+
self._variants = set()
35+
3336
def __call__(self, *args, **kwargs):
3437
return f(*args, **kwargs)
3538

39+
def _add_variant(self, var_name, vfunc):
40+
self._variants.add(var_name)
41+
setattr(self, var_name, vfunc)
42+
3643
def variant(self, func_name):
3744
"""Decorator to add a new variant form to the function."""
3845
def decorator(vfunc):
39-
setattr(self.__class__, func_name, staticmethod(vfunc))
46+
self._add_variant(func_name, vfunc)
47+
4048
return self
4149

4250
return decorator
4351

4452
def __get__(self, instance, owner):
4553
# This is necessary to bind instance methods
46-
if instance is None:
47-
return self
54+
if instance is not None:
55+
return VariantMethod(self, instance)
4856

49-
return types.MethodType(self, instance)
57+
return self
5058

5159
def __repr__(self):
52-
return '<VariantFunction {}>'.format(self.__name__)
60+
return '<{} {}>'.format(self.__class__.__name__, self.__name__)
61+
62+
class VariantMethod(VariantFunction):
63+
def __init__(self, variant_func, instance):
64+
self.__instance = instance
65+
self.__name__ = variant_func.__name__
66+
67+
# Convert existing variants to methods
68+
for vname in variant_func._variants:
69+
vfunc = getattr(variant_func, vname)
70+
vmethod = self._as_bound_method(vfunc)
71+
72+
setattr(self, vname, vmethod)
73+
74+
def __call__(self, *args, **kwargs):
75+
return f(self.__instance, *args, **kwargs)
76+
77+
def _as_bound_method(self, vfunc):
78+
@functools.wraps(vfunc)
79+
def bound_method(*args, **kwargs):
80+
return vfunc(self.__instance, *args, **kwargs)
81+
82+
return bound_method
83+
84+
def _add_variant(self, var_name, vfunc):
85+
self._variants.add(var_name)
86+
setattr(self, var_name, self._as_bound_method(vfunc))
5387

5488
f_out = VariantFunction()
5589
f_out.__name__ = f.__name__

tests/test_instance_methods.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -48,41 +48,30 @@ def test_divide(x, y, expected):
4848
assert dv.divide(y) == expected
4949

5050

51-
52-
@pytest.mark.xfail
5351
@pytest.mark.parametrize('x,y,expected', DivisionData.ROUND_VALS)
5452
def test_round(x, y, expected):
5553
dv = DivisionVariants(x)
5654
assert dv.divide.round(y) == expected
5755

5856

59-
60-
61-
@pytest.mark.xfail
6257
@pytest.mark.parametrize('x,y,expected', DivisionData.FLOOR_VALS)
6358
def test_floor(x, y, expected):
6459
dv = DivisionVariants(x)
6560
assert dv.divide.floor(y) == expected
6661

6762

68-
69-
70-
@pytest.mark.xfail
7163
@pytest.mark.parametrize('x,y,expected', DivisionData.CEIL_VALS)
72-
def test_floor(x, y, expected):
64+
def test_ceil(x, y, expected):
7365
dv = DivisionVariants(x)
74-
assert dv.divide.floor(y) == expected
75-
66+
assert dv.divide.ceil(y) == expected
7667

7768

78-
@pytest.mark.xfail
7969
@pytest.mark.parametrize('x,y,expected,mode', DivisionData.MODE_VALS)
8070
def test_mode(x, y, expected, mode):
8171
dv = DivisionVariants(x)
8272
assert dv.divide.mode(y, mode) == expected
8373

8474

85-
@pytest.mark.xfail
8675
@pytest.mark.parametrize('x,y,expected,mode', DivisionData.MODE_VALS)
8776
def test_mode_change_x(x, y, expected, mode):
8877
# Test that with mutable values it still works after x is changed

0 commit comments

Comments
 (0)