diff --git a/kafka/coordinator/assignors/sticky/sticky_assignor.py b/kafka/coordinator/assignors/sticky/sticky_assignor.py index 243c26709..5e29d5894 100644 --- a/kafka/coordinator/assignors/sticky/sticky_assignor.py +++ b/kafka/coordinator/assignors/sticky/sticky_assignor.py @@ -6,9 +6,8 @@ from kafka.coordinator.assignors.sticky.partition_movements import PartitionMovements from kafka.coordinator.assignors.sticky.sorted_set import SortedSet from kafka.coordinator.protocol import ConsumerProtocolMemberMetadata_v0, ConsumerProtocolMemberAssignment_v0 -from kafka.coordinator.protocol import Schema from kafka.protocol.struct import Struct -from kafka.protocol.types import String, Array, Int32 +from kafka.protocol.types import Array, Int32, Schema, String from kafka.structs import TopicPartition log = logging.getLogger(__name__) @@ -59,7 +58,10 @@ class StickyAssignorUserDataV1(Struct): """ SCHEMA = Schema( - ("previous_assignment", Array(("topic", String("utf-8")), ("partitions", Array(Int32)))), ("generation", Int32) + ("previous_assignment", Array( + ("topic", String("utf-8")), + ("partitions", Array(Int32)))), + ("generation", Int32) ) diff --git a/kafka/protocol/abstract.py b/kafka/protocol/abstract.py index e0d89433b..5817673cb 100644 --- a/kafka/protocol/abstract.py +++ b/kafka/protocol/abstract.py @@ -2,10 +2,12 @@ class AbstractType(object, metaclass=abc.ABCMeta): + @classmethod @abc.abstractmethod def encode(cls, value): # pylint: disable=no-self-argument pass + @classmethod @abc.abstractmethod def decode(cls, data): # pylint: disable=no-self-argument pass diff --git a/kafka/protocol/api.py b/kafka/protocol/api.py index c7a477cac..b6b8a2996 100644 --- a/kafka/protocol/api.py +++ b/kafka/protocol/api.py @@ -60,11 +60,6 @@ def API_VERSION(self): """Integer of api request version""" pass - @abc.abstractproperty - def SCHEMA(self): - """An instance of Schema() representing the request structure""" - pass - @abc.abstractproperty def RESPONSE_TYPE(self): """The Response class associated with the api request""" @@ -96,11 +91,6 @@ def API_VERSION(self): """Integer of api request/response version""" pass - @abc.abstractproperty - def SCHEMA(self): - """An instance of Schema() representing the response structure""" - pass - def to_object(self): return _to_object(self.SCHEMA, self) diff --git a/kafka/protocol/group.py b/kafka/protocol/group.py index 5d35ab219..a56bd48dc 100644 --- a/kafka/protocol/group.py +++ b/kafka/protocol/group.py @@ -158,6 +158,7 @@ class JoinGroupRequest_v5(Request): ] +# Currently unused -- see kafka.coordinator.protocol class ProtocolMetadata(Struct): SCHEMA = Schema( ('version', Int16), @@ -250,6 +251,7 @@ class SyncGroupRequest_v3(Request): ] +# Currently unused -- see kafka.coordinator.protocol class MemberAssignment(Struct): SCHEMA = Schema( ('version', Int16), diff --git a/kafka/protocol/struct.py b/kafka/protocol/struct.py index b482326fa..f66170c60 100644 --- a/kafka/protocol/struct.py +++ b/kafka/protocol/struct.py @@ -1,3 +1,4 @@ +import abc from io import BytesIO from kafka.protocol.abstract import AbstractType @@ -6,11 +7,15 @@ from kafka.util import WeakMethod -class Struct(AbstractType): - SCHEMA = Schema() +class Struct(metaclass=abc.ABCMeta): + + @abc.abstractproperty + def SCHEMA(self): + """An instance of Schema() representing the structure""" + pass def __init__(self, *args, **kwargs): - if len(args) == len(self.SCHEMA.fields): + if len(args) == len(self.SCHEMA): for i, name in enumerate(self.SCHEMA.names): setattr(self, name, args[i]) elif len(args) > 0: @@ -23,19 +28,7 @@ def __init__(self, *args, **kwargs): % (list(self.SCHEMA.names), ', '.join(kwargs.keys()))) - # overloading encode() to support both class and instance - # Without WeakMethod() this creates circular ref, which - # causes instances to "leak" to garbage - self.encode = WeakMethod(self._encode_self) - - @classmethod - def encode(cls, item): # pylint: disable=E0202 - bits = [] - for i, field in enumerate(cls.SCHEMA.fields): - bits.append(field.encode(item[i])) - return b''.join(bits) - - def _encode_self(self): + def encode(self): return self.SCHEMA.encode( [getattr(self, name) for name in self.SCHEMA.names] ) @@ -44,7 +37,7 @@ def _encode_self(self): def decode(cls, data): if isinstance(data, bytes): data = BytesIO(data) - return cls(*[field.decode(data) for field in cls.SCHEMA.fields]) + return cls(*cls.SCHEMA.decode(data)) def get_item(self, name): if name not in self.SCHEMA.names: