Skip to content

Commit 9bf83b0

Browse files
authored
ref(span-buffer): Introduce multiprocessed flusher (#93824)
1 parent c2c41b0 commit 9bf83b0

File tree

6 files changed

+199
-50
lines changed

6 files changed

+199
-50
lines changed

CLAUDE.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -445,6 +445,17 @@ for org in organizations:
445445

446446
# RIGHT: Use prefetch_related
447447
organizations.prefetch_related('projects')
448+
449+
# WRONG: Use hasattr() for unions
450+
x: str | None = "hello"
451+
if hasattr(x, "replace"):
452+
x = x.replace("e", "a")
453+
454+
# RIGHT: Use isinstance()
455+
x: str | None = "hello"
456+
if isinstance(x, str):
457+
x = x.replace("e", "a")
458+
448459
```
449460

450461
### Frontend

src/sentry/consumers/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -427,7 +427,15 @@ def ingest_transactions_options() -> list[click.Option]:
427427
"topic": Topic.INGEST_SPANS,
428428
"dlq_topic": Topic.INGEST_SPANS_DLQ,
429429
"strategy_factory": "sentry.spans.consumers.process.factory.ProcessSpansStrategyFactory",
430-
"click_options": multiprocessing_options(default_max_batch_size=100),
430+
"click_options": [
431+
*multiprocessing_options(default_max_batch_size=100),
432+
click.Option(
433+
["--flusher-processes", "flusher_processes"],
434+
default=1,
435+
type=int,
436+
help="Maximum number of processes for the span flusher. Defaults to 1.",
437+
),
438+
],
431439
},
432440
"process-segments": {
433441
"topic": Topic.BUFFERED_SEGMENTS,

src/sentry/spans/consumers/process/factory.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def __init__(
3838
num_processes: int,
3939
input_block_size: int | None,
4040
output_block_size: int | None,
41+
flusher_processes: int | None = None,
4142
produce_to_pipe: Callable[[KafkaPayload], None] | None = None,
4243
):
4344
super().__init__()
@@ -48,6 +49,7 @@ def __init__(
4849
self.input_block_size = input_block_size
4950
self.output_block_size = output_block_size
5051
self.num_processes = num_processes
52+
self.flusher_processes = flusher_processes
5153
self.produce_to_pipe = produce_to_pipe
5254

5355
if self.num_processes != 1:
@@ -69,6 +71,7 @@ def create_with_partitions(
6971
flusher = self._flusher = SpanFlusher(
7072
buffer,
7173
next_step=committer,
74+
max_processes=self.flusher_processes,
7275
produce_to_pipe=self.produce_to_pipe,
7376
)
7477

src/sentry/spans/consumers/process/flusher.py

Lines changed: 127 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from sentry import options
1717
from sentry.conf.types.kafka_definition import Topic
18+
from sentry.processing.backpressure.memory import ServiceMemory
1819
from sentry.spans.buffer import SpansBuffer
1920
from sentry.utils import metrics
2021
from sentry.utils.arroyo import run_with_initialized_sentry
@@ -27,7 +28,8 @@
2728

2829
class SpanFlusher(ProcessingStrategy[FilteredPayload | int]):
2930
"""
30-
A background thread that polls Redis for new segments to flush and to produce to Kafka.
31+
A background multiprocessing manager that polls Redis for new segments to flush and to produce to Kafka.
32+
Creates one process per shard for parallel processing.
3133
3234
This is a processing step to be embedded into the consumer that writes to
3335
Redis. It takes and fowards integer messages that represent recently
@@ -42,27 +44,53 @@ def __init__(
4244
self,
4345
buffer: SpansBuffer,
4446
next_step: ProcessingStrategy[FilteredPayload | int],
47+
max_processes: int | None = None,
4548
produce_to_pipe: Callable[[KafkaPayload], None] | None = None,
4649
):
47-
self.buffer = buffer
4850
self.next_step = next_step
51+
self.max_processes = max_processes or len(buffer.assigned_shards)
4952

5053
self.mp_context = mp_context = multiprocessing.get_context("spawn")
5154
self.stopped = mp_context.Value("i", 0)
5255
self.redis_was_full = False
5356
self.current_drift = mp_context.Value("i", 0)
54-
self.backpressure_since = mp_context.Value("i", 0)
55-
self.healthy_since = mp_context.Value("i", 0)
56-
self.process_restarts = 0
5757
self.produce_to_pipe = produce_to_pipe
5858

59-
self._create_process()
60-
61-
def _create_process(self):
59+
# Determine which shards get their own processes vs shared processes
60+
self.num_processes = min(self.max_processes, len(buffer.assigned_shards))
61+
self.process_to_shards_map: dict[int, list[int]] = {
62+
i: [] for i in range(self.num_processes)
63+
}
64+
for i, shard in enumerate(buffer.assigned_shards):
65+
process_index = i % self.num_processes
66+
self.process_to_shards_map[process_index].append(shard)
67+
68+
self.processes: dict[int, multiprocessing.context.SpawnProcess | threading.Thread] = {}
69+
self.process_healthy_since = {
70+
process_index: mp_context.Value("i", int(time.time()))
71+
for process_index in range(self.num_processes)
72+
}
73+
self.process_backpressure_since = {
74+
process_index: mp_context.Value("i", 0) for process_index in range(self.num_processes)
75+
}
76+
self.process_restarts = {process_index: 0 for process_index in range(self.num_processes)}
77+
self.buffers: dict[int, SpansBuffer] = {}
78+
79+
self._create_processes()
80+
81+
def _create_processes(self):
82+
# Create processes based on shard mapping
83+
for process_index, shards in self.process_to_shards_map.items():
84+
self._create_process_for_shards(process_index, shards)
85+
86+
def _create_process_for_shards(self, process_index: int, shards: list[int]):
6287
# Optimistically reset healthy_since to avoid a race between the
6388
# starting process and the next flush cycle. Keep back pressure across
6489
# the restart, however.
65-
self.healthy_since.value = int(time.time())
90+
self.process_healthy_since[process_index].value = int(time.time())
91+
92+
# Create a buffer for these specific shards
93+
shard_buffer = SpansBuffer(shards)
6694

6795
make_process: Callable[..., multiprocessing.context.SpawnProcess | threading.Thread]
6896
if self.produce_to_pipe is None:
@@ -72,37 +100,50 @@ def _create_process(self):
72100
# pickled separately. at the same time, pickling
73101
# synchronization primitives like multiprocessing.Value can
74102
# only be done by the Process
75-
self.buffer,
103+
shard_buffer,
76104
)
77105
make_process = self.mp_context.Process
78106
else:
79-
target = partial(SpanFlusher.main, self.buffer)
107+
target = partial(SpanFlusher.main, shard_buffer)
80108
make_process = threading.Thread
81109

82-
self.process = make_process(
110+
process = make_process(
83111
target=target,
84112
args=(
113+
shards,
85114
self.stopped,
86115
self.current_drift,
87-
self.backpressure_since,
88-
self.healthy_since,
116+
self.process_backpressure_since[process_index],
117+
self.process_healthy_since[process_index],
89118
self.produce_to_pipe,
90119
),
91120
daemon=True,
92121
)
93122

94-
self.process.start()
123+
process.start()
124+
self.processes[process_index] = process
125+
self.buffers[process_index] = shard_buffer
126+
127+
def _create_process_for_shard(self, shard: int):
128+
# Find which process this shard belongs to and restart that process
129+
for process_index, shards in self.process_to_shards_map.items():
130+
if shard in shards:
131+
self._create_process_for_shards(process_index, shards)
132+
break
95133

96134
@staticmethod
97135
def main(
98136
buffer: SpansBuffer,
137+
shards: list[int],
99138
stopped,
100139
current_drift,
101140
backpressure_since,
102141
healthy_since,
103142
produce_to_pipe: Callable[[KafkaPayload], None] | None,
104143
) -> None:
144+
shard_tag = ",".join(map(str, shards))
105145
sentry_sdk.set_tag("sentry_spans_buffer_component", "flusher")
146+
sentry_sdk.set_tag("sentry_spans_buffer_shards", shard_tag)
106147

107148
try:
108149
producer_futures = []
@@ -134,23 +175,28 @@ def produce(payload: KafkaPayload) -> None:
134175
else:
135176
backpressure_since.value = 0
136177

178+
# Update healthy_since for all shards handled by this process
137179
healthy_since.value = system_now
138180

139181
if not flushed_segments:
140182
time.sleep(1)
141183
continue
142184

143-
with metrics.timer("spans.buffer.flusher.produce"):
144-
for _, flushed_segment in flushed_segments.items():
185+
with metrics.timer("spans.buffer.flusher.produce", tags={"shard": shard_tag}):
186+
for flushed_segment in flushed_segments.values():
145187
if not flushed_segment.spans:
146188
continue
147189

148190
spans = [span.payload for span in flushed_segment.spans]
149191
kafka_payload = KafkaPayload(None, orjson.dumps({"spans": spans}), [])
150-
metrics.timing("spans.buffer.segment_size_bytes", len(kafka_payload.value))
192+
metrics.timing(
193+
"spans.buffer.segment_size_bytes",
194+
len(kafka_payload.value),
195+
tags={"shard": shard_tag},
196+
)
151197
produce(kafka_payload)
152198

153-
with metrics.timer("spans.buffer.flusher.wait_produce"):
199+
with metrics.timer("spans.buffer.flusher.wait_produce", tags={"shards": shard_tag}):
154200
for future in producer_futures:
155201
future.result()
156202

@@ -169,46 +215,71 @@ def produce(payload: KafkaPayload) -> None:
169215
def poll(self) -> None:
170216
self.next_step.poll()
171217

172-
def _ensure_process_alive(self) -> None:
218+
def _ensure_processes_alive(self) -> None:
173219
max_unhealthy_seconds = options.get("spans.buffer.flusher.max-unhealthy-seconds")
174-
if not self.process.is_alive():
175-
exitcode = getattr(self.process, "exitcode", "unknown")
176-
cause = f"no_process_{exitcode}"
177-
elif int(time.time()) - self.healthy_since.value > max_unhealthy_seconds:
178-
cause = "hang"
179-
else:
180-
return # healthy
181220

182-
metrics.incr("spans.buffer.flusher_unhealthy", tags={"cause": cause})
183-
if self.process_restarts > MAX_PROCESS_RESTARTS:
184-
raise RuntimeError(f"flusher process crashed repeatedly ({cause}), restarting consumer")
221+
for process_index, process in self.processes.items():
222+
if not process:
223+
continue
224+
225+
shards = self.process_to_shards_map[process_index]
226+
227+
cause = None
228+
if not process.is_alive():
229+
exitcode = getattr(process, "exitcode", "unknown")
230+
cause = f"no_process_{exitcode}"
231+
elif (
232+
int(time.time()) - self.process_healthy_since[process_index].value
233+
> max_unhealthy_seconds
234+
):
235+
# Check if any shard handled by this process is unhealthy
236+
cause = "hang"
237+
238+
if cause is None:
239+
continue # healthy
240+
241+
# Report unhealthy for all shards handled by this process
242+
for shard in shards:
243+
metrics.incr(
244+
"spans.buffer.flusher_unhealthy", tags={"cause": cause, "shard": shard}
245+
)
185246

186-
try:
187-
self.process.kill()
188-
except ValueError:
189-
pass # Process already closed, ignore
247+
if self.process_restarts[process_index] > MAX_PROCESS_RESTARTS:
248+
raise RuntimeError(
249+
f"flusher process for shards {shards} crashed repeatedly ({cause}), restarting consumer"
250+
)
251+
self.process_restarts[process_index] += 1
190252

191-
self.process_restarts += 1
192-
self._create_process()
253+
try:
254+
if isinstance(process, multiprocessing.Process):
255+
process.kill()
256+
except (ValueError, AttributeError):
257+
pass # Process already closed, ignore
258+
259+
self._create_process_for_shards(process_index, shards)
193260

194261
def submit(self, message: Message[FilteredPayload | int]) -> None:
195262
# Note that submit is not actually a hot path. Their message payloads
196263
# are mapped from *batches* of spans, and there are a handful of spans
197264
# per second at most. If anything, self.poll() might even be called
198265
# more often than submit()
199266

200-
self._ensure_process_alive()
267+
self._ensure_processes_alive()
201268

202-
self.buffer.record_stored_segments()
269+
for buffer in self.buffers.values():
270+
buffer.record_stored_segments()
203271

204272
# We pause insertion into Redis if the flusher is not making progress
205273
# fast enough. We could backlog into Redis, but we assume, despite best
206274
# efforts, it is still always going to be less durable than Kafka.
207275
# Minimizing our Redis memory usage also makes COGS easier to reason
208276
# about.
209-
if self.backpressure_since.value > 0:
210-
backpressure_secs = options.get("spans.buffer.flusher.backpressure-seconds")
211-
if int(time.time()) - self.backpressure_since.value > backpressure_secs:
277+
backpressure_secs = options.get("spans.buffer.flusher.backpressure-seconds")
278+
for backpressure_since in self.process_backpressure_since.values():
279+
if (
280+
backpressure_since.value > 0
281+
and int(time.time()) - backpressure_since.value > backpressure_secs
282+
):
212283
metrics.incr("spans.buffer.flusher.backpressure")
213284
raise MessageRejected()
214285

@@ -225,7 +296,9 @@ def submit(self, message: Message[FilteredPayload | int]) -> None:
225296
# wait until the situation is improved manually.
226297
max_memory_percentage = options.get("spans.buffer.max-memory-percentage")
227298
if max_memory_percentage < 1.0:
228-
memory_infos = list(self.buffer.get_memory_info())
299+
memory_infos: list[ServiceMemory] = []
300+
for buffer in self.buffers.values():
301+
memory_infos.extend(buffer.get_memory_info())
229302
used = sum(x.used for x in memory_infos)
230303
available = sum(x.available for x in memory_infos)
231304
if available > 0 and used / available > max_memory_percentage:
@@ -253,15 +326,22 @@ def close(self) -> None:
253326
self.next_step.close()
254327

255328
def join(self, timeout: float | None = None):
256-
# set stopped flag first so we can "flush" the background thread while
329+
# set stopped flag first so we can "flush" the background threads while
257330
# next_step is also shutting down. we can do two things at once!
258331
self.stopped.value = True
259332
deadline = time.time() + timeout if timeout else None
260333

261334
self.next_step.join(timeout)
262335

263-
while self.process.is_alive() and (deadline is None or deadline > time.time()):
264-
time.sleep(0.1)
336+
# Wait for all processes to finish
337+
for process_index, process in self.processes.items():
338+
if deadline is not None:
339+
remaining_time = deadline - time.time()
340+
if remaining_time <= 0:
341+
break
342+
343+
while process.is_alive() and (deadline is None or deadline > time.time()):
344+
time.sleep(0.1)
265345

266-
if isinstance(self.process, multiprocessing.Process):
267-
self.process.terminate()
346+
if isinstance(process, multiprocessing.Process):
347+
process.terminate()

0 commit comments

Comments
 (0)