@@ -67,6 +67,10 @@ def __init__(
6767 self ._msg_desc_by_type = {}
6868 self ._msg_desc_by_name = {}
6969 self ._msg_type_by_desc = {}
70+ self ._field_by_msg_desc = {}
71+
72+ self ._envelope_class = None
73+ self ._envelope_desc = None
7074
7175 self ._address = "tcp://{0}:{1}" .format (server_host , server_port )
7276 # init zeromq
@@ -130,6 +134,10 @@ def registerEnvelope(self, envelope_module, envelope_class_name="Envelope"):
130134 envelope_class = getattr (envelope_module , envelope_class_name )
131135 envelope_desc = envelope_class .DESCRIPTOR
132136
137+ # Store for envelope wrapping/unwrapping
138+ self ._envelope_class = envelope_class
139+ self ._envelope_desc = envelope_desc
140+
133141 for field in envelope_desc .fields :
134142 if field .message_type is None :
135143 # Skip non-message fields (e.g. scalars) if any exist
@@ -141,6 +149,7 @@ def registerEnvelope(self, envelope_module, envelope_class_name="Envelope"):
141149 self ._msg_desc_by_type [msg_type ] = desc
142150 self ._msg_desc_by_name [desc .name ] = desc
143151 self ._msg_type_by_desc [desc ] = msg_type
152+ self ._field_by_msg_desc [desc ] = field
144153
145154 self ._logger .debug (
146155 "Registered %d message types from %s" ,
@@ -177,15 +186,15 @@ def registerProtocol(self, msg_module):
177186 ##
178187 # @brief Receive a message
179188 #
180- # Receive a protobuf message with timeout. This method automatically
181- # parses and creates a new protobuf message class based on received
182- # framing. The new message object, the message name (defined in the
183- # associated proto file) , and re-association context are returned as
184- # a tuple. On timeout, (None,None,None) is returned.
189+ # Receive a protobuf message with timeout. The wire payload is an
190+ # Envelope message; this method deserializes the Envelope and extracts
191+ # the inner message via the oneof payload field. The inner message
192+ # object, its name , and re-association context are returned as a tuple.
193+ # On timeout, (None, None, None) is returned.
185194 #
186195 # @param timeout - Timeout in milliseconds
187- # @return Tuple of message, message type , and re-association context
188- # @retval (object,str,int) or (None,None,None) on timeout
196+ # @return Tuple of message, message name , and re-association context
197+ # @retval (object, str, int) or (None, None, None) on timeout
189198 # @exception Exception: if unregistered message type is received.
190199 #
191200 def recv (self , a_timeout = 1000 ):
@@ -219,48 +228,62 @@ def recv(self, a_timeout=1000):
219228 # client
220229 self ._socket .recv_string (0 )
221230
222- # receive custom frame header and unpack
231+ # Receive frame: 8 bytes = uint32 size + uint16 msg_type + uint16 context
223232 frame_data = self ._socket .recv (0 )
224- frame_values = struct .unpack (">LBBH" , frame_data )
225- msg_type = (frame_values [1 ] << 8 ) | frame_values [2 ]
233+ frame_values = struct .unpack (">LHH" , frame_data )
234+ body_size = frame_values [0 ]
235+ msg_type = frame_values [1 ]
236+ ctxt = frame_values [2 ]
226237
227- # find message descriptor based on type (descriptor index)
228-
229- if not (msg_type in self ._msg_desc_by_type ):
238+ if msg_type not in self ._msg_desc_by_type :
230239 raise Exception (
231240 "received unregistered message type: {}" .format (msg_type )
232241 )
233242
234- desc = self ._msg_desc_by_type [ msg_type ]
243+ data = self ._socket . recv ( 0 )
235244
236- if frame_values [0 ] > 0 :
237- # Create message by parsing content
238- data = self ._socket .recv (0 )
239- reply = GetMessageClass (desc )()
240- reply .ParseFromString (data )
245+ if body_size > 0 :
246+ # Deserialize as Envelope
247+ envelope = self ._envelope_class ()
248+ envelope .ParseFromString (data )
249+
250+ # Extract inner message from the oneof
251+ payload_field = envelope .WhichOneof ("payload" )
252+ if payload_field is None :
253+ raise Exception ("Received Envelope with no payload set" )
254+ reply = getattr (envelope , payload_field )
241255 else :
242- # No content, just create message instance
243- data = self ._socket . recv ( 0 )
256+ # Zero-size body: create empty message instance from type
257+ desc = self ._msg_desc_by_type [ msg_type ]
244258 reply = GetMessageClass (desc )()
245259
246- return reply , desc . name , frame_values [ 3 ]
260+ return reply , reply . DESCRIPTOR . name , ctxt
247261 else :
248262 return None , None , None
249263
250264 ##
251265 # @brief Send a message
252266 #
253- # Serializes and sends framing and message payload over connection.
267+ # Wraps the inner message in an Envelope, serializes it, and sends
268+ # framing and payload over the connection. The frame header carries the
269+ # message type (Envelope field number) for efficient routing on the
270+ # server side.
254271 #
255272 # @param message - The protobuf message object to be sent
256273 # @param ctxt - Reply re-association value (int)
257274 # @exception Exception: if unregistered message type is sent.
258275 #
259276 def send (self , message , ctxt ):
260277 # Find msg type by descriptor look-up
261- if not ( message .DESCRIPTOR in self ._msg_type_by_desc ) :
278+ if message .DESCRIPTOR not in self ._msg_type_by_desc :
262279 raise Exception ("Attempt to send unregistered message type." )
280+
263281 msg_type = self ._msg_type_by_desc [message .DESCRIPTOR ]
282+ field = self ._field_by_msg_desc [message .DESCRIPTOR ]
283+
284+ # Wrap inner message in Envelope
285+ envelope = self ._envelope_class ()
286+ getattr (envelope , field .name ).CopyFrom (message )
264287
265288 # Initial Null frame
266289 self ._socket .send_string ("BEGIN_DATAFED" , zmq .SNDMORE )
@@ -274,12 +297,12 @@ def send(self, message, ctxt):
274297 self ._socket .send_string (self ._pub_key , zmq .SNDMORE )
275298 self ._socket .send_string ("no_user" , zmq .SNDMORE )
276299
277- # Serialize
278- data = message .SerializeToString ()
300+ # Serialize the Envelope (not the inner message)
301+ data = envelope .SerializeToString ()
279302 data_sz = len (data )
280303
281- # Build the message frame, to match C-struct MessageFrame
282- frame = struct .pack (">LBBH " , data_sz , msg_type >> 8 , msg_type & 0xFF , ctxt )
304+ # Build the message frame: uint32 size + uint16 msg_type + uint16 context
305+ frame = struct .pack (">LHH " , data_sz , msg_type , ctxt )
283306
284307 if data_sz > 0 :
285308 # Send frame and payload
0 commit comments