11import json
22import time
33import unittest
4+ from typing import Any
45
56import mock
7+ from parameterized import parameterized
68
79try :
810 from queue import Queue
1416from posthog .test .test_utils import TEST_API_KEY
1517
1618
19+ def _track_event (event_name : str = "python event" ) -> dict [str , str ]:
20+ return {"type" : "track" , "event" : event_name , "distinct_id" : "distinct_id" }
21+
22+
1723class TestConsumer (unittest .TestCase ):
18- def test_next (self ):
24+ def test_next (self ) -> None :
1925 q = Queue ()
2026 consumer = Consumer (q , "" )
2127 q .put (1 )
2228 next = consumer .next ()
2329 self .assertEqual (next , [1 ])
2430
25- def test_next_limit (self ):
31+ def test_next_limit (self ) -> None :
2632 q = Queue ()
2733 flush_at = 50
2834 consumer = Consumer (q , "" , flush_at )
@@ -31,7 +37,7 @@ def test_next_limit(self):
3137 next = consumer .next ()
3238 self .assertEqual (next , list (range (flush_at )))
3339
34- def test_dropping_oversize_msg (self ):
40+ def test_dropping_oversize_msg (self ) -> None :
3541 q = Queue ()
3642 consumer = Consumer (q , "" )
3743 oversize_msg = {"m" : "x" * MAX_MSG_SIZE }
@@ -40,15 +46,14 @@ def test_dropping_oversize_msg(self):
4046 self .assertEqual (next , [])
4147 self .assertTrue (q .empty ())
4248
43- def test_upload (self ):
49+ def test_upload (self ) -> None :
4450 q = Queue ()
4551 consumer = Consumer (q , TEST_API_KEY )
46- track = {"type" : "track" , "event" : "python event" , "distinct_id" : "distinct_id" }
47- q .put (track )
52+ q .put (_track_event ())
4853 success = consumer .upload ()
4954 self .assertTrue (success )
5055
51- def test_flush_interval (self ):
56+ def test_flush_interval (self ) -> None :
5257 # Put _n_ items in the queue, pausing a little bit more than
5358 # _flush_interval_ after each one.
5459 # The consumer should upload _n_ times.
@@ -57,17 +62,12 @@ def test_flush_interval(self):
5762 consumer = Consumer (q , TEST_API_KEY , flush_at = 10 , flush_interval = flush_interval )
5863 with mock .patch ("posthog.consumer.batch_post" ) as mock_post :
5964 consumer .start ()
60- for i in range (0 , 3 ):
61- track = {
62- "type" : "track" ,
63- "event" : "python event %d" % i ,
64- "distinct_id" : "distinct_id" ,
65- }
66- q .put (track )
65+ for i in range (3 ):
66+ q .put (_track_event ("python event %d" % i ))
6767 time .sleep (flush_interval * 1.1 )
6868 self .assertEqual (mock_post .call_count , 3 )
6969
70- def test_multiple_uploads_per_interval (self ):
70+ def test_multiple_uploads_per_interval (self ) -> None :
7171 # Put _flush_at*2_ items in the queue at once, then pause for
7272 # _flush_interval_. The consumer should upload 2 times.
7373 q = Queue ()
@@ -78,88 +78,60 @@ def test_multiple_uploads_per_interval(self):
7878 )
7979 with mock .patch ("posthog.consumer.batch_post" ) as mock_post :
8080 consumer .start ()
81- for i in range (0 , flush_at * 2 ):
82- track = {
83- "type" : "track" ,
84- "event" : "python event %d" % i ,
85- "distinct_id" : "distinct_id" ,
86- }
87- q .put (track )
81+ for i in range (flush_at * 2 ):
82+ q .put (_track_event ("python event %d" % i ))
8883 time .sleep (flush_interval * 1.1 )
8984 self .assertEqual (mock_post .call_count , 2 )
9085
91- def test_request (self ):
86+ def test_request (self ) -> None :
9287 consumer = Consumer (None , TEST_API_KEY )
93- track = {"type" : "track" , "event" : "python event" , "distinct_id" : "distinct_id" }
94- consumer .request ([track ])
88+ consumer .request ([_track_event ()])
9589
96- def _test_request_retry (self , consumer , expected_exception , exception_count ):
97- def mock_post (* args , ** kwargs ):
98- mock_post .call_count += 1
99- if mock_post .call_count <= exception_count :
100- raise expected_exception
90+ def _run_retry_test (
91+ self , exception : Exception , exception_count : int , retries : int = 10
92+ ) -> None :
93+ call_count = [0 ]
10194
102- mock_post .call_count = 0
95+ def mock_post (* args : Any , ** kwargs : Any ) -> None :
96+ call_count [0 ] += 1
97+ if call_count [0 ] <= exception_count :
98+ raise exception
10399
100+ consumer = Consumer (None , TEST_API_KEY , retries = retries )
104101 with mock .patch (
105102 "posthog.consumer.batch_post" , mock .Mock (side_effect = mock_post )
106103 ):
107- track = {
108- "type" : "track" ,
109- "event" : "python event" ,
110- "distinct_id" : "distinct_id" ,
111- }
112- # request() should succeed if the number of exceptions raised is
113- # less than the retries paramater.
114- if exception_count <= consumer .retries :
115- consumer .request ([track ])
104+ if exception_count <= retries :
105+ consumer .request ([_track_event ()])
116106 else :
117- # if exceptions are raised more times than the retries
118- # parameter, we expect the exception to be returned to
119- # the caller.
120- try :
121- consumer .request ([track ])
122- except type (expected_exception ) as exc :
123- self .assertEqual (exc , expected_exception )
124- else :
125- self .fail (
126- "request() should raise an exception if still failing after %d retries"
127- % consumer .retries
128- )
129-
130- def test_request_retry (self ):
131- # we should retry on general errors
132- consumer = Consumer (None , TEST_API_KEY )
133- self ._test_request_retry (consumer , Exception ("generic exception" ), 2 )
134-
135- # we should retry on server errors
136- consumer = Consumer (None , TEST_API_KEY )
137- self ._test_request_retry (consumer , APIError (500 , "Internal Server Error" ), 2 )
138-
139- # we should retry on HTTP 429 errors
140- consumer = Consumer (None , TEST_API_KEY )
141- self ._test_request_retry (consumer , APIError (429 , "Too Many Requests" ), 2 )
142-
143- # we should NOT retry on other client errors
144- consumer = Consumer (None , TEST_API_KEY )
145- api_error = APIError (400 , "Client Errors" )
146- try :
147- self ._test_request_retry (consumer , api_error , 1 )
148- except APIError :
149- pass
150- else :
151- self .fail ("request() should not retry on client errors" )
152-
153- # test for number of exceptions raise > retries value
154- consumer = Consumer (None , TEST_API_KEY , retries = 3 )
155- self ._test_request_retry (consumer , APIError (500 , "Internal Server Error" ), 3 )
156-
157- def test_pause (self ):
107+ with self .assertRaises (type (exception )):
108+ consumer .request ([_track_event ()])
109+
110+ @parameterized .expand (
111+ [
112+ ("general_errors" , Exception ("generic exception" ), 2 ),
113+ ("server_errors" , APIError (500 , "Internal Server Error" ), 2 ),
114+ ("rate_limit_errors" , APIError (429 , "Too Many Requests" ), 2 ),
115+ ]
116+ )
117+ def test_request_retries_on_retriable_errors (
118+ self , _name : str , exception : Exception , exception_count : int
119+ ) -> None :
120+ self ._run_retry_test (exception , exception_count )
121+
122+ def test_request_does_not_retry_client_errors (self ) -> None :
123+ with self .assertRaises (APIError ):
124+ self ._run_retry_test (APIError (400 , "Client Errors" ), 1 )
125+
126+ def test_request_fails_when_exceptions_exceed_retries (self ) -> None :
127+ self ._run_retry_test (APIError (500 , "Internal Server Error" ), 4 , retries = 3 )
128+
129+ def test_pause (self ) -> None :
158130 consumer = Consumer (None , TEST_API_KEY )
159131 consumer .pause ()
160132 self .assertFalse (consumer .running )
161133
162- def test_max_batch_size (self ):
134+ def test_max_batch_size (self ) -> None :
163135 q = Queue ()
164136 consumer = Consumer (q , TEST_API_KEY , flush_at = 100000 , flush_interval = 3 )
165137 properties = {}
@@ -175,7 +147,7 @@ def test_max_batch_size(self):
175147 # Let's capture 8MB of data to trigger two batches
176148 n_msgs = int (8_000_000 / msg_size )
177149
178- def mock_post_fn (_ , data , ** kwargs ) :
150+ def mock_post_fn (_ : str , data : str , ** kwargs : Any ) -> mock . Mock :
179151 res = mock .Mock ()
180152 res .status_code = 200
181153 request_size = len (data .encode ())
@@ -194,3 +166,34 @@ def mock_post_fn(_, data, **kwargs):
194166 q .put (track )
195167 q .join ()
196168 self .assertEqual (mock_post .call_count , 2 )
169+
170+ @parameterized .expand (
171+ [
172+ ("on_error_succeeds" , False ),
173+ ("on_error_raises" , True ),
174+ ]
175+ )
176+ def test_upload_exception_calls_on_error_and_does_not_raise (
177+ self , _name : str , on_error_raises : bool
178+ ) -> None :
179+ on_error_called : list [tuple [Exception , list [dict [str , str ]]]] = []
180+
181+ def on_error (e : Exception , batch : list [dict [str , str ]]) -> None :
182+ on_error_called .append ((e , batch ))
183+ if on_error_raises :
184+ raise Exception ("on_error failed" )
185+
186+ q = Queue ()
187+ consumer = Consumer (q , TEST_API_KEY , on_error = on_error )
188+ track = _track_event ()
189+ q .put (track )
190+
191+ with mock .patch .object (
192+ consumer , "request" , side_effect = Exception ("request failed" )
193+ ):
194+ result = consumer .upload ()
195+
196+ self .assertFalse (result )
197+ self .assertEqual (len (on_error_called ), 1 )
198+ self .assertEqual (str (on_error_called [0 ][0 ]), "request failed" )
199+ self .assertEqual (on_error_called [0 ][1 ], [track ])
0 commit comments