1515
1616from sentry import options
1717from sentry .conf .types .kafka_definition import Topic
18+ from sentry .processing .backpressure .memory import ServiceMemory
1819from sentry .spans .buffer import SpansBuffer
1920from sentry .utils import metrics
2021from sentry .utils .arroyo import run_with_initialized_sentry
2728
2829class SpanFlusher (ProcessingStrategy [FilteredPayload | int ]):
2930 """
30- A background thread that polls Redis for new segments to flush and to produce to Kafka.
31+ A background multiprocessing manager that polls Redis for new segments to flush and to produce to Kafka.
32+ Creates one process per shard for parallel processing.
3133
3234 This is a processing step to be embedded into the consumer that writes to
3335 Redis. It takes and fowards integer messages that represent recently
@@ -42,27 +44,53 @@ def __init__(
4244 self ,
4345 buffer : SpansBuffer ,
4446 next_step : ProcessingStrategy [FilteredPayload | int ],
47+ max_processes : int | None = None ,
4548 produce_to_pipe : Callable [[KafkaPayload ], None ] | None = None ,
4649 ):
47- self .buffer = buffer
4850 self .next_step = next_step
51+ self .max_processes = max_processes or len (buffer .assigned_shards )
4952
5053 self .mp_context = mp_context = multiprocessing .get_context ("spawn" )
5154 self .stopped = mp_context .Value ("i" , 0 )
5255 self .redis_was_full = False
5356 self .current_drift = mp_context .Value ("i" , 0 )
54- self .backpressure_since = mp_context .Value ("i" , 0 )
55- self .healthy_since = mp_context .Value ("i" , 0 )
56- self .process_restarts = 0
5757 self .produce_to_pipe = produce_to_pipe
5858
59- self ._create_process ()
60-
61- def _create_process (self ):
59+ # Determine which shards get their own processes vs shared processes
60+ self .num_processes = min (self .max_processes , len (buffer .assigned_shards ))
61+ self .process_to_shards_map : dict [int , list [int ]] = {
62+ i : [] for i in range (self .num_processes )
63+ }
64+ for i , shard in enumerate (buffer .assigned_shards ):
65+ process_index = i % self .num_processes
66+ self .process_to_shards_map [process_index ].append (shard )
67+
68+ self .processes : dict [int , multiprocessing .context .SpawnProcess | threading .Thread ] = {}
69+ self .process_healthy_since = {
70+ process_index : mp_context .Value ("i" , int (time .time ()))
71+ for process_index in range (self .num_processes )
72+ }
73+ self .process_backpressure_since = {
74+ process_index : mp_context .Value ("i" , 0 ) for process_index in range (self .num_processes )
75+ }
76+ self .process_restarts = {process_index : 0 for process_index in range (self .num_processes )}
77+ self .buffers : dict [int , SpansBuffer ] = {}
78+
79+ self ._create_processes ()
80+
81+ def _create_processes (self ):
82+ # Create processes based on shard mapping
83+ for process_index , shards in self .process_to_shards_map .items ():
84+ self ._create_process_for_shards (process_index , shards )
85+
86+ def _create_process_for_shards (self , process_index : int , shards : list [int ]):
6287 # Optimistically reset healthy_since to avoid a race between the
6388 # starting process and the next flush cycle. Keep back pressure across
6489 # the restart, however.
65- self .healthy_since .value = int (time .time ())
90+ self .process_healthy_since [process_index ].value = int (time .time ())
91+
92+ # Create a buffer for these specific shards
93+ shard_buffer = SpansBuffer (shards )
6694
6795 make_process : Callable [..., multiprocessing .context .SpawnProcess | threading .Thread ]
6896 if self .produce_to_pipe is None :
@@ -72,37 +100,50 @@ def _create_process(self):
72100 # pickled separately. at the same time, pickling
73101 # synchronization primitives like multiprocessing.Value can
74102 # only be done by the Process
75- self . buffer ,
103+ shard_buffer ,
76104 )
77105 make_process = self .mp_context .Process
78106 else :
79- target = partial (SpanFlusher .main , self . buffer )
107+ target = partial (SpanFlusher .main , shard_buffer )
80108 make_process = threading .Thread
81109
82- self . process = make_process (
110+ process = make_process (
83111 target = target ,
84112 args = (
113+ shards ,
85114 self .stopped ,
86115 self .current_drift ,
87- self .backpressure_since ,
88- self .healthy_since ,
116+ self .process_backpressure_since [ process_index ] ,
117+ self .process_healthy_since [ process_index ] ,
89118 self .produce_to_pipe ,
90119 ),
91120 daemon = True ,
92121 )
93122
94- self .process .start ()
123+ process .start ()
124+ self .processes [process_index ] = process
125+ self .buffers [process_index ] = shard_buffer
126+
127+ def _create_process_for_shard (self , shard : int ):
128+ # Find which process this shard belongs to and restart that process
129+ for process_index , shards in self .process_to_shards_map .items ():
130+ if shard in shards :
131+ self ._create_process_for_shards (process_index , shards )
132+ break
95133
96134 @staticmethod
97135 def main (
98136 buffer : SpansBuffer ,
137+ shards : list [int ],
99138 stopped ,
100139 current_drift ,
101140 backpressure_since ,
102141 healthy_since ,
103142 produce_to_pipe : Callable [[KafkaPayload ], None ] | None ,
104143 ) -> None :
144+ shard_tag = "," .join (map (str , shards ))
105145 sentry_sdk .set_tag ("sentry_spans_buffer_component" , "flusher" )
146+ sentry_sdk .set_tag ("sentry_spans_buffer_shards" , shard_tag )
106147
107148 try :
108149 producer_futures = []
@@ -134,23 +175,28 @@ def produce(payload: KafkaPayload) -> None:
134175 else :
135176 backpressure_since .value = 0
136177
178+ # Update healthy_since for all shards handled by this process
137179 healthy_since .value = system_now
138180
139181 if not flushed_segments :
140182 time .sleep (1 )
141183 continue
142184
143- with metrics .timer ("spans.buffer.flusher.produce" ):
144- for _ , flushed_segment in flushed_segments .items ():
185+ with metrics .timer ("spans.buffer.flusher.produce" , tags = { "shard" : shard_tag } ):
186+ for flushed_segment in flushed_segments .values ():
145187 if not flushed_segment .spans :
146188 continue
147189
148190 spans = [span .payload for span in flushed_segment .spans ]
149191 kafka_payload = KafkaPayload (None , orjson .dumps ({"spans" : spans }), [])
150- metrics .timing ("spans.buffer.segment_size_bytes" , len (kafka_payload .value ))
192+ metrics .timing (
193+ "spans.buffer.segment_size_bytes" ,
194+ len (kafka_payload .value ),
195+ tags = {"shard" : shard_tag },
196+ )
151197 produce (kafka_payload )
152198
153- with metrics .timer ("spans.buffer.flusher.wait_produce" ):
199+ with metrics .timer ("spans.buffer.flusher.wait_produce" , tags = { "shards" : shard_tag } ):
154200 for future in producer_futures :
155201 future .result ()
156202
@@ -169,46 +215,71 @@ def produce(payload: KafkaPayload) -> None:
169215 def poll (self ) -> None :
170216 self .next_step .poll ()
171217
172- def _ensure_process_alive (self ) -> None :
218+ def _ensure_processes_alive (self ) -> None :
173219 max_unhealthy_seconds = options .get ("spans.buffer.flusher.max-unhealthy-seconds" )
174- if not self .process .is_alive ():
175- exitcode = getattr (self .process , "exitcode" , "unknown" )
176- cause = f"no_process_{ exitcode } "
177- elif int (time .time ()) - self .healthy_since .value > max_unhealthy_seconds :
178- cause = "hang"
179- else :
180- return # healthy
181220
182- metrics .incr ("spans.buffer.flusher_unhealthy" , tags = {"cause" : cause })
183- if self .process_restarts > MAX_PROCESS_RESTARTS :
184- raise RuntimeError (f"flusher process crashed repeatedly ({ cause } ), restarting consumer" )
221+ for process_index , process in self .processes .items ():
222+ if not process :
223+ continue
224+
225+ shards = self .process_to_shards_map [process_index ]
226+
227+ cause = None
228+ if not process .is_alive ():
229+ exitcode = getattr (process , "exitcode" , "unknown" )
230+ cause = f"no_process_{ exitcode } "
231+ elif (
232+ int (time .time ()) - self .process_healthy_since [process_index ].value
233+ > max_unhealthy_seconds
234+ ):
235+ # Check if any shard handled by this process is unhealthy
236+ cause = "hang"
237+
238+ if cause is None :
239+ continue # healthy
240+
241+ # Report unhealthy for all shards handled by this process
242+ for shard in shards :
243+ metrics .incr (
244+ "spans.buffer.flusher_unhealthy" , tags = {"cause" : cause , "shard" : shard }
245+ )
185246
186- try :
187- self .process .kill ()
188- except ValueError :
189- pass # Process already closed, ignore
247+ if self .process_restarts [process_index ] > MAX_PROCESS_RESTARTS :
248+ raise RuntimeError (
249+ f"flusher process for shards { shards } crashed repeatedly ({ cause } ), restarting consumer"
250+ )
251+ self .process_restarts [process_index ] += 1
190252
191- self .process_restarts += 1
192- self ._create_process ()
253+ try :
254+ if isinstance (process , multiprocessing .Process ):
255+ process .kill ()
256+ except (ValueError , AttributeError ):
257+ pass # Process already closed, ignore
258+
259+ self ._create_process_for_shards (process_index , shards )
193260
194261 def submit (self , message : Message [FilteredPayload | int ]) -> None :
195262 # Note that submit is not actually a hot path. Their message payloads
196263 # are mapped from *batches* of spans, and there are a handful of spans
197264 # per second at most. If anything, self.poll() might even be called
198265 # more often than submit()
199266
200- self ._ensure_process_alive ()
267+ self ._ensure_processes_alive ()
201268
202- self .buffer .record_stored_segments ()
269+ for buffer in self .buffers .values ():
270+ buffer .record_stored_segments ()
203271
204272 # We pause insertion into Redis if the flusher is not making progress
205273 # fast enough. We could backlog into Redis, but we assume, despite best
206274 # efforts, it is still always going to be less durable than Kafka.
207275 # Minimizing our Redis memory usage also makes COGS easier to reason
208276 # about.
209- if self .backpressure_since .value > 0 :
210- backpressure_secs = options .get ("spans.buffer.flusher.backpressure-seconds" )
211- if int (time .time ()) - self .backpressure_since .value > backpressure_secs :
277+ backpressure_secs = options .get ("spans.buffer.flusher.backpressure-seconds" )
278+ for backpressure_since in self .process_backpressure_since .values ():
279+ if (
280+ backpressure_since .value > 0
281+ and int (time .time ()) - backpressure_since .value > backpressure_secs
282+ ):
212283 metrics .incr ("spans.buffer.flusher.backpressure" )
213284 raise MessageRejected ()
214285
@@ -225,7 +296,9 @@ def submit(self, message: Message[FilteredPayload | int]) -> None:
225296 # wait until the situation is improved manually.
226297 max_memory_percentage = options .get ("spans.buffer.max-memory-percentage" )
227298 if max_memory_percentage < 1.0 :
228- memory_infos = list (self .buffer .get_memory_info ())
299+ memory_infos : list [ServiceMemory ] = []
300+ for buffer in self .buffers .values ():
301+ memory_infos .extend (buffer .get_memory_info ())
229302 used = sum (x .used for x in memory_infos )
230303 available = sum (x .available for x in memory_infos )
231304 if available > 0 and used / available > max_memory_percentage :
@@ -253,15 +326,22 @@ def close(self) -> None:
253326 self .next_step .close ()
254327
255328 def join (self , timeout : float | None = None ):
256- # set stopped flag first so we can "flush" the background thread while
329+ # set stopped flag first so we can "flush" the background threads while
257330 # next_step is also shutting down. we can do two things at once!
258331 self .stopped .value = True
259332 deadline = time .time () + timeout if timeout else None
260333
261334 self .next_step .join (timeout )
262335
263- while self .process .is_alive () and (deadline is None or deadline > time .time ()):
264- time .sleep (0.1 )
336+ # Wait for all processes to finish
337+ for process_index , process in self .processes .items ():
338+ if deadline is not None :
339+ remaining_time = deadline - time .time ()
340+ if remaining_time <= 0 :
341+ break
342+
343+ while process .is_alive () and (deadline is None or deadline > time .time ()):
344+ time .sleep (0.1 )
265345
266- if isinstance (self . process , multiprocessing .Process ):
267- self . process .terminate ()
346+ if isinstance (process , multiprocessing .Process ):
347+ process .terminate ()
0 commit comments