From 9bf83b05423beff6dcc542f93e903c562e97469b Mon Sep 17 00:00:00 2001 From: Markus Unterwaditzer Date: Mon, 23 Jun 2025 13:07:20 +0200 Subject: [PATCH] ref(span-buffer): Introduce multiprocessed flusher (#93824) --- CLAUDE.md | 11 ++ src/sentry/consumers/__init__.py | 10 +- src/sentry/spans/consumers/process/factory.py | 3 + src/sentry/spans/consumers/process/flusher.py | 174 +++++++++++++----- .../spans/consumers/process/test_consumer.py | 49 ++++- .../spans/consumers/process/test_flusher.py | 2 +- 6 files changed, 199 insertions(+), 50 deletions(-) diff --git a/CLAUDE.md b/CLAUDE.md index 97b7d0f6eea..cc56fcd90d8 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -445,6 +445,17 @@ for org in organizations: # RIGHT: Use prefetch_related organizations.prefetch_related('projects') + +# WRONG: Use hasattr() for unions +x: str | None = "hello" +if hasattr(x, "replace"): + x = x.replace("e", "a") + +# RIGHT: Use isinstance() +x: str | None = "hello" +if isinstance(x, str): + x = x.replace("e", "a") + ``` ### Frontend diff --git a/src/sentry/consumers/__init__.py b/src/sentry/consumers/__init__.py index 345912eebb8..0aa60a4d6e4 100644 --- a/src/sentry/consumers/__init__.py +++ b/src/sentry/consumers/__init__.py @@ -427,7 +427,15 @@ def ingest_transactions_options() -> list[click.Option]: "topic": Topic.INGEST_SPANS, "dlq_topic": Topic.INGEST_SPANS_DLQ, "strategy_factory": "sentry.spans.consumers.process.factory.ProcessSpansStrategyFactory", - "click_options": multiprocessing_options(default_max_batch_size=100), + "click_options": [ + *multiprocessing_options(default_max_batch_size=100), + click.Option( + ["--flusher-processes", "flusher_processes"], + default=1, + type=int, + help="Maximum number of processes for the span flusher. Defaults to 1.", + ), + ], }, "process-segments": { "topic": Topic.BUFFERED_SEGMENTS, diff --git a/src/sentry/spans/consumers/process/factory.py b/src/sentry/spans/consumers/process/factory.py index 8a4e8a6fbdf..90ef08f6fc2 100644 --- a/src/sentry/spans/consumers/process/factory.py +++ b/src/sentry/spans/consumers/process/factory.py @@ -38,6 +38,7 @@ def __init__( num_processes: int, input_block_size: int | None, output_block_size: int | None, + flusher_processes: int | None = None, produce_to_pipe: Callable[[KafkaPayload], None] | None = None, ): super().__init__() @@ -48,6 +49,7 @@ def __init__( self.input_block_size = input_block_size self.output_block_size = output_block_size self.num_processes = num_processes + self.flusher_processes = flusher_processes self.produce_to_pipe = produce_to_pipe if self.num_processes != 1: @@ -69,6 +71,7 @@ def create_with_partitions( flusher = self._flusher = SpanFlusher( buffer, next_step=committer, + max_processes=self.flusher_processes, produce_to_pipe=self.produce_to_pipe, ) diff --git a/src/sentry/spans/consumers/process/flusher.py b/src/sentry/spans/consumers/process/flusher.py index 359c0bf44fc..ae9ec081e88 100644 --- a/src/sentry/spans/consumers/process/flusher.py +++ b/src/sentry/spans/consumers/process/flusher.py @@ -15,6 +15,7 @@ from sentry import options from sentry.conf.types.kafka_definition import Topic +from sentry.processing.backpressure.memory import ServiceMemory from sentry.spans.buffer import SpansBuffer from sentry.utils import metrics from sentry.utils.arroyo import run_with_initialized_sentry @@ -27,7 +28,8 @@ class SpanFlusher(ProcessingStrategy[FilteredPayload | int]): """ - A background thread that polls Redis for new segments to flush and to produce to Kafka. + A background multiprocessing manager that polls Redis for new segments to flush and to produce to Kafka. + Creates one process per shard for parallel processing. This is a processing step to be embedded into the consumer that writes to Redis. It takes and fowards integer messages that represent recently @@ -42,27 +44,53 @@ def __init__( self, buffer: SpansBuffer, next_step: ProcessingStrategy[FilteredPayload | int], + max_processes: int | None = None, produce_to_pipe: Callable[[KafkaPayload], None] | None = None, ): - self.buffer = buffer self.next_step = next_step + self.max_processes = max_processes or len(buffer.assigned_shards) self.mp_context = mp_context = multiprocessing.get_context("spawn") self.stopped = mp_context.Value("i", 0) self.redis_was_full = False self.current_drift = mp_context.Value("i", 0) - self.backpressure_since = mp_context.Value("i", 0) - self.healthy_since = mp_context.Value("i", 0) - self.process_restarts = 0 self.produce_to_pipe = produce_to_pipe - self._create_process() - - def _create_process(self): + # Determine which shards get their own processes vs shared processes + self.num_processes = min(self.max_processes, len(buffer.assigned_shards)) + self.process_to_shards_map: dict[int, list[int]] = { + i: [] for i in range(self.num_processes) + } + for i, shard in enumerate(buffer.assigned_shards): + process_index = i % self.num_processes + self.process_to_shards_map[process_index].append(shard) + + self.processes: dict[int, multiprocessing.context.SpawnProcess | threading.Thread] = {} + self.process_healthy_since = { + process_index: mp_context.Value("i", int(time.time())) + for process_index in range(self.num_processes) + } + self.process_backpressure_since = { + process_index: mp_context.Value("i", 0) for process_index in range(self.num_processes) + } + self.process_restarts = {process_index: 0 for process_index in range(self.num_processes)} + self.buffers: dict[int, SpansBuffer] = {} + + self._create_processes() + + def _create_processes(self): + # Create processes based on shard mapping + for process_index, shards in self.process_to_shards_map.items(): + self._create_process_for_shards(process_index, shards) + + def _create_process_for_shards(self, process_index: int, shards: list[int]): # Optimistically reset healthy_since to avoid a race between the # starting process and the next flush cycle. Keep back pressure across # the restart, however. - self.healthy_since.value = int(time.time()) + self.process_healthy_since[process_index].value = int(time.time()) + + # Create a buffer for these specific shards + shard_buffer = SpansBuffer(shards) make_process: Callable[..., multiprocessing.context.SpawnProcess | threading.Thread] if self.produce_to_pipe is None: @@ -72,37 +100,50 @@ def _create_process(self): # pickled separately. at the same time, pickling # synchronization primitives like multiprocessing.Value can # only be done by the Process - self.buffer, + shard_buffer, ) make_process = self.mp_context.Process else: - target = partial(SpanFlusher.main, self.buffer) + target = partial(SpanFlusher.main, shard_buffer) make_process = threading.Thread - self.process = make_process( + process = make_process( target=target, args=( + shards, self.stopped, self.current_drift, - self.backpressure_since, - self.healthy_since, + self.process_backpressure_since[process_index], + self.process_healthy_since[process_index], self.produce_to_pipe, ), daemon=True, ) - self.process.start() + process.start() + self.processes[process_index] = process + self.buffers[process_index] = shard_buffer + + def _create_process_for_shard(self, shard: int): + # Find which process this shard belongs to and restart that process + for process_index, shards in self.process_to_shards_map.items(): + if shard in shards: + self._create_process_for_shards(process_index, shards) + break @staticmethod def main( buffer: SpansBuffer, + shards: list[int], stopped, current_drift, backpressure_since, healthy_since, produce_to_pipe: Callable[[KafkaPayload], None] | None, ) -> None: + shard_tag = ",".join(map(str, shards)) sentry_sdk.set_tag("sentry_spans_buffer_component", "flusher") + sentry_sdk.set_tag("sentry_spans_buffer_shards", shard_tag) try: producer_futures = [] @@ -134,23 +175,28 @@ def produce(payload: KafkaPayload) -> None: else: backpressure_since.value = 0 + # Update healthy_since for all shards handled by this process healthy_since.value = system_now if not flushed_segments: time.sleep(1) continue - with metrics.timer("spans.buffer.flusher.produce"): - for _, flushed_segment in flushed_segments.items(): + with metrics.timer("spans.buffer.flusher.produce", tags={"shard": shard_tag}): + for flushed_segment in flushed_segments.values(): if not flushed_segment.spans: continue spans = [span.payload for span in flushed_segment.spans] kafka_payload = KafkaPayload(None, orjson.dumps({"spans": spans}), []) - metrics.timing("spans.buffer.segment_size_bytes", len(kafka_payload.value)) + metrics.timing( + "spans.buffer.segment_size_bytes", + len(kafka_payload.value), + tags={"shard": shard_tag}, + ) produce(kafka_payload) - with metrics.timer("spans.buffer.flusher.wait_produce"): + with metrics.timer("spans.buffer.flusher.wait_produce", tags={"shards": shard_tag}): for future in producer_futures: future.result() @@ -169,27 +215,48 @@ def produce(payload: KafkaPayload) -> None: def poll(self) -> None: self.next_step.poll() - def _ensure_process_alive(self) -> None: + def _ensure_processes_alive(self) -> None: max_unhealthy_seconds = options.get("spans.buffer.flusher.max-unhealthy-seconds") - if not self.process.is_alive(): - exitcode = getattr(self.process, "exitcode", "unknown") - cause = f"no_process_{exitcode}" - elif int(time.time()) - self.healthy_since.value > max_unhealthy_seconds: - cause = "hang" - else: - return # healthy - metrics.incr("spans.buffer.flusher_unhealthy", tags={"cause": cause}) - if self.process_restarts > MAX_PROCESS_RESTARTS: - raise RuntimeError(f"flusher process crashed repeatedly ({cause}), restarting consumer") + for process_index, process in self.processes.items(): + if not process: + continue + + shards = self.process_to_shards_map[process_index] + + cause = None + if not process.is_alive(): + exitcode = getattr(process, "exitcode", "unknown") + cause = f"no_process_{exitcode}" + elif ( + int(time.time()) - self.process_healthy_since[process_index].value + > max_unhealthy_seconds + ): + # Check if any shard handled by this process is unhealthy + cause = "hang" + + if cause is None: + continue # healthy + + # Report unhealthy for all shards handled by this process + for shard in shards: + metrics.incr( + "spans.buffer.flusher_unhealthy", tags={"cause": cause, "shard": shard} + ) - try: - self.process.kill() - except ValueError: - pass # Process already closed, ignore + if self.process_restarts[process_index] > MAX_PROCESS_RESTARTS: + raise RuntimeError( + f"flusher process for shards {shards} crashed repeatedly ({cause}), restarting consumer" + ) + self.process_restarts[process_index] += 1 - self.process_restarts += 1 - self._create_process() + try: + if isinstance(process, multiprocessing.Process): + process.kill() + except (ValueError, AttributeError): + pass # Process already closed, ignore + + self._create_process_for_shards(process_index, shards) def submit(self, message: Message[FilteredPayload | int]) -> None: # Note that submit is not actually a hot path. Their message payloads @@ -197,18 +264,22 @@ def submit(self, message: Message[FilteredPayload | int]) -> None: # per second at most. If anything, self.poll() might even be called # more often than submit() - self._ensure_process_alive() + self._ensure_processes_alive() - self.buffer.record_stored_segments() + for buffer in self.buffers.values(): + buffer.record_stored_segments() # We pause insertion into Redis if the flusher is not making progress # fast enough. We could backlog into Redis, but we assume, despite best # efforts, it is still always going to be less durable than Kafka. # Minimizing our Redis memory usage also makes COGS easier to reason # about. - if self.backpressure_since.value > 0: - backpressure_secs = options.get("spans.buffer.flusher.backpressure-seconds") - if int(time.time()) - self.backpressure_since.value > backpressure_secs: + backpressure_secs = options.get("spans.buffer.flusher.backpressure-seconds") + for backpressure_since in self.process_backpressure_since.values(): + if ( + backpressure_since.value > 0 + and int(time.time()) - backpressure_since.value > backpressure_secs + ): metrics.incr("spans.buffer.flusher.backpressure") raise MessageRejected() @@ -225,7 +296,9 @@ def submit(self, message: Message[FilteredPayload | int]) -> None: # wait until the situation is improved manually. max_memory_percentage = options.get("spans.buffer.max-memory-percentage") if max_memory_percentage < 1.0: - memory_infos = list(self.buffer.get_memory_info()) + memory_infos: list[ServiceMemory] = [] + for buffer in self.buffers.values(): + memory_infos.extend(buffer.get_memory_info()) used = sum(x.used for x in memory_infos) available = sum(x.available for x in memory_infos) if available > 0 and used / available > max_memory_percentage: @@ -253,15 +326,22 @@ def close(self) -> None: self.next_step.close() def join(self, timeout: float | None = None): - # set stopped flag first so we can "flush" the background thread while + # set stopped flag first so we can "flush" the background threads while # next_step is also shutting down. we can do two things at once! self.stopped.value = True deadline = time.time() + timeout if timeout else None self.next_step.join(timeout) - while self.process.is_alive() and (deadline is None or deadline > time.time()): - time.sleep(0.1) + # Wait for all processes to finish + for process_index, process in self.processes.items(): + if deadline is not None: + remaining_time = deadline - time.time() + if remaining_time <= 0: + break + + while process.is_alive() and (deadline is None or deadline > time.time()): + time.sleep(0.1) - if isinstance(self.process, multiprocessing.Process): - self.process.terminate() + if isinstance(process, multiprocessing.Process): + process.terminate() diff --git a/tests/sentry/spans/consumers/process/test_consumer.py b/tests/sentry/spans/consumers/process/test_consumer.py index f12a7ab386c..c216de2fe22 100644 --- a/tests/sentry/spans/consumers/process/test_consumer.py +++ b/tests/sentry/spans/consumers/process/test_consumer.py @@ -1,3 +1,4 @@ +import time from datetime import datetime import pytest @@ -8,7 +9,7 @@ from sentry.spans.consumers.process.factory import ProcessSpansStrategyFactory -@pytest.mark.django_db +@pytest.mark.django_db(transaction=True) def test_basic(monkeypatch): # Flush very aggressively to make test pass instantly monkeypatch.setattr("time.sleep", lambda _: None) @@ -56,6 +57,10 @@ def add_commit(offsets, force=False): step.poll() fac._flusher.current_drift.value = 9000 # "advance" our "clock" + step.poll() + # Give flusher threads time to process after drift change + time.sleep(0.1) + step.join() (msg,) = messages @@ -74,3 +79,45 @@ def add_commit(offsets, force=False): }, ], } + + +@pytest.mark.django_db(transaction=True) +def test_flusher_processes_limit(monkeypatch): + """Test that flusher respects the max_processes limit""" + # Flush very aggressively to make test pass instantly + monkeypatch.setattr("time.sleep", lambda _: None) + + topic = Topic("test") + messages: list[KafkaPayload] = [] + + # Create factory with limited flusher processes + fac = ProcessSpansStrategyFactory( + max_batch_size=10, + max_batch_time=10, + num_processes=1, + input_block_size=None, + output_block_size=None, + flusher_processes=2, # Limit to 2 processes even if more shards + produce_to_pipe=messages.append, + ) + + commits = [] + + def add_commit(offsets, force=False): + commits.append(offsets) + + # Create with 4 partitions/shards to test process sharing + partitions = {Partition(topic, i): 0 for i in range(4)} + step = fac.create_with_partitions(add_commit, partitions) + + # Verify that flusher uses at most 2 processes + flusher = fac._flusher + assert len(flusher.processes) == 2 + assert flusher.max_processes == 2 + assert flusher.num_processes == 2 + + # Verify shards are distributed across processes + total_shards = sum(len(shards) for shards in flusher.process_to_shards_map.values()) + assert total_shards == 4 # All 4 shards should be assigned + + step.join() diff --git a/tests/sentry/spans/consumers/process/test_flusher.py b/tests/sentry/spans/consumers/process/test_flusher.py index 6365653ae1f..0e886ffac28 100644 --- a/tests/sentry/spans/consumers/process/test_flusher.py +++ b/tests/sentry/spans/consumers/process/test_flusher.py @@ -80,4 +80,4 @@ def append(msg): assert messages - assert flusher.backpressure_since.value + assert any(x.value for x in flusher.process_backpressure_since.values())