diff --git a/tests/test_hook.py b/tests/test_hook.py index fe66341..c0d1fc8 100644 --- a/tests/test_hook.py +++ b/tests/test_hook.py @@ -72,3 +72,69 @@ def test_tpayload_pickle(): person_2 = pickle.loads(PICKLED_BYTES) assert person == person_2 + + +def test_load_slots(): + thrift = thriftpy.load( + 'addressbook.thrift', + use_slots=True, + module_name='addressbook_thrift' + ) + + # normal structs will have slots + assert thrift.PhoneNumber.__slots__ == ['type', 'number', 'mix_item'] + assert thrift.Person.__slots__ == ['name', 'phones', 'created_at'] + assert thrift.AddressBook.__slots__ == ['people'] + + # get/set undefined attributes + person = thrift.Person() + with pytest.raises(AttributeError): + person.attr_not_exist = "Does not work" + + with pytest.raises(AttributeError): + person.attr_not_exist + + pn = thrift.PhoneNumber() + with pytest.raises(AttributeError): + pn.attr_not_exist = "Does not work" + + with pytest.raises(AttributeError): + pn.attr_not_exist + + ab = thrift.AddressBook() + with pytest.raises(AttributeError): + ab.attr_not_exist = "Does not work" + + with pytest.raises(AttributeError): + ab.attr_not_exist + # eo: get/set + + # exceptions will not have slots + assert not hasattr(thrift.PersonNotExistsError, '__slots__') + + # enums will not have slots + assert not hasattr(thrift.PhoneType, '__slots__') + + # service itself will not be created with slots + assert not hasattr(thrift.AddressBookService, '__slots__') + + # service args will have slots + args_slots = thrift.AddressBookService.get_phonenumbers_args.__slots__ + assert args_slots == ['name', 'count'] + + result_slots = thrift.AddressBookService.get_phonenumbers_result.__slots__ + assert result_slots == ['success'] + + # should be able to pickle slotted objects - if load with module_name + bob = thrift.Person(name="Bob") + p_str = pickle.dumps(bob) + + assert pickle.loads(p_str) == bob + + # works for recursive types too + rec = thriftpy.load('parser-cases/recursive_union.thrift', use_slots=True) + rec_slots = rec.Dynamic.__slots__ + assert rec_slots == ['boolean', 'integer', 'doubl', 'str', 'arr', 'object'] + dyn = rec.Dynamic() + with pytest.raises(AttributeError): + dyn.attr_not_exist = "shouldn't work" diff --git a/thriftpy/parser/__init__.py b/thriftpy/parser/__init__.py index 7aa7e2e..0aa92f6 100644 --- a/thriftpy/parser/__init__.py +++ b/thriftpy/parser/__init__.py @@ -15,7 +15,7 @@ from .parser import parse, parse_fp -def load(path, module_name=None, include_dirs=None, include_dir=None): +def load(path, module_name=None, include_dirs=None, include_dir=None, use_slots=False): """Load thrift file as a module. The module loaded and objects inside may only be pickled if module_name @@ -27,17 +27,17 @@ def load(path, module_name=None, include_dirs=None, include_dir=None): """ real_module = bool(module_name) thrift = parse(path, module_name, include_dirs=include_dirs, - include_dir=include_dir) + include_dir=include_dir, use_slots=use_slots) if real_module: sys.modules[module_name] = thrift return thrift -def load_fp(source, module_name): +def load_fp(source, module_name, use_slots=False): """Load thrift file like object as a module. """ - thrift = parse_fp(source, module_name) + thrift = parse_fp(source, module_name, use_slots=use_slots) sys.modules[module_name] = thrift return thrift diff --git a/thriftpy/parser/parser.py b/thriftpy/parser/parser.py index d443442..988198f 100644 --- a/thriftpy/parser/parser.py +++ b/thriftpy/parser/parser.py @@ -15,7 +15,7 @@ from .lexer import * # noqa from .exc import ThriftParserError, ThriftGrammerError from thriftpy._compat import urlopen, urlparse -from ..thrift import gen_init, TType, TPayload, TException +from ..thrift import gen_init, TType, TPayload, TSPayload, TException def p_error(p): @@ -215,7 +215,9 @@ def p_struct(p): def p_seen_struct(p): '''seen_struct : STRUCT IDENTIFIER ''' - val = _make_empty_struct(p[2]) + use_slots = p.parser.__use_slots__ + base_cls = TSPayload if use_slots else TPayload + val = _make_empty_struct(p[2], base_cls=base_cls) setattr(thrift_stack[-1], p[2], val) p[0] = val @@ -228,7 +230,9 @@ def p_union(p): def p_seen_union(p): '''seen_union : UNION IDENTIFIER ''' - val = _make_empty_struct(p[2]) + use_slots = p.parser.__use_slots__ + base_cls = TSPayload if use_slots else TPayload + val = _make_empty_struct(p[2], base_cls=base_cls) setattr(thrift_stack[-1], p[2], val) p[0] = val @@ -262,7 +266,8 @@ def p_service(p): else: extends = None - val = _make_service(p[2], p[len(p) - 2], extends) + use_slots = p.parser.__use_slots__ + val = _make_service(p[2], p[len(p) - 2], extends, use_slots=use_slots) setattr(thrift, p[2], val) _add_thrift_meta('services', val) @@ -430,8 +435,12 @@ def p_definition_type(p): thrift_cache = {} +def _get_cache_key(prefix, use_slots=False): + return ('%s:slotted' % prefix) if use_slots else prefix + + def parse(path, module_name=None, include_dirs=None, include_dir=None, - lexer=None, parser=None, enable_cache=True): + lexer=None, parser=None, enable_cache=True, use_slots=False): """Parse a single thrift file to module object, e.g.:: >>> from thriftpy.parser.parser import parse @@ -452,6 +461,7 @@ def parse(path, module_name=None, include_dirs=None, include_dir=None, :param enable_cache: if this is set to be `True`, parsed module will be cached, this is enabled by default. If `module_name` is provided, use it as cache key, else use the `path`. + :param use_slots: if set to `True` uses slots for struct members """ if os.name == 'nt' and sys.version_info < (3, 2): os.path.samefile = lambda f1, f2: os.stat(f1) == os.stat(f2) @@ -464,7 +474,8 @@ def parse(path, module_name=None, include_dirs=None, include_dir=None, global thrift_cache - cache_key = module_name or os.path.normpath(path) + cache_prefix = module_name or os.path.normpath(path) + cache_key = _get_cache_key(cache_prefix, use_slots) if enable_cache and cache_key in thrift_cache: return thrift_cache[cache_key] @@ -474,6 +485,8 @@ def parse(path, module_name=None, include_dirs=None, include_dir=None, if parser is None: parser = yacc.yacc(debug=False, write_tables=0) + parser.__use_slots__ = use_slots + global include_dirs_ if include_dirs is not None: @@ -515,7 +528,7 @@ def parse(path, module_name=None, include_dirs=None, include_dir=None, return thrift -def parse_fp(source, module_name, lexer=None, parser=None, enable_cache=True): +def parse_fp(source, module_name, lexer=None, parser=None, enable_cache=True, use_slots=False): """Parse a file-like object to thrift module object, e.g.:: >>> from thriftpy.parser.parser import parse_fp @@ -530,13 +543,16 @@ def parse_fp(source, module_name, lexer=None, parser=None, enable_cache=True): :param parser: ply parser to use, if not provided, `parse` will new one. :param enable_cache: if this is set to be `True`, parsed module will be cached by `module_name`, this is enabled by default. + :param use_slots: if set to `True` uses slots for struct members """ if not module_name.endswith('_thrift'): raise ThriftParserError('ThriftPy can only generate module with ' '\'_thrift\' suffix') + cache_key = _get_cache_key(module_name, use_slots) + if enable_cache and module_name in thrift_cache: - return thrift_cache[module_name] + return thrift_cache[cache_key] if not hasattr(source, 'read'): raise ThriftParserError('Except `source` to be a file-like object with' @@ -547,6 +563,8 @@ def parse_fp(source, module_name, lexer=None, parser=None, enable_cache=True): if parser is None: parser = yacc.yacc(debug=False, write_tables=0) + parser.__use_slots__ = use_slots + data = source.read() thrift = types.ModuleType(module_name) @@ -557,7 +575,7 @@ def parse_fp(source, module_name, lexer=None, parser=None, enable_cache=True): thrift_stack.pop() if enable_cache: - thrift_cache[module_name] = thrift + thrift_cache[cache_key] = thrift return thrift @@ -749,6 +767,8 @@ def _make_enum(name, kvs): def _make_empty_struct(name, ttype=TType.STRUCT, base_cls=TPayload): attrs = {'__module__': thrift_stack[-1].__name__, '_ttype': ttype} + if issubclass(base_cls, TSPayload): + attrs['__slots__'] = [] return type(name, (base_cls, ), attrs) @@ -769,6 +789,9 @@ def _fill_in_struct(cls, fields, _gen_init=True): setattr(cls, 'thrift_spec', thrift_spec) setattr(cls, 'default_spec', default_spec) setattr(cls, '_tspec', _tspec) + # add __slots__ for easy introspection + if issubclass(cls, TSPayload): + cls.__slots__ = [field for field, _ in default_spec] if _gen_init: gen_init(cls, thrift_spec, default_spec) return cls @@ -780,11 +803,13 @@ def _make_struct(name, fields, ttype=TType.STRUCT, base_cls=TPayload, return _fill_in_struct(cls, fields, _gen_init=_gen_init) -def _make_service(name, funcs, extends): +def _make_service(name, funcs, extends, use_slots=False): if extends is None: extends = object attrs = {'__module__': thrift_stack[-1].__name__} + base_cls = TSPayload if use_slots else TPayload + # service class itself will not be created with slots cls = type(name, (extends, ), attrs) thrift_services = [] @@ -793,7 +818,7 @@ def _make_service(name, funcs, extends): # args payload cls args_name = '%s_args' % func_name args_fields = func[3] - args_cls = _make_struct(args_name, args_fields) + args_cls = _make_struct(args_name, args_fields, base_cls=base_cls) setattr(cls, args_name, args_cls) # result payload cls result_name = '%s_result' % func_name @@ -801,13 +826,15 @@ def _make_service(name, funcs, extends): result_throws = func[4] result_oneway = func[0] result_cls = _make_struct(result_name, result_throws, - _gen_init=False) + _gen_init=False, base_cls=base_cls) setattr(result_cls, 'oneway', result_oneway) if result_type != TType.VOID: result_cls.thrift_spec[0] = _ttype_spec(result_type, 'success') result_cls.default_spec.insert(0, ('success', None)) gen_init(result_cls, result_cls.thrift_spec, result_cls.default_spec) setattr(cls, result_name, result_cls) + # default spec is modified after making struct so add slots here + result_cls.__slots__ = [f for f, _ in result_cls.default_spec] thrift_services.append(func_name) if extends is not None and hasattr(extends, 'thrift_services'): thrift_services.extend(extends.thrift_services) diff --git a/thriftpy/thrift.py b/thriftpy/thrift.py index bf1db20..7310a29 100644 --- a/thriftpy/thrift.py +++ b/thriftpy/thrift.py @@ -9,6 +9,11 @@ from __future__ import absolute_import +try: + import copy_reg as copyreg +except ImportError: + import copyreg + import functools import linecache import types @@ -126,12 +131,35 @@ class TMessageType(object): class TPayloadMeta(type): + _class_cache = {} + def __new__(cls, name, bases, attrs): if "default_spec" in attrs: spec = attrs.pop("default_spec") attrs["__init__"] = init_func_generator(cls, spec) return super(TPayloadMeta, cls).__new__(cls, name, bases, attrs) + def __call__(cls, *args, **kw): + if not issubclass(cls, TSPayload): + return type.__call__(cls, *args, **kw) + cls_name = cls.__name__.split('.')[-1] + cache_key = '%s:%s' % (cls.__module__, cls_name) + kls = TPayloadMeta._class_cache.get(cache_key) + if not kls: + fields = [field for field, _ in cls.default_spec] + kls = type( + cls_name, + (cls,), + { + '__slots__': fields, + '__module__': cls.__module__, + } + ) + TPayloadMeta._class_cache[cache_key] = kls + fn = lambda obj: (cls, tuple(getattr(obj, f) for f in fields)) + copyreg.pickle(kls, fn) + return type.__call__(kls, *args, **kw) + def gen_init(cls, thrift_spec=None, default_spec=None): if thrift_spec is not None: @@ -167,6 +195,39 @@ def __ne__(self, other): return not self.__eq__(other) +class TSPayload(with_metaclass(TPayloadMeta, object)): + + __slots__ = tuple() + + __hash__ = None + + def read(self, iprot): + iprot.read_struct(self) + + def write(self, oprot): + oprot.write_struct(self) + + def __repr__(self): + keys = self.__slots__ + values = [getattr(self, k) for k in keys] + l = ['%s=%r' % (key, value) for key, value in zip(keys, values)] + return '%s(%s)' % (self.__class__.__name__, ', '.join(l)) + + def __str__(self): + return repr(self) + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + keys = self.__slots__ + this = [getattr(self, k) for k in keys] + other_ = [getattr(other, k) for k in keys] + return this == other_ + + def __ne__(self, other): + return not self.__eq__(other) + + class TClient(object): def __init__(self, service, iprot, oprot=None):