Skip to content

Commit 72d11c7

Browse files
committed
feat: add multiprocessing support
add multiprocessing support by ensuring ControllerClient is spawned for each pid
1 parent 8f7bf7f commit 72d11c7

File tree

2 files changed

+69
-23
lines changed

2 files changed

+69
-23
lines changed

sentry_dynamic_sampling_lib/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def init_wrapper():
5252
app_key = build_app_key(client.options)
5353
controller_endpoint = urljoin(controller_host, controller_path)
5454
metric_endpoint = urljoin(controller_host, metric_path)
55-
print("Sentry Wrapper: Injecting TracesSampler")
55+
print(f"Sentry Wrapper: Injecting TracesSampler. App Key : {app_key}")
5656
client.options["traces_sampler"] = TraceSampler(
5757
poll_interval=poll_interval,
5858
metric_interval=metric_interval,

sentry_dynamic_sampling_lib/sampler.py

Lines changed: 68 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import logging
2+
import os
23
import signal
34
from threading import Event, Thread
45
from time import sleep
6+
from typing import Optional
57

68
import schedule
79
from requests.exceptions import RequestException
@@ -25,16 +27,28 @@ def on_exit(*args, **kwargs):
2527

2628

2729
class ControllerClient(Thread):
28-
def __init__(self, stop, config, metric, *args, **kwargs) -> None:
29-
self.poll_interval = kwargs.pop("poll_interval")
30-
self.metric_interval = kwargs.pop("metric_interval")
31-
self.controller_endpoint = kwargs.pop("controller_endpoint")
32-
self.metric_endpoint = kwargs.pop("metric_endpoint")
33-
self.app_key = kwargs.pop("app_key")
34-
self.stop: Event = stop
35-
self.config: Config = config
36-
self.metrics: Metric = metric
30+
def __init__(
31+
self,
32+
*args,
33+
poll_interval=None,
34+
metric_interval=None,
35+
controller_endpoint=None,
36+
metric_endpoint=None,
37+
app_key=None,
38+
**kwargs
39+
) -> None:
40+
self.poll_interval = poll_interval
41+
self.metric_interval = metric_interval
42+
self.controller_endpoint = controller_endpoint
43+
self.metric_endpoint = metric_endpoint
44+
self.app_key = app_key
45+
46+
self._stop = Event()
47+
self.config = Config()
48+
self.metrics = Metric()
3749
self.session = CachedSession(backend="memory", cache_control=True)
50+
51+
LOGGER.debug("ControllerClient Initialized")
3852
super().__init__(*args, name="SentryControllerClient", **kwargs)
3953

4054
def run(self):
@@ -44,10 +58,16 @@ def run(self):
4458
sleep(5)
4559
schedule.every(self.poll_interval).seconds.do(self.update_config)
4660
schedule.every(self.metric_interval).seconds.do(self.update_metrics)
47-
while not self.stop.is_set():
61+
LOGGER.debug("ControllerClient Started")
62+
while not self._stop.is_set():
4863
schedule.run_pending()
4964
sleep(1)
5065

66+
def kill(self):
67+
self._stop.set()
68+
if self.is_alive():
69+
self.join()
70+
5171
def update_config(self):
5272
try:
5373
resp = self.session.get(
@@ -59,8 +79,10 @@ def update_config(self):
5979
return
6080

6181
if resp.from_cache:
82+
LOGGER.debug("Config Polled from cache")
6283
return
6384

85+
LOGGER.debug("Config Polled")
6486
data = resp.json()
6587
self.config.update(data)
6688
self.metrics.set_mode(MetricType.CELERY, data["celery_collect_metrics"])
@@ -71,6 +93,7 @@ def update_metrics(self):
7193
# check if metric is enable
7294
mode = self.metrics.get_mode(metric_type)
7395
if not mode:
96+
LOGGER.debug("Metric %s disabled", metric_type.value)
7497
continue
7598

7699
counter = self.metrics.get_and_reset(metric_type)
@@ -86,36 +109,59 @@ def update_metrics(self):
86109
self.metric_endpoint.format(self.app_key, metric_type.value),
87110
json=data,
88111
)
112+
LOGGER.debug("Metric %s pushed", metric_type.value)
89113
except RequestException as err:
90114
LOGGER.warning("Metric Request Failed: %s", err)
91115
return
92116

93117

94118
class TraceSampler(metaclass=Singleton):
95119
def __init__(self, *args, **kwargs) -> None:
96-
self.stop = Event()
97-
self.config = Config()
98-
self.metrics = Metric()
99-
self.controller = ControllerClient(
100-
*args, self.stop, self.config, self.metrics, **kwargs
101-
)
102-
self.controller.start()
120+
self.params = (args, kwargs)
121+
self._controller: Optional[ControllerClient] = None
122+
self._tread_for_pid: Optional[int] = None
103123

104124
signal.signal(signal.SIGINT, on_exit)
105-
106125
# HACK: Celery has a built in signal mechanism
107126
# so we use it
108127
if worker_shutdown:
109128
worker_shutdown.connect(on_exit)
110129

111-
def __del__(self):
112-
on_exit(self.stop, self.controller)
130+
@property
131+
def has_running_controller(self):
132+
if self._tread_for_pid != os.getpid():
133+
return False
134+
if not self._controller:
135+
return None
136+
return self._controller.is_alive()
137+
138+
@property
139+
def config(self) -> Config:
140+
return self._controller.config
141+
142+
@property
143+
def metrics(self) -> Metric:
144+
return self._controller.metrics
113145

114146
def kill(self):
115-
self.stop.set()
116-
self.controller.join()
147+
if self._controller:
148+
self._controller.kill()
149+
150+
def _start_controller(self):
151+
args, kwargs = self.params
152+
self._controller = ControllerClient(*args, **kwargs)
153+
self._controller.start()
154+
self._tread_for_pid = os.getpid()
155+
156+
def _ensure_controller(self):
157+
if not self.has_running_controller:
158+
self._start_controller()
159+
160+
def __del__(self):
161+
self.kill()
117162

118163
def __call__(self, sampling_context):
164+
self._ensure_controller()
119165
if sampling_context:
120166
if "wsgi_environ" in sampling_context:
121167
path = sampling_context["wsgi_environ"].get("PATH_INFO", "")

0 commit comments

Comments
 (0)