diff --git a/tests/test_protocol_binary.py b/tests/test_protocol_binary.py index beb03cc..7ff51b3 100644 --- a/tests/test_protocol_binary.py +++ b/tests/test_protocol_binary.py @@ -2,8 +2,10 @@ from io import BytesIO +import pytest + from thriftpy._compat import u -from thriftpy.thrift import TType, TPayload +from thriftpy.thrift import TType, TPayload, TDecodeException from thriftpy.utils import hexlify from thriftpy.protocol import binary as proto @@ -111,6 +113,33 @@ def test_write_message_begin_not_strict(): hexlify(b.getvalue()) +def test_write_decode_error(): + b = BytesIO() + p = proto.TBinaryProtocol(b) + + class T(TPayload): + thrift_spec = { + 1: (TType.I32, "id", False), + 2: (TType.LIST, "phones", TType.STRING, False), + 3: (TType.STRUCT, "item", TItem, False), + 4: (TType.MAP, "mm", (TType.STRING, (TType.STRUCT, TItem)), False) + } + default_spec = [("id", None), ("phones", None), ("item", None), + ("mm", None)] + + cases = [ + (T(id="hello"), "Field 'id(1)' of 'T' needs type 'I32', but the value is `'hello'`"), # noqa + (T(phones=[90, 12]), "Field 'phones(2)' of 'T' needs type 'LIST', but the value is `[90, 12]`"), # noqa + (T(item=12), "Field 'item(3)' of 'T' needs type 'TItem', but the value is `12`"), # noqa + (T(mm=[45, 56]), "Field 'mm(4)' of 'T' needs type 'MAP', but the value is `[45, 56]`") # noqa + ] + + for obj, res in cases: + with pytest.raises(TDecodeException) as exc: + p.write_struct(obj) + assert str(exc.value) == res + + def test_read_message_begin(): b = BytesIO(b"\x80\x01\x00\x0b\x00\x00\x00\x04test\x00\x00\x00\x01") res = proto.TBinaryProtocol(b).read_message_begin() diff --git a/thriftpy/protocol/binary.py b/thriftpy/protocol/binary.py index 7ba21f7..eb4531e 100644 --- a/thriftpy/protocol/binary.py +++ b/thriftpy/protocol/binary.py @@ -4,7 +4,7 @@ import struct -from ..thrift import TType +from ..thrift import TType, TDecodeException from .exc import TProtocolException @@ -86,6 +86,60 @@ def write_map_begin(outbuf, ktype, vtype, size): outbuf.write(pack_i8(ktype) + pack_i8(vtype) + pack_i32(size)) +def write_struct(outbuf, val): + for fid in iter(val.thrift_spec): + f_spec = val.thrift_spec[fid] + if len(f_spec) == 3: + f_type, f_name, f_req = f_spec + f_container_spec = None + else: + f_type, f_name, f_container_spec, f_req = f_spec + + v = getattr(val, f_name) + if v is None: + continue + + write_field_begin(outbuf, f_type, fid) + try: + write_val(outbuf, f_type, v, f_container_spec) + except (TypeError, AttributeError, AssertionError, OverflowError, struct.error): + raise TDecodeException(val.__class__.__name__, fid, f_name, v, + f_type, f_container_spec) + + write_field_stop(outbuf) + + +def write_dict(outbuf, val, spec): + if isinstance(spec[0], int): + k_type = spec[0] + k_spec = None + else: + k_type, k_spec = spec[0] + + if isinstance(spec[1], int): + v_type = spec[1] + v_spec = None + else: + v_type, v_spec = spec[1] + + write_map_begin(outbuf, k_type, v_type, len(val)) + for k in iter(val): + write_val(outbuf, k_type, k, k_spec) + write_val(outbuf, v_type, val[k], v_spec) + + +def write_list(outbuf, val, spec): + if isinstance(spec, tuple): + e_type, t_spec = spec[0], spec[1] + else: + e_type, t_spec = spec, None + + val_len = len(val) + write_list_begin(outbuf, e_type, val_len) + for e_val in val: + write_val(outbuf, e_type, e_val, t_spec) + + def write_val(outbuf, ttype, val, spec=None): if ttype == TType.BOOL: if val: @@ -114,50 +168,13 @@ def write_val(outbuf, ttype, val, spec=None): outbuf.write(pack_string(val)) elif ttype == TType.SET or ttype == TType.LIST: - if isinstance(spec, tuple): - e_type, t_spec = spec[0], spec[1] - else: - e_type, t_spec = spec, None - - val_len = len(val) - write_list_begin(outbuf, e_type, val_len) - for e_val in val: - write_val(outbuf, e_type, e_val, t_spec) + write_list(outbuf, val, spec) elif ttype == TType.MAP: - if isinstance(spec[0], int): - k_type = spec[0] - k_spec = None - else: - k_type, k_spec = spec[0] - - if isinstance(spec[1], int): - v_type = spec[1] - v_spec = None - else: - v_type, v_spec = spec[1] - - write_map_begin(outbuf, k_type, v_type, len(val)) - for k in iter(val): - write_val(outbuf, k_type, k, k_spec) - write_val(outbuf, v_type, val[k], v_spec) + write_dict(outbuf, val, spec) elif ttype == TType.STRUCT: - for fid in iter(val.thrift_spec): - f_spec = val.thrift_spec[fid] - if len(f_spec) == 3: - f_type, f_name, f_req = f_spec - f_container_spec = None - else: - f_type, f_name, f_container_spec, f_req = f_spec - - v = getattr(val, f_name) - if v is None: - continue - - write_field_begin(outbuf, f_type, fid) - write_val(outbuf, f_type, v, f_container_spec) - write_field_stop(outbuf) + write_struct(outbuf, val) def read_message_begin(inbuf, strict=True):