Skip to content
This repository was archived by the owner on Dec 10, 2018. It is now read-only.

Commit 1a3b1bd

Browse files
author
misakwa
committed
Rewrite class during creation with supported slots
During the first creation of the loaded class with slots, a new class is inserted into the inheritance chain to ensure that the slot fields are defined. This is done because of the need to create an empty struct during the parsing phase and fill it in later - slots require the fields to be known before hand. All checks on the new replacement class will have it looking like the original except an equality comparison between the replaced class and its replacement. >>> import thriftpy >>> ab = thriftpy.load('addressbook.thrift', use_slots=True) >>> ab_inst = ab.AddressBook() >>> ab_inst.__class__ == ab.AddressBook # will return False >>> # all other checks should work as expected >>> isinstance(ab_inst, ab.AddressBook) # will return True >>> issubclass(ab_inst.__class__, ab.AddressBook) # will return True In order to get pickling to work as expected, a new extension type is registered with copyreg (copy_reg for py2x) to avoid pickling errors.
1 parent 9aa7a61 commit 1a3b1bd

File tree

3 files changed

+66
-94
lines changed

3 files changed

+66
-94
lines changed

tests/test_hook.py

Lines changed: 38 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -82,38 +82,54 @@ def test_load_slots():
8282
)
8383

8484
# normal structs will have slots
85-
# assert set(thrift.PhoneNumber.__slots__) == set(['type', 'number', 'mix_item'])
86-
# assert set(thrift.Person.__slots__) == set(['name', 'phones', 'created_at'])
87-
# assert set(thrift.AddressBook.__slots__) == set(['people'])
88-
# assert set(thrift.AddressBookService.__slots__) == set()
85+
assert thrift.PhoneNumber.__slots__ == ['type', 'number', 'mix_item']
86+
assert thrift.Person.__slots__ == ['name', 'phones', 'created_at']
87+
assert thrift.AddressBook.__slots__ == ['people']
8988

90-
# one cannot get/set undefined attributes
91-
# person = thrift.Person()
92-
# with pytest.raises(AttributeError):
93-
# person.attr_not_exist = "Does not work"
89+
# get/set undefined attributes
90+
person = thrift.Person()
91+
with pytest.raises(AttributeError):
92+
person.attr_not_exist = "Does not work"
9493

95-
# with pytest.raises(AttributeError):
96-
# person.attr_not_exist
94+
with pytest.raises(AttributeError):
95+
person.attr_not_exist
9796

98-
# pn = thrift.PhoneNumber()
99-
# with pytest.raises(AttributeError):
100-
# pn.attr_not_exist = "Does not work"
97+
pn = thrift.PhoneNumber()
98+
with pytest.raises(AttributeError):
99+
pn.attr_not_exist = "Does not work"
101100

102-
# with pytest.raises(AttributeError):
103-
# pn.attr_not_exist
101+
with pytest.raises(AttributeError):
102+
pn.attr_not_exist
104103

105-
# ab = thrift.AddressBook()
106-
# with pytest.raises(AttributeError):
107-
# ab.attr_not_exist = "Does not work"
104+
ab = thrift.AddressBook()
105+
with pytest.raises(AttributeError):
106+
ab.attr_not_exist = "Does not work"
108107

109-
# with pytest.raises(AttributeError):
110-
# ab.attr_not_exist
108+
with pytest.raises(AttributeError):
109+
ab.attr_not_exist
110+
# eo: get/set
111111

112112
# exceptions will not have slots
113-
# assert not hasattr(thrift.PersonNotExistsError, '__slots__')
113+
assert not hasattr(thrift.PersonNotExistsError, '__slots__')
114114

115115
# enums will not have slots
116-
# assert not hasattr(thrift.PhoneType, '__slots__')
116+
assert not hasattr(thrift.PhoneType, '__slots__')
117+
118+
# service itself will not be created with slots
119+
assert not hasattr(thrift.AddressBookService, '__slots__')
120+
121+
# service args will have slots
122+
args_slots = thrift.AddressBookService.get_phonenumbers_args.__slots__
123+
assert args_slots == ['name', 'count']
124+
125+
# XXX: service result will have their slot list empty until after they are
126+
# created. This is because the success field is inserted after calling
127+
# _make_struct. We could hardcode a check in the metaclass for names ending
128+
# with _result, but I'm unsure if its a good idea. This should usually not
129+
# be an issue since this object is only used internally, but we can revisit
130+
# the need to have it available when required.
131+
result_obj = thrift.AddressBookService.get_phonenumbers_result()
132+
assert result_obj.__slots__ == ['success']
117133

118134
# should be able to pickle slotted objects - if load with module_name
119135
bob = thrift.Person(name="Bob")

thriftpy/parser/parser.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -767,6 +767,8 @@ def _make_enum(name, kvs):
767767

768768
def _make_empty_struct(name, ttype=TType.STRUCT, base_cls=TPayload):
769769
attrs = {'__module__': thrift_stack[-1].__name__, '_ttype': ttype}
770+
if issubclass(base_cls, TSPayload):
771+
attrs['__slots__'] = []
770772
return type(name, (base_cls, ), attrs)
771773

772774

@@ -787,10 +789,9 @@ def _fill_in_struct(cls, fields, _gen_init=True):
787789
setattr(cls, 'thrift_spec', thrift_spec)
788790
setattr(cls, 'default_spec', default_spec)
789791
setattr(cls, '_tspec', _tspec)
790-
# add __slots__ so we can check way before the class is created, even though
791-
# it really does nothing here, the real work is done during instantiation
792+
# add __slots__ for easy introspection
792793
if issubclass(cls, TSPayload):
793-
cls.__slots__ = tuple(field for field, _ in default_spec)
794+
cls.__slots__ = [field for field, _ in default_spec]
794795
if _gen_init:
795796
gen_init(cls, thrift_spec, default_spec)
796797
return cls
@@ -808,8 +809,7 @@ def _make_service(name, funcs, extends, use_slots=False):
808809

809810
attrs = {'__module__': thrift_stack[-1].__name__}
810811
base_cls = TSPayload if use_slots else TPayload
811-
if use_slots:
812-
attrs['__slots__'] = tuple()
812+
# service class itself will not be created with slots
813813
cls = type(name, (extends, ), attrs)
814814
thrift_services = []
815815

thriftpy/thrift.py

Lines changed: 23 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,13 @@
99

1010
from __future__ import absolute_import
1111

12+
try:
13+
import copy_reg as copyreg
14+
except ImportError:
15+
import copyreg
16+
1217
import functools
13-
import itertools
1418
import linecache
15-
import operator
1619
import types
1720

1821
from ._compat import with_metaclass
@@ -136,57 +139,27 @@ def __new__(cls, name, bases, attrs):
136139
attrs["__init__"] = init_func_generator(cls, spec)
137140
return super(TPayloadMeta, cls).__new__(cls, name, bases, attrs)
138141

139-
# checks: instance and subclass checks
140-
def __instancecheck__(cls, inst):
141-
mod = inst.__class__.__module__
142-
name = inst.__class__.__name__
143-
key = "%s:%s" % (mod, name)
144-
repl_cls = TPayloadMeta._class_cache.get(key)
145-
if repl_cls is None:
146-
return type.__instancecheck__(cls, inst)
147-
return inst.__class__ is repl_cls
148-
149-
def __subclasscheck__(cls, subcls):
150-
mod = cls.__module__
151-
name = cls.__name__
152-
key = "%s:%s" % (mod, name)
153-
repl_cls = TPayloadMeta._class_cache.get(key)
154-
if repl_cls is None:
155-
return type.__subclasscheck__(cls, subcls)
156-
if cls == repl_cls:
157-
return True
158-
# the first class in __mro__ is replaced by the replacement class so we
159-
# can only look up from the second position
160-
return cls.__mro__[1] in repl_cls.__mro__
161-
# eo: checks
162-
163142
def __call__(cls, *args, **kw):
164-
if not cls.__mro__[1] == TSPayload:
165-
return super(TPayloadMeta, cls).__call__(cls, *args, **kw)
166-
# XXX: replaces class with new class using slot list from default_spec
143+
# if issubclass(cls, TSPayload):
144+
if not issubclass(cls, TSPayload):
145+
return type.__call__(cls, *args, **kw)
167146
cls_name = cls.__name__.split('.')[-1]
168147
cache_key = '%s:%s' % (cls.__module__, cls_name)
169-
# XXX: we trust default_spec more than __slots__ for this
170-
fields = tuple(field for field, _ in cls.default_spec)
171-
cls_obj = TPayloadMeta._class_cache.get(cache_key)
172-
if not cls_obj:
173-
cls_obj = type(
148+
kls = TPayloadMeta._class_cache.get(cache_key)
149+
if not kls:
150+
fields = [field for field, _ in cls.default_spec]
151+
kls = type(
174152
cls_name,
175-
cls.__mro__,
153+
(cls,),
176154
{
177155
'__slots__': fields,
178156
'__module__': cls.__module__,
179157
}
180158
)
181-
# XXX: need a better way to do this; its a dupe from parser.py
182-
cls_obj._ttype = cls._ttype
183-
cls_obj._tspec = cls._tspec
184-
cls_obj.default_spec = cls.default_spec
185-
cls_obj.thrift_spec = cls.thrift_spec
186-
# cls.__init__ is already bound to cls
187-
cls_obj.__init__ = init_func_generator(cls_obj, cls.default_spec)
188-
TPayloadMeta._class_cache[cache_key] = cls_obj
189-
return type.__call__(cls_obj, *args, **kw)
159+
TPayloadMeta._class_cache[cache_key] = kls
160+
fn = lambda obj: (cls, tuple(getattr(obj, f) for f in fields))
161+
copyreg.pickle(kls, fn)
162+
return type.__call__(kls, *args, **kw)
190163

191164

192165
def gen_init(cls, thrift_spec=None, default_spec=None):
@@ -236,11 +209,8 @@ def write(self, oprot):
236209
oprot.write_struct(self)
237210

238211
def __repr__(self):
239-
keys = itertools.chain.from_iterable(
240-
getattr(cls, '__slots__', tuple()) for cls in type(self).__mro__
241-
)
242-
keys = list(keys)
243-
values = operator.attrgetter(*keys)(self)
212+
keys = self.__slots__
213+
values = [getattr(self, k) for k in keys]
244214
l = ['%s=%r' % (key, value) for key, value in zip(keys, values)]
245215
return '%s(%s)' % (self.__class__.__name__, ', '.join(l))
246216

@@ -250,28 +220,14 @@ def __str__(self):
250220
def __eq__(self, other):
251221
if not isinstance(other, self.__class__):
252222
return False
253-
keys = itertools.chain.from_iterable(
254-
getattr(cls, '__slots__', tuple()) for cls in type(self).__mro__
255-
)
256-
keys = list(keys)
257-
getter = operator.attrgetter(*keys)
258-
return getter(self) == getter(other)
223+
keys = self.__slots__
224+
vals1 = [getattr(self, k) for k in keys]
225+
vals2 = [getattr(self, k) for k in keys]
226+
return vals1 == vals2
259227

260228
def __ne__(self, other):
261229
return not self.__eq__(other)
262230

263-
def __getstate__(self):
264-
keys = itertools.chain.from_iterable(
265-
getattr(cls, '__slots__', tuple()) for cls in type(self).__mro__
266-
)
267-
keys = list(keys)
268-
values = operator.attrgetter(*keys)(self)
269-
return tuple(zip(keys, values))
270-
271-
def __setstate__(self, state):
272-
for k, v in state:
273-
setattr(self, k, v)
274-
275231

276232
class TClient(object):
277233

0 commit comments

Comments
 (0)