Skip to content
This repository was archived by the owner on Dec 10, 2018. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 66 additions & 0 deletions tests/test_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,69 @@ def test_tpayload_pickle():
person_2 = pickle.loads(PICKLED_BYTES)

assert person == person_2


def test_load_slots():
thrift = thriftpy.load(
'addressbook.thrift',
use_slots=True,
module_name='addressbook_thrift'
)

# normal structs will have slots
assert thrift.PhoneNumber.__slots__ == ['type', 'number', 'mix_item']
assert thrift.Person.__slots__ == ['name', 'phones', 'created_at']
assert thrift.AddressBook.__slots__ == ['people']

# get/set undefined attributes
person = thrift.Person()
with pytest.raises(AttributeError):
person.attr_not_exist = "Does not work"

with pytest.raises(AttributeError):
person.attr_not_exist

pn = thrift.PhoneNumber()
with pytest.raises(AttributeError):
pn.attr_not_exist = "Does not work"

with pytest.raises(AttributeError):
pn.attr_not_exist

ab = thrift.AddressBook()
with pytest.raises(AttributeError):
ab.attr_not_exist = "Does not work"

with pytest.raises(AttributeError):
ab.attr_not_exist
# eo: get/set

# exceptions will not have slots
assert not hasattr(thrift.PersonNotExistsError, '__slots__')

# enums will not have slots
assert not hasattr(thrift.PhoneType, '__slots__')

# service itself will not be created with slots
assert not hasattr(thrift.AddressBookService, '__slots__')

# service args will have slots
args_slots = thrift.AddressBookService.get_phonenumbers_args.__slots__
assert args_slots == ['name', 'count']

result_slots = thrift.AddressBookService.get_phonenumbers_result.__slots__
assert result_slots == ['success']

# should be able to pickle slotted objects - if load with module_name
bob = thrift.Person(name="Bob")
p_str = pickle.dumps(bob)

assert pickle.loads(p_str) == bob

# works for recursive types too
rec = thriftpy.load('parser-cases/recursive_union.thrift', use_slots=True)
rec_slots = rec.Dynamic.__slots__
assert rec_slots == ['boolean', 'integer', 'doubl', 'str', 'arr', 'object']
dyn = rec.Dynamic()
with pytest.raises(AttributeError):
dyn.attr_not_exist = "shouldn't work"
8 changes: 4 additions & 4 deletions thriftpy/parser/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .parser import parse, parse_fp


def load(path, module_name=None, include_dirs=None, include_dir=None):
def load(path, module_name=None, include_dirs=None, include_dir=None, use_slots=False):
"""Load thrift file as a module.

The module loaded and objects inside may only be pickled if module_name
Expand All @@ -27,17 +27,17 @@ def load(path, module_name=None, include_dirs=None, include_dir=None):
"""
real_module = bool(module_name)
thrift = parse(path, module_name, include_dirs=include_dirs,
include_dir=include_dir)
include_dir=include_dir, use_slots=use_slots)

if real_module:
sys.modules[module_name] = thrift
return thrift


def load_fp(source, module_name):
def load_fp(source, module_name, use_slots=False):
"""Load thrift file like object as a module.
"""
thrift = parse_fp(source, module_name)
thrift = parse_fp(source, module_name, use_slots=use_slots)
sys.modules[module_name] = thrift
return thrift

Expand Down
51 changes: 39 additions & 12 deletions thriftpy/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .lexer import * # noqa
from .exc import ThriftParserError, ThriftGrammerError
from thriftpy._compat import urlopen, urlparse
from ..thrift import gen_init, TType, TPayload, TException
from ..thrift import gen_init, TType, TPayload, TSPayload, TException


def p_error(p):
Expand Down Expand Up @@ -215,7 +215,9 @@ def p_struct(p):

def p_seen_struct(p):
'''seen_struct : STRUCT IDENTIFIER '''
val = _make_empty_struct(p[2])
use_slots = p.parser.__use_slots__
base_cls = TSPayload if use_slots else TPayload
val = _make_empty_struct(p[2], base_cls=base_cls)
setattr(thrift_stack[-1], p[2], val)
p[0] = val

Expand All @@ -228,7 +230,9 @@ def p_union(p):

def p_seen_union(p):
'''seen_union : UNION IDENTIFIER '''
val = _make_empty_struct(p[2])
use_slots = p.parser.__use_slots__
base_cls = TSPayload if use_slots else TPayload
val = _make_empty_struct(p[2], base_cls=base_cls)
setattr(thrift_stack[-1], p[2], val)
p[0] = val

Expand Down Expand Up @@ -262,7 +266,8 @@ def p_service(p):
else:
extends = None

val = _make_service(p[2], p[len(p) - 2], extends)
use_slots = p.parser.__use_slots__
val = _make_service(p[2], p[len(p) - 2], extends, use_slots=use_slots)
setattr(thrift, p[2], val)
_add_thrift_meta('services', val)

Expand Down Expand Up @@ -430,8 +435,12 @@ def p_definition_type(p):
thrift_cache = {}


def _get_cache_key(prefix, use_slots=False):
return ('%s:slotted' % prefix) if use_slots else prefix


def parse(path, module_name=None, include_dirs=None, include_dir=None,
lexer=None, parser=None, enable_cache=True):
lexer=None, parser=None, enable_cache=True, use_slots=False):
"""Parse a single thrift file to module object, e.g.::

>>> from thriftpy.parser.parser import parse
Expand All @@ -452,6 +461,7 @@ def parse(path, module_name=None, include_dirs=None, include_dir=None,
:param enable_cache: if this is set to be `True`, parsed module will be
cached, this is enabled by default. If `module_name`
is provided, use it as cache key, else use the `path`.
:param use_slots: if set to `True` uses slots for struct members
"""
if os.name == 'nt' and sys.version_info < (3, 2):
os.path.samefile = lambda f1, f2: os.stat(f1) == os.stat(f2)
Expand All @@ -464,7 +474,8 @@ def parse(path, module_name=None, include_dirs=None, include_dir=None,

global thrift_cache

cache_key = module_name or os.path.normpath(path)
cache_prefix = module_name or os.path.normpath(path)
cache_key = _get_cache_key(cache_prefix, use_slots)

if enable_cache and cache_key in thrift_cache:
return thrift_cache[cache_key]
Expand All @@ -474,6 +485,8 @@ def parse(path, module_name=None, include_dirs=None, include_dir=None,
if parser is None:
parser = yacc.yacc(debug=False, write_tables=0)

parser.__use_slots__ = use_slots

global include_dirs_

if include_dirs is not None:
Expand Down Expand Up @@ -515,7 +528,7 @@ def parse(path, module_name=None, include_dirs=None, include_dir=None,
return thrift


def parse_fp(source, module_name, lexer=None, parser=None, enable_cache=True):
def parse_fp(source, module_name, lexer=None, parser=None, enable_cache=True, use_slots=False):
"""Parse a file-like object to thrift module object, e.g.::

>>> from thriftpy.parser.parser import parse_fp
Expand All @@ -530,13 +543,16 @@ def parse_fp(source, module_name, lexer=None, parser=None, enable_cache=True):
:param parser: ply parser to use, if not provided, `parse` will new one.
:param enable_cache: if this is set to be `True`, parsed module will be
cached by `module_name`, this is enabled by default.
:param use_slots: if set to `True` uses slots for struct members
"""
if not module_name.endswith('_thrift'):
raise ThriftParserError('ThriftPy can only generate module with '
'\'_thrift\' suffix')

cache_key = _get_cache_key(module_name, use_slots)

if enable_cache and module_name in thrift_cache:
return thrift_cache[module_name]
return thrift_cache[cache_key]

if not hasattr(source, 'read'):
raise ThriftParserError('Except `source` to be a file-like object with'
Expand All @@ -547,6 +563,8 @@ def parse_fp(source, module_name, lexer=None, parser=None, enable_cache=True):
if parser is None:
parser = yacc.yacc(debug=False, write_tables=0)

parser.__use_slots__ = use_slots

data = source.read()

thrift = types.ModuleType(module_name)
Expand All @@ -557,7 +575,7 @@ def parse_fp(source, module_name, lexer=None, parser=None, enable_cache=True):
thrift_stack.pop()

if enable_cache:
thrift_cache[module_name] = thrift
thrift_cache[cache_key] = thrift
return thrift


Expand Down Expand Up @@ -749,6 +767,8 @@ def _make_enum(name, kvs):

def _make_empty_struct(name, ttype=TType.STRUCT, base_cls=TPayload):
attrs = {'__module__': thrift_stack[-1].__name__, '_ttype': ttype}
if issubclass(base_cls, TSPayload):
attrs['__slots__'] = []
return type(name, (base_cls, ), attrs)


Expand All @@ -769,6 +789,9 @@ def _fill_in_struct(cls, fields, _gen_init=True):
setattr(cls, 'thrift_spec', thrift_spec)
setattr(cls, 'default_spec', default_spec)
setattr(cls, '_tspec', _tspec)
# add __slots__ for easy introspection
if issubclass(cls, TSPayload):
cls.__slots__ = [field for field, _ in default_spec]
if _gen_init:
gen_init(cls, thrift_spec, default_spec)
return cls
Expand All @@ -780,11 +803,13 @@ def _make_struct(name, fields, ttype=TType.STRUCT, base_cls=TPayload,
return _fill_in_struct(cls, fields, _gen_init=_gen_init)


def _make_service(name, funcs, extends):
def _make_service(name, funcs, extends, use_slots=False):
if extends is None:
extends = object

attrs = {'__module__': thrift_stack[-1].__name__}
base_cls = TSPayload if use_slots else TPayload
# service class itself will not be created with slots
cls = type(name, (extends, ), attrs)
thrift_services = []

Expand All @@ -793,21 +818,23 @@ def _make_service(name, funcs, extends):
# args payload cls
args_name = '%s_args' % func_name
args_fields = func[3]
args_cls = _make_struct(args_name, args_fields)
args_cls = _make_struct(args_name, args_fields, base_cls=base_cls)
setattr(cls, args_name, args_cls)
# result payload cls
result_name = '%s_result' % func_name
result_type = func[1]
result_throws = func[4]
result_oneway = func[0]
result_cls = _make_struct(result_name, result_throws,
_gen_init=False)
_gen_init=False, base_cls=base_cls)
setattr(result_cls, 'oneway', result_oneway)
if result_type != TType.VOID:
result_cls.thrift_spec[0] = _ttype_spec(result_type, 'success')
result_cls.default_spec.insert(0, ('success', None))
gen_init(result_cls, result_cls.thrift_spec, result_cls.default_spec)
setattr(cls, result_name, result_cls)
# default spec is modified after making struct so add slots here
result_cls.__slots__ = [f for f, _ in result_cls.default_spec]
thrift_services.append(func_name)
if extends is not None and hasattr(extends, 'thrift_services'):
thrift_services.extend(extends.thrift_services)
Expand Down
61 changes: 61 additions & 0 deletions thriftpy/thrift.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@

from __future__ import absolute_import

try:
import copy_reg as copyreg
except ImportError:
import copyreg

import functools
import linecache
import types
Expand Down Expand Up @@ -126,12 +131,35 @@ class TMessageType(object):

class TPayloadMeta(type):

_class_cache = {}

def __new__(cls, name, bases, attrs):
if "default_spec" in attrs:
spec = attrs.pop("default_spec")
attrs["__init__"] = init_func_generator(cls, spec)
return super(TPayloadMeta, cls).__new__(cls, name, bases, attrs)

def __call__(cls, *args, **kw):
if not issubclass(cls, TSPayload):
return type.__call__(cls, *args, **kw)
cls_name = cls.__name__.split('.')[-1]
cache_key = '%s:%s' % (cls.__module__, cls_name)
kls = TPayloadMeta._class_cache.get(cache_key)
if not kls:
fields = [field for field, _ in cls.default_spec]
kls = type(
cls_name,
(cls,),
{
'__slots__': fields,
'__module__': cls.__module__,
}
)
TPayloadMeta._class_cache[cache_key] = kls
fn = lambda obj: (cls, tuple(getattr(obj, f) for f in fields))
copyreg.pickle(kls, fn)
return type.__call__(kls, *args, **kw)


def gen_init(cls, thrift_spec=None, default_spec=None):
if thrift_spec is not None:
Expand Down Expand Up @@ -167,6 +195,39 @@ def __ne__(self, other):
return not self.__eq__(other)


class TSPayload(with_metaclass(TPayloadMeta, object)):

__slots__ = tuple()

__hash__ = None

def read(self, iprot):
iprot.read_struct(self)

def write(self, oprot):
oprot.write_struct(self)

def __repr__(self):
keys = self.__slots__
values = [getattr(self, k) for k in keys]
l = ['%s=%r' % (key, value) for key, value in zip(keys, values)]
return '%s(%s)' % (self.__class__.__name__, ', '.join(l))

def __str__(self):
return repr(self)

def __eq__(self, other):
if not isinstance(other, self.__class__):
return False
keys = self.__slots__
this = [getattr(self, k) for k in keys]
other_ = [getattr(other, k) for k in keys]
return this == other_

def __ne__(self, other):
return not self.__eq__(other)


class TClient(object):

def __init__(self, service, iprot, oprot=None):
Expand Down