Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions kafka/coordinator/assignors/sticky/sticky_assignor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@
from kafka.coordinator.assignors.sticky.partition_movements import PartitionMovements
from kafka.coordinator.assignors.sticky.sorted_set import SortedSet
from kafka.coordinator.protocol import ConsumerProtocolMemberMetadata_v0, ConsumerProtocolMemberAssignment_v0
from kafka.coordinator.protocol import Schema
from kafka.protocol.struct import Struct
from kafka.protocol.types import String, Array, Int32
from kafka.protocol.types import Array, Int32, Schema, String
from kafka.structs import TopicPartition

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -59,7 +58,10 @@ class StickyAssignorUserDataV1(Struct):
"""

SCHEMA = Schema(
("previous_assignment", Array(("topic", String("utf-8")), ("partitions", Array(Int32)))), ("generation", Int32)
("previous_assignment", Array(
("topic", String("utf-8")),
("partitions", Array(Int32)))),
("generation", Int32)
)


Expand Down
2 changes: 2 additions & 0 deletions kafka/protocol/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@


class AbstractType(object, metaclass=abc.ABCMeta):
@classmethod
@abc.abstractmethod
def encode(cls, value): # pylint: disable=no-self-argument
pass

@classmethod
@abc.abstractmethod
def decode(cls, data): # pylint: disable=no-self-argument
pass
Expand Down
10 changes: 0 additions & 10 deletions kafka/protocol/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,11 +60,6 @@ def API_VERSION(self):
"""Integer of api request version"""
pass

@abc.abstractproperty
def SCHEMA(self):
"""An instance of Schema() representing the request structure"""
pass

@abc.abstractproperty
def RESPONSE_TYPE(self):
"""The Response class associated with the api request"""
Expand Down Expand Up @@ -96,11 +91,6 @@ def API_VERSION(self):
"""Integer of api request/response version"""
pass

@abc.abstractproperty
def SCHEMA(self):
"""An instance of Schema() representing the response structure"""
pass

def to_object(self):
return _to_object(self.SCHEMA, self)

Expand Down
2 changes: 2 additions & 0 deletions kafka/protocol/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ class JoinGroupRequest_v5(Request):
]


# Currently unused -- see kafka.coordinator.protocol
class ProtocolMetadata(Struct):
SCHEMA = Schema(
('version', Int16),
Expand Down Expand Up @@ -250,6 +251,7 @@ class SyncGroupRequest_v3(Request):
]


# Currently unused -- see kafka.coordinator.protocol
class MemberAssignment(Struct):
SCHEMA = Schema(
('version', Int16),
Expand Down
27 changes: 10 additions & 17 deletions kafka/protocol/struct.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
from io import BytesIO

from kafka.protocol.abstract import AbstractType
Expand All @@ -6,11 +7,15 @@
from kafka.util import WeakMethod


class Struct(AbstractType):
SCHEMA = Schema()
class Struct(metaclass=abc.ABCMeta):

@abc.abstractproperty
def SCHEMA(self):
"""An instance of Schema() representing the structure"""
pass

def __init__(self, *args, **kwargs):
if len(args) == len(self.SCHEMA.fields):
if len(args) == len(self.SCHEMA):
for i, name in enumerate(self.SCHEMA.names):
setattr(self, name, args[i])
elif len(args) > 0:
Expand All @@ -23,19 +28,7 @@ def __init__(self, *args, **kwargs):
% (list(self.SCHEMA.names),
', '.join(kwargs.keys())))

# overloading encode() to support both class and instance
# Without WeakMethod() this creates circular ref, which
# causes instances to "leak" to garbage
self.encode = WeakMethod(self._encode_self)

@classmethod
def encode(cls, item): # pylint: disable=E0202
bits = []
for i, field in enumerate(cls.SCHEMA.fields):
bits.append(field.encode(item[i]))
return b''.join(bits)

def _encode_self(self):
def encode(self):
return self.SCHEMA.encode(
[getattr(self, name) for name in self.SCHEMA.names]
)
Expand All @@ -44,7 +37,7 @@ def _encode_self(self):
def decode(cls, data):
if isinstance(data, bytes):
data = BytesIO(data)
return cls(*[field.decode(data) for field in cls.SCHEMA.fields])
return cls(*cls.SCHEMA.decode(data))

def get_item(self, name):
if name not in self.SCHEMA.names:
Expand Down
Loading