Skip to content

Commit 7dd6229

Browse files
authored
chore: add a test to describe upload behaviour when there are errors (#403)
* chore: add a test to describe upload behaviour when there are errors * refactor the test file * add typehints to the test file
1 parent 4e4cd18 commit 7dd6229

File tree

1 file changed

+86
-83
lines changed

1 file changed

+86
-83
lines changed

posthog/test/test_consumer.py

Lines changed: 86 additions & 83 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import json
22
import time
33
import unittest
4+
from typing import Any
45

56
import mock
7+
from parameterized import parameterized
68

79
try:
810
from queue import Queue
@@ -14,15 +16,19 @@
1416
from posthog.test.test_utils import TEST_API_KEY
1517

1618

19+
def _track_event(event_name: str = "python event") -> dict[str, str]:
20+
return {"type": "track", "event": event_name, "distinct_id": "distinct_id"}
21+
22+
1723
class TestConsumer(unittest.TestCase):
18-
def test_next(self):
24+
def test_next(self) -> None:
1925
q = Queue()
2026
consumer = Consumer(q, "")
2127
q.put(1)
2228
next = consumer.next()
2329
self.assertEqual(next, [1])
2430

25-
def test_next_limit(self):
31+
def test_next_limit(self) -> None:
2632
q = Queue()
2733
flush_at = 50
2834
consumer = Consumer(q, "", flush_at)
@@ -31,7 +37,7 @@ def test_next_limit(self):
3137
next = consumer.next()
3238
self.assertEqual(next, list(range(flush_at)))
3339

34-
def test_dropping_oversize_msg(self):
40+
def test_dropping_oversize_msg(self) -> None:
3541
q = Queue()
3642
consumer = Consumer(q, "")
3743
oversize_msg = {"m": "x" * MAX_MSG_SIZE}
@@ -40,15 +46,14 @@ def test_dropping_oversize_msg(self):
4046
self.assertEqual(next, [])
4147
self.assertTrue(q.empty())
4248

43-
def test_upload(self):
49+
def test_upload(self) -> None:
4450
q = Queue()
4551
consumer = Consumer(q, TEST_API_KEY)
46-
track = {"type": "track", "event": "python event", "distinct_id": "distinct_id"}
47-
q.put(track)
52+
q.put(_track_event())
4853
success = consumer.upload()
4954
self.assertTrue(success)
5055

51-
def test_flush_interval(self):
56+
def test_flush_interval(self) -> None:
5257
# Put _n_ items in the queue, pausing a little bit more than
5358
# _flush_interval_ after each one.
5459
# The consumer should upload _n_ times.
@@ -57,17 +62,12 @@ def test_flush_interval(self):
5762
consumer = Consumer(q, TEST_API_KEY, flush_at=10, flush_interval=flush_interval)
5863
with mock.patch("posthog.consumer.batch_post") as mock_post:
5964
consumer.start()
60-
for i in range(0, 3):
61-
track = {
62-
"type": "track",
63-
"event": "python event %d" % i,
64-
"distinct_id": "distinct_id",
65-
}
66-
q.put(track)
65+
for i in range(3):
66+
q.put(_track_event("python event %d" % i))
6767
time.sleep(flush_interval * 1.1)
6868
self.assertEqual(mock_post.call_count, 3)
6969

70-
def test_multiple_uploads_per_interval(self):
70+
def test_multiple_uploads_per_interval(self) -> None:
7171
# Put _flush_at*2_ items in the queue at once, then pause for
7272
# _flush_interval_. The consumer should upload 2 times.
7373
q = Queue()
@@ -78,88 +78,60 @@ def test_multiple_uploads_per_interval(self):
7878
)
7979
with mock.patch("posthog.consumer.batch_post") as mock_post:
8080
consumer.start()
81-
for i in range(0, flush_at * 2):
82-
track = {
83-
"type": "track",
84-
"event": "python event %d" % i,
85-
"distinct_id": "distinct_id",
86-
}
87-
q.put(track)
81+
for i in range(flush_at * 2):
82+
q.put(_track_event("python event %d" % i))
8883
time.sleep(flush_interval * 1.1)
8984
self.assertEqual(mock_post.call_count, 2)
9085

91-
def test_request(self):
86+
def test_request(self) -> None:
9287
consumer = Consumer(None, TEST_API_KEY)
93-
track = {"type": "track", "event": "python event", "distinct_id": "distinct_id"}
94-
consumer.request([track])
88+
consumer.request([_track_event()])
9589

96-
def _test_request_retry(self, consumer, expected_exception, exception_count):
97-
def mock_post(*args, **kwargs):
98-
mock_post.call_count += 1
99-
if mock_post.call_count <= exception_count:
100-
raise expected_exception
90+
def _run_retry_test(
91+
self, exception: Exception, exception_count: int, retries: int = 10
92+
) -> None:
93+
call_count = [0]
10194

102-
mock_post.call_count = 0
95+
def mock_post(*args: Any, **kwargs: Any) -> None:
96+
call_count[0] += 1
97+
if call_count[0] <= exception_count:
98+
raise exception
10399

100+
consumer = Consumer(None, TEST_API_KEY, retries=retries)
104101
with mock.patch(
105102
"posthog.consumer.batch_post", mock.Mock(side_effect=mock_post)
106103
):
107-
track = {
108-
"type": "track",
109-
"event": "python event",
110-
"distinct_id": "distinct_id",
111-
}
112-
# request() should succeed if the number of exceptions raised is
113-
# less than the retries paramater.
114-
if exception_count <= consumer.retries:
115-
consumer.request([track])
104+
if exception_count <= retries:
105+
consumer.request([_track_event()])
116106
else:
117-
# if exceptions are raised more times than the retries
118-
# parameter, we expect the exception to be returned to
119-
# the caller.
120-
try:
121-
consumer.request([track])
122-
except type(expected_exception) as exc:
123-
self.assertEqual(exc, expected_exception)
124-
else:
125-
self.fail(
126-
"request() should raise an exception if still failing after %d retries"
127-
% consumer.retries
128-
)
129-
130-
def test_request_retry(self):
131-
# we should retry on general errors
132-
consumer = Consumer(None, TEST_API_KEY)
133-
self._test_request_retry(consumer, Exception("generic exception"), 2)
134-
135-
# we should retry on server errors
136-
consumer = Consumer(None, TEST_API_KEY)
137-
self._test_request_retry(consumer, APIError(500, "Internal Server Error"), 2)
138-
139-
# we should retry on HTTP 429 errors
140-
consumer = Consumer(None, TEST_API_KEY)
141-
self._test_request_retry(consumer, APIError(429, "Too Many Requests"), 2)
142-
143-
# we should NOT retry on other client errors
144-
consumer = Consumer(None, TEST_API_KEY)
145-
api_error = APIError(400, "Client Errors")
146-
try:
147-
self._test_request_retry(consumer, api_error, 1)
148-
except APIError:
149-
pass
150-
else:
151-
self.fail("request() should not retry on client errors")
152-
153-
# test for number of exceptions raise > retries value
154-
consumer = Consumer(None, TEST_API_KEY, retries=3)
155-
self._test_request_retry(consumer, APIError(500, "Internal Server Error"), 3)
156-
157-
def test_pause(self):
107+
with self.assertRaises(type(exception)):
108+
consumer.request([_track_event()])
109+
110+
@parameterized.expand(
111+
[
112+
("general_errors", Exception("generic exception"), 2),
113+
("server_errors", APIError(500, "Internal Server Error"), 2),
114+
("rate_limit_errors", APIError(429, "Too Many Requests"), 2),
115+
]
116+
)
117+
def test_request_retries_on_retriable_errors(
118+
self, _name: str, exception: Exception, exception_count: int
119+
) -> None:
120+
self._run_retry_test(exception, exception_count)
121+
122+
def test_request_does_not_retry_client_errors(self) -> None:
123+
with self.assertRaises(APIError):
124+
self._run_retry_test(APIError(400, "Client Errors"), 1)
125+
126+
def test_request_fails_when_exceptions_exceed_retries(self) -> None:
127+
self._run_retry_test(APIError(500, "Internal Server Error"), 4, retries=3)
128+
129+
def test_pause(self) -> None:
158130
consumer = Consumer(None, TEST_API_KEY)
159131
consumer.pause()
160132
self.assertFalse(consumer.running)
161133

162-
def test_max_batch_size(self):
134+
def test_max_batch_size(self) -> None:
163135
q = Queue()
164136
consumer = Consumer(q, TEST_API_KEY, flush_at=100000, flush_interval=3)
165137
properties = {}
@@ -175,7 +147,7 @@ def test_max_batch_size(self):
175147
# Let's capture 8MB of data to trigger two batches
176148
n_msgs = int(8_000_000 / msg_size)
177149

178-
def mock_post_fn(_, data, **kwargs):
150+
def mock_post_fn(_: str, data: str, **kwargs: Any) -> mock.Mock:
179151
res = mock.Mock()
180152
res.status_code = 200
181153
request_size = len(data.encode())
@@ -194,3 +166,34 @@ def mock_post_fn(_, data, **kwargs):
194166
q.put(track)
195167
q.join()
196168
self.assertEqual(mock_post.call_count, 2)
169+
170+
@parameterized.expand(
171+
[
172+
("on_error_succeeds", False),
173+
("on_error_raises", True),
174+
]
175+
)
176+
def test_upload_exception_calls_on_error_and_does_not_raise(
177+
self, _name: str, on_error_raises: bool
178+
) -> None:
179+
on_error_called: list[tuple[Exception, list[dict[str, str]]]] = []
180+
181+
def on_error(e: Exception, batch: list[dict[str, str]]) -> None:
182+
on_error_called.append((e, batch))
183+
if on_error_raises:
184+
raise Exception("on_error failed")
185+
186+
q = Queue()
187+
consumer = Consumer(q, TEST_API_KEY, on_error=on_error)
188+
track = _track_event()
189+
q.put(track)
190+
191+
with mock.patch.object(
192+
consumer, "request", side_effect=Exception("request failed")
193+
):
194+
result = consumer.upload()
195+
196+
self.assertFalse(result)
197+
self.assertEqual(len(on_error_called), 1)
198+
self.assertEqual(str(on_error_called[0][0]), "request failed")
199+
self.assertEqual(on_error_called[0][1], [track])

0 commit comments

Comments
 (0)