Skip to content

Commit f204fc9

Browse files
Merge pull request #98 from Doist/brandon/response
fix: improved handling of api call failure
2 parents f1b78a1 + f8965d4 commit f204fc9

File tree

4 files changed

+144
-59
lines changed

4 files changed

+144
-59
lines changed

sqs_workers/memory_sqs.py

Lines changed: 39 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,7 @@ class MemoryQueue:
107107
name: str = attr.ib()
108108
attributes: Dict[str, Dict[str, str]] = attr.ib()
109109
messages: List["MemoryMessage"] = attr.ib(factory=list)
110+
in_flight: Dict[str, "MemoryMessage"] = attr.ib(factory=dict)
110111

111112
def __attrs_post_init__(self):
112113
self.attributes["QueueArn"] = self.name
@@ -146,6 +147,8 @@ def receive_messages(self, WaitTimeSeconds="0", MaxNumberOfMessages="10", **kwar
146147
else:
147148
ready_messages.append(message)
148149
self.messages[:] = push_back_messages
150+
for m in ready_messages:
151+
self.in_flight[m.message_id] = m
149152
return ready_messages
150153

151154
def delete_messages(self, Entries):
@@ -155,22 +158,47 @@ def delete_messages(self, Entries):
155158
See: https://boto3.amazonaws.com/v1/documentation/api/latest/reference/
156159
services/sqs.html#SQS.Queue.delete_messages
157160
"""
158-
message_ids = {entry["Id"] for entry in Entries}
161+
found_entries = []
162+
not_found_entries = []
159163

160-
successfully_deleted = set()
161-
push_back_messages = []
164+
for e in Entries:
165+
if e["Id"] in self.in_flight:
166+
found_entries.append(e)
167+
self.in_flight.pop(e["Id"])
168+
else:
169+
not_found_entries.append(e)
162170

163-
for message in self.messages:
164-
if message.message_id in message_ids:
165-
successfully_deleted.add(message.message_id)
171+
return {
172+
"Successful": [{"Id": e["Id"]} for e in found_entries],
173+
"Failed": [{"Id": e["Id"]} for e in not_found_entries],
174+
}
175+
176+
def change_message_visibility_batch(self, Entries):
177+
"""
178+
Changes message visibility by looking at in-flight messages, setting
179+
a new execute_at, and returning it to the pool of processable messages
180+
"""
181+
found_entries = []
182+
not_found_entries = []
183+
184+
now = datetime.datetime.utcnow()
185+
186+
for e in Entries:
187+
if e["Id"] in self.in_flight:
188+
found_entries.append(e)
189+
in_flight_message = self.in_flight[e["Id"]]
190+
sec = int(e["VisibilityTimeout"])
191+
execute_at = now + datetime.timedelta(seconds=sec)
192+
updated_message = attr.evolve(in_flight_message, execute_at=execute_at)
193+
updated_message.attributes["ApproximateReceiveCount"] += 1
194+
self.messages.append(updated_message)
195+
self.in_flight.pop(e["Id"])
166196
else:
167-
push_back_messages.append(message)
168-
self.messages[:] = push_back_messages
197+
not_found_entries.append(e)
169198

170-
didnt_deleted = message_ids.difference(successfully_deleted)
171199
return {
172-
"Successful": [{"Id": _id} for _id in successfully_deleted],
173-
"Failed": [{"Id": _id} for _id in didnt_deleted],
200+
"Successful": [{"Id": e["Id"]} for e in found_entries],
201+
"Failed": [{"Id": e["Id"]} for e in not_found_entries],
174202
}
175203

176204
def delete(self):
@@ -241,10 +269,3 @@ def from_kwargs(cls, queue_impl, kwargs):
241269
return MemoryMessage(
242270
queue_impl, body, message_atttributes, attributes, execute_at
243271
)
244-
245-
def change_visibility(self, VisibilityTimeout="0", **kwargs):
246-
timeout = int(VisibilityTimeout)
247-
execute_at = datetime.datetime.utcnow() + datetime.timedelta(seconds=timeout)
248-
message = attr.evolve(self, execute_at=execute_at)
249-
message.attributes["ApproximateReceiveCount"] += 1
250-
self.queue_impl.messages.append(message)

sqs_workers/queue.py

Lines changed: 86 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@
1010
Callable,
1111
Dict,
1212
Generator,
13+
Iterable,
1314
List,
1415
Literal,
1516
Optional,
17+
Tuple,
1618
TypeVar,
1719
)
1820

@@ -27,6 +29,7 @@
2729
from sqs_workers.exceptions import SQSError
2830
from sqs_workers.processors import DEFAULT_CONTEXT_VAR, Processor
2931
from sqs_workers.shutdown_policies import NEVER_SHUTDOWN
32+
from sqs_workers.utils import batcher
3033

3134
DEFAULT_MESSAGE_GROUP_ID = "default"
3235
SEND_BATCH_SIZE = 10
@@ -85,53 +88,84 @@ def process_queue(self, shutdown_policy=NEVER_SHUTDOWN, wait_second=10):
8588
)
8689
break
8790

88-
def process_batch(self, wait_seconds=0) -> BatchProcessingResult:
91+
def process_batch(self, wait_seconds: int = 0) -> BatchProcessingResult:
8992
"""
9093
Process a batch of messages from the queue (10 messages at most), return
9194
the number of successfully processed messages, and exit
9295
"""
93-
queue = self.get_queue()
94-
9596
if self.batching_policy.batching_enabled:
96-
return self._process_messages_in_batch(queue, wait_seconds)
97+
messages = self.get_raw_messages(
98+
wait_seconds, self.batching_policy.batch_size
99+
)
100+
success = self.process_messages(messages)
101+
messages_with_success = ((m, success) for m in messages)
102+
else:
103+
messages = self.get_raw_messages(wait_seconds)
104+
success = [self.process_message(message) for message in messages]
105+
messages_with_success = zip(messages, success)
97106

98-
return self._process_messages_individually(queue, wait_seconds)
107+
return self._handle_processed(messages_with_success)
99108

100-
def _process_messages_in_batch(self, queue, wait_seconds):
101-
messages = self.get_raw_messages(wait_seconds, self.batching_policy.batch_size)
102-
result = BatchProcessingResult(self.name)
109+
def _handle_processed(self, messages_with_success: Iterable[Tuple[Any, bool]]):
110+
"""
111+
Handles the results of processing messages.
103112
104-
success = self.process_messages(messages)
113+
For successful messages, we delete the message ID from the queue, which is
114+
equivalent to acknowledging it.
105115
106-
for message in messages:
107-
result.update_with_message(message, success)
108-
if success:
109-
entry = {
110-
"Id": message.message_id,
111-
"ReceiptHandle": message.receipt_handle,
112-
}
113-
queue.delete_messages(Entries=[entry])
114-
else:
115-
timeout = self.backoff_policy.get_visibility_timeout(message)
116-
message.change_visibility(VisibilityTimeout=timeout)
117-
return result
116+
For failed messages, we change the visibility of the message, in order to
117+
keep it un-consumeable for a little while (a form of backoff).
118+
119+
In each case (delete or change-viz), we batch the API calls to AWS in order
120+
to try to avoid getting throttled, with batches of size 10 (the limit). The
121+
config (see sqs_env.py) should also retry in the event of exceptions.
122+
"""
123+
queue = self.get_queue()
118124

119-
def _process_messages_individually(self, queue, wait_seconds):
120-
messages = self.get_raw_messages(wait_seconds)
121125
result = BatchProcessingResult(self.name)
122126

123-
for message in messages:
124-
success = self.process_message(message)
125-
result.update_with_message(message, success)
126-
if success:
127-
entry = {
128-
"Id": message.message_id,
129-
"ReceiptHandle": message.receipt_handle,
130-
}
131-
queue.delete_messages(Entries=[entry])
132-
else:
133-
timeout = self.backoff_policy.get_visibility_timeout(message)
134-
message.change_visibility(VisibilityTimeout=timeout)
127+
for subgroup in batcher(messages_with_success, batch_size=10):
128+
entries_to_ack = []
129+
entries_to_change_viz = []
130+
131+
for m, success in subgroup:
132+
result.update_with_message(m, success)
133+
if success:
134+
entries_to_ack.append(
135+
{
136+
"Id": m.message_id,
137+
"ReceiptHandle": m.receipt_handle,
138+
}
139+
)
140+
else:
141+
entries_to_change_viz.append(
142+
{
143+
"Id": m.message_id,
144+
"ReceiptHandle": m.receipt_handle,
145+
"VisibilityTimeout": self.backoff_policy.get_visibility_timeout(
146+
m
147+
),
148+
}
149+
)
150+
151+
ack_response = queue.delete_messages(Entries=entries_to_ack)
152+
153+
if ack_response.get("Failed"):
154+
logger.warning(
155+
"Failed to delete processed messages from queue",
156+
extra={"queue": self.name, "failures": ack_response["Failed"]},
157+
)
158+
159+
viz_response = queue.change_message_visibility_batch(
160+
Entries=entries_to_change_viz,
161+
)
162+
163+
if viz_response.get("Failed"):
164+
logger.warning(
165+
"Failed to change visibility of messages which failed to process",
166+
extra={"queue": self.name, "failures": viz_response["Failed"]},
167+
)
168+
135169
return result
136170

137171
def process_message(self, message: Any) -> bool:
@@ -151,15 +185,16 @@ def process_messages(self, messages: List[Any]) -> bool:
151185
"""
152186
raise NotImplementedError()
153187

154-
def get_raw_messages(self, wait_seconds, max_messages=10):
188+
def get_raw_messages(self, wait_seconds: int, max_messages: int = 10) -> List[Any]:
155189
"""Return raw messages from the queue, addressed by its name"""
190+
queue = self.get_queue()
191+
156192
kwargs = {
157193
"WaitTimeSeconds": wait_seconds,
158194
"MaxNumberOfMessages": max_messages if max_messages <= 10 else 10,
159195
"MessageAttributeNames": ["All"],
160196
"AttributeNames": ["All"],
161197
}
162-
queue = self.get_queue()
163198

164199
if max_messages <= 10:
165200
return queue.receive_messages(**kwargs)
@@ -180,16 +215,29 @@ def get_raw_messages(self, wait_seconds, max_messages=10):
180215
def drain_queue(self, wait_seconds=0):
181216
"""Delete all messages from the queue without calling purge()."""
182217
queue = self.get_queue()
218+
183219
deleted_count = 0
184220
while True:
185221
messages = self.get_raw_messages(wait_seconds)
186222
if not messages:
187223
break
224+
188225
entries = [
189226
{"Id": msg.message_id, "ReceiptHandle": msg.receipt_handle}
190227
for msg in messages
191228
]
192-
queue.delete_messages(Entries=entries)
229+
230+
ack_response = queue.delete_messages(Entries=entries)
231+
232+
if ack_response.get("Failed"):
233+
logger.warning(
234+
"Failed to delete processed messages from queue",
235+
extra={
236+
"queue": self.name,
237+
"failures": ack_response["Failed"],
238+
},
239+
)
240+
193241
deleted_count += len(messages)
194242
return deleted_count
195243

sqs_workers/sqs_env.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import attr
1717
import boto3
18+
from botocore.config import Config
1819
from typing_extensions import ParamSpec
1920

2021
from sqs_workers import DEFAULT_BACKOFF, RawQueue, codecs, context, processors
@@ -45,6 +46,10 @@ class SQSEnv:
4546
queue_prefix = attr.ib(default="")
4647
codec: str = attr.ib(default=codecs.DEFAULT_CONTENT_TYPE)
4748

49+
# retry settings for internal boto
50+
retry_max_attempts: int = attr.ib(default=3)
51+
retry_mode: str = attr.ib(default="standard")
52+
4853
# queue-specific settings
4954
backoff_policy = attr.ib(default=DEFAULT_BACKOFF)
5055

@@ -60,10 +65,13 @@ class SQSEnv:
6065

6166
def __attrs_post_init__(self):
6267
self.context = self.context_maker()
68+
# https://boto3.amazonaws.com/v1/documentation/api/latest/guide/retries.html
69+
retry_dict = {"max_attempts": self.retry_max_attempts, "mode": self.retry_mode}
70+
retry_config = Config(retries=retry_dict)
6371
if not self.sqs_client:
64-
self.sqs_client = self.session.client("sqs")
72+
self.sqs_client = self.session.client("sqs", config=retry_config)
6573
if not self.sqs_resource:
66-
self.sqs_resource = self.session.resource("sqs")
74+
self.sqs_resource = self.session.resource("sqs", config=retry_config)
6775

6876
@overload
6977
def queue(

sqs_workers/utils.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import importlib
22
import logging
33
from inspect import Signature
4-
from typing import Any
4+
from itertools import islice
5+
from typing import Any, Iterable
56

67
logger = logging.getLogger(__name__)
78

@@ -121,3 +122,10 @@ def ensure_string(obj: Any, encoding="utf-8", errors="strict") -> str:
121122
return obj.decode(encoding, errors)
122123
else:
123124
return str(obj)
125+
126+
127+
def batcher(iterable, batch_size) -> Iterable[Iterable[Any]]:
128+
"""Cuts an iterable up into sub-iterables of size batch_size."""
129+
iterator = iter(iterable)
130+
while batch := list(islice(iterator, batch_size)):
131+
yield batch

0 commit comments

Comments
 (0)