88import re
99import ssl
1010import string
11+ import time
1112from itertools import count
1213from itertools import product
1314from threading import Event
1415from threading import Lock
1516from typing import Any
1617from typing import Dict
1718from typing import List
18- from typing import no_type_check
1919from typing import Tuple
20+ from typing import no_type_check
2021
2122import paho .mqtt .client as mqtt
2223import pexpect
@@ -61,7 +62,7 @@ def on_connect(self, mqttc: Any, obj: Any, flags: Any, rc: int) -> None:
6162 def on_connect_fail (self , mqttc : Any , obj : Any ) -> None :
6263 logging .error ('Connect failed' )
6364
64- def on_message (self , mqttc : Any , userdata : Any , msg : mqtt .MQTTMessage ) -> None :
65+ def on_message (self , mqttc : mqtt . Client , obj : Any , msg : mqtt .MQTTMessage ) -> None :
6566 payload = msg .payload .decode ('utf-8' )
6667 if payload == self .expected_data :
6768 self .received += 1
@@ -70,8 +71,9 @@ def on_message(self, mqttc: Any, userdata: Any, msg: mqtt.MQTTMessage) -> None:
7071 else :
7172 differences = len (list (filter (lambda data : data [0 ] != data [1 ], zip (payload , self .expected_data ))))
7273 logging .error (
73- f'Payload differ in { differences } positions from expected data. received size: { len (payload )} expected size:'
74- f'{ len (self .expected_data )} '
74+ f'Payload on topic "{ msg .topic } " (QoS { msg .qos } ) differs in { differences } positions '
75+ 'from expected data. '
76+ f'Received size: { len (payload )} , expected size: { len (self .expected_data )} .'
7577 )
7678 logging .info (f'Repetitions: { payload .count (self .config ["pattern" ])} ' )
7779 logging .info (f'Pattern: { self .config ["pattern" ]} ' )
@@ -85,6 +87,7 @@ def __enter__(self) -> Any:
8587 qos = self .config ['qos' ]
8688 broker_host = self .config ['broker_host_' + self .config ['transport' ]]
8789 broker_port = self .config ['broker_port_' + self .config ['transport' ]]
90+ connect_timeout_seconds = self .config .get ('client_connect_timeout' , 30 )
8891
8992 try :
9093 self .print_details ('Connecting...' )
@@ -93,14 +96,17 @@ def __enter__(self) -> Any:
9396 self .tls_insecure_set (True )
9497 self .event_client_connected .clear ()
9598 self .loop_start ()
96- self .connect (broker_host , broker_port , 60 )
99+ self .connect (broker_host , broker_port , 60 ) # paho's keepalive
97100 except Exception :
98101 self .print_details (f'ENV_TEST_FAILURE: Unexpected error while connecting to broker { broker_host } ' )
99102 raise
100103 self .print_details (f'Connecting py-client to broker { broker_host } :{ broker_port } ...' )
101104
102- if not self .event_client_connected .wait (timeout = 30 ):
103- raise ValueError (f'ENV_TEST_FAILURE: Test script cannot connect to broker: { broker_host } ' )
105+ if not self .event_client_connected .wait (timeout = connect_timeout_seconds ):
106+ raise ValueError (
107+ f'ENV_TEST_FAILURE: Test script cannot connect to broker: { broker_host } '
108+ f'within { connect_timeout_seconds } s'
109+ )
104110 self .event_client_got_all .clear ()
105111 result , self .subscribe_mid = self .subscribe (self .config ['subscribe_topic' ], qos )
106112 assert result == 0
@@ -148,7 +154,11 @@ def get_config_from_dut(dut, config_option):
148154 publish_cfg ['pattern' ] = '' .join (
149155 random .choice (string .ascii_uppercase + string .ascii_lowercase + string .digits ) for _ in range (DEFAULT_MSG_SIZE )
150156 )
157+ publish_cfg ['client_connect_timeout' ] = 30
158+ publish_cfg ['dut_subscribe_timeout' ] = 60
159+ publish_cfg ['publish_ack_timeout' ] = 60
151160 publish_cfg ['test_timeout' ] = get_timeout (test_case )
161+
152162 unique_topic = '' .join (
153163 random .choice (string .ascii_uppercase + string .ascii_lowercase ) for _ in range (DEFAULT_MSG_SIZE )
154164 )
@@ -159,9 +169,10 @@ def get_config_from_dut(dut, config_option):
159169
160170
161171@contextlib .contextmanager
162- def connected_and_subscribed (dut : Dut ) -> Any :
172+ def connected_and_subscribed (dut : Dut , config : Dict [ str , Any ] ) -> Any :
163173 dut .write ('start' )
164- dut .expect (re .compile (rb'MQTT_EVENT_SUBSCRIBED' ), timeout = 60 )
174+ dut_subscribe_timeout = config .get ('dut_subscribe_timeout' , 60 )
175+ dut .expect (re .compile (rb'MQTT_EVENT_SUBSCRIBED' ), timeout = dut_subscribe_timeout )
165176 yield
166177 dut .write ('stop' )
167178
@@ -177,6 +188,7 @@ def get_scenarios() -> List[Dict[str, int]]:
177188 continue
178189 break
179190 if not scenarios : # No message sizes present in the env - set defaults
191+ logging .info ('Using predefined cases' )
180192 scenarios = [
181193 {'msg_len' : 0 , 'nr_of_msgs' : 5 }, # zero-sized messages
182194 {'msg_len' : 2 , 'nr_of_msgs' : 5 }, # short messages
@@ -201,13 +213,15 @@ def run_publish_test_case(dut: Dut, config: Any) -> None:
201213 logging .info (
202214 f'Starting Publish test: transport:{ config ["transport" ]} , qos:{ config ["qos" ]} ,'
203215 f'nr_of_msgs:{ config ["scenario" ]["nr_of_msgs" ]} ,'
204- f' msg_size:{ config ["scenario" ]["msg_len" ] * DEFAULT_MSG_SIZE } , enqueue:{ config ["enqueue" ]} '
216+ f' msg_size:{ config ["scenario" ]["msg_len" ]} , enqueue:{ config ["enqueue" ]} '
205217 )
206218 dut .write (
207- f'publish_setup { config ["transport" ]} { config ["publish_topic" ]} { config ["subscribe_topic" ]} { config ["pattern" ]} { config ["scenario" ]["msg_len" ]} '
219+ f'publish_setup { config ["transport" ]} { config ["publish_topic" ]} '
220+ f' { config ["subscribe_topic" ]} { config ["pattern" ]} { config ["scenario" ]["msg_len" ]} '
208221 )
209- with MqttPublisher (config ) as publisher , connected_and_subscribed (dut ):
210- assert publisher .event_client_subscribed .wait (timeout = config ['test_timeout' ]), 'Runner failed to subscribe'
222+ with MqttPublisher (config ) as publisher , connected_and_subscribed (dut , config ):
223+ py_client_subscribe_timeout = config .get ('py_client_subscribe_timeout' , config ['test_timeout' ])
224+ assert publisher .event_client_subscribed .wait (timeout = py_client_subscribe_timeout ), 'Runner failed to subscribe'
211225 msgs_published : List [mqtt .MQTTMessageInfo ] = []
212226 dut .write (f'publish { config ["scenario" ]["nr_of_msgs" ]} { config ["qos" ]} { config ["enqueue" ]} ' )
213227 assert publisher .event_client_got_all .wait (timeout = config ['test_timeout' ]), (
@@ -222,11 +236,33 @@ def run_publish_test_case(dut: Dut, config: Any) -> None:
222236 msg = publisher .publish (topic = config ['publish_topic' ], payload = payload , qos = config ['qos' ])
223237 if config ['qos' ] > 0 :
224238 msgs_published .append (msg )
225- logging .info (f'Published: { len (msgs_published )} ' )
226- while msgs_published :
227- msgs_published = [msg for msg in msgs_published if not msg .is_published ()]
228-
229- logging .info ('All messages from runner published' )
239+ logging .info (f'Published: { len (msgs_published )} messages from script with QoS > 0 needing ACK.' )
240+
241+ if msgs_published :
242+ publish_ack_timeout_seconds = config .get ('publish_ack_timeout' , 60 ) # Default 60s, make configurable
243+ ack_wait_start_time = time .time ()
244+ initial_unacked_count = len (msgs_published )
245+ logging .info (f'Waiting { initial_unacked_count } publish ack with timeout { publish_ack_timeout_seconds } s...' )
246+
247+ while msgs_published :
248+ if time .time () - ack_wait_start_time > publish_ack_timeout_seconds :
249+ unacked_mids = [msg .mid for msg in msgs_published if msg .mid is not None and not msg .is_published ()]
250+ logging .error (
251+ f'Timeout waiting for publish acknowledgements. '
252+ f'{ len (unacked_mids )} of { initial_unacked_count } messages remain unacknowledged. '
253+ f'Unacked MIDs: { unacked_mids } '
254+ )
255+ # This will likely cause the test to fail at a later assertion,
256+ # or you could raise an explicit error here.
257+ # e.g. raise Exception('Timeout waiting for publish acknowledgements')
258+ break
259+ msgs_published = [msg for msg in msgs_published if not msg .is_published ()]
260+ if msgs_published : # Avoid busy-looping if list is not empty
261+ time .sleep (0.1 ) # Brief pause
262+ if not msgs_published :
263+ logging .info ('All script-published QoS > 0 messages acknowledged by broker.' )
264+
265+ logging .info ('All messages from runner published (or timed out waiting for ACK).' )
230266
231267 try :
232268 dut .expect (re .compile (rb'Correct pattern received exactly x times' ), timeout = config ['test_timeout' ])
@@ -262,6 +298,7 @@ def make_cases(transport: Any, scenarios: List[Dict[str, int]]) -> List[Tuple[st
262298@pytest .mark .parametrize ('test_case' , test_cases )
263299@pytest .mark .parametrize ('config' , ['default' ], indirect = True )
264300@idf_parametrize ('target' , ['esp32' ], indirect = ['target' ])
301+ @pytest .mark .flaky (reruns = 1 , reruns_delay = 1 )
265302def test_mqtt_publish (dut : Dut , test_case : Any ) -> None :
266303 publish_cfg = get_configurations (dut , test_case )
267304 dut .expect (re .compile (rb'mqtt>' ), timeout = 30 )
@@ -273,6 +310,7 @@ def test_mqtt_publish(dut: Dut, test_case: Any) -> None:
273310@pytest .mark .nightly_run
274311@pytest .mark .parametrize ('test_case' , stress_test_cases )
275312@pytest .mark .parametrize ('config' , ['default' ], indirect = True )
313+ @pytest .mark .flaky (reruns = 1 , reruns_delay = 1 )
276314@idf_parametrize ('target' , ['esp32' ], indirect = ['target' ])
277315def test_mqtt_publish_stress (dut : Dut , test_case : Any ) -> None :
278316 publish_cfg = get_configurations (dut , test_case )
0 commit comments