|
| 1 | +import ssl |
| 2 | +import time |
| 3 | +from ssl import SSLSocket |
| 4 | +from typing import Any, Dict, Tuple |
| 5 | + |
| 6 | +import prometheus_client as prom |
| 7 | + |
| 8 | +from generic_connection_pool.contrib.socket import SslSocketConnectionManager |
| 9 | +from generic_connection_pool.threading import ConnectionPool |
| 10 | + |
| 11 | +Hostname = str |
| 12 | +Port = int |
| 13 | +Endpoint = Tuple[Hostname, Port] |
| 14 | + |
| 15 | +acquire_latency_hist = prom.Histogram('acquire_latency', 'Connections acquire latency', labelnames=['hostname']) |
| 16 | +acquire_total = prom.Counter('acquire_total', 'Connections acquire count', labelnames=['hostname']) |
| 17 | +dead_conn_total = prom.Counter('dead_conn_total', 'Dead connections count', labelnames=['hostname']) |
| 18 | + |
| 19 | + |
| 20 | +class ObservableConnectionManager(SslSocketConnectionManager): |
| 21 | + |
| 22 | + def __init__(self, *args: Any, **kwargs: Any): |
| 23 | + super().__init__(*args, **kwargs) |
| 24 | + self._acquires: Dict[SSLSocket, float] = {} |
| 25 | + |
| 26 | + def on_acquire(self, endpoint: Endpoint, conn: SSLSocket) -> None: |
| 27 | + hostname, port = endpoint |
| 28 | + |
| 29 | + acquire_total.labels(hostname).inc() |
| 30 | + self._acquires[conn] = time.time() |
| 31 | + |
| 32 | + def on_release(self, endpoint: Endpoint, conn: SSLSocket) -> None: |
| 33 | + hostname, port = endpoint |
| 34 | + |
| 35 | + acquired_at = self._acquires.pop(conn) |
| 36 | + acquire_latency_hist.labels(hostname).observe(time.time() - acquired_at) |
| 37 | + |
| 38 | + def on_connection_dead(self, endpoint: Endpoint, conn: SSLSocket) -> None: |
| 39 | + hostname, port = endpoint |
| 40 | + |
| 41 | + dead_conn_total.labels(hostname).inc() |
| 42 | + |
| 43 | + |
| 44 | +http_pool = ConnectionPool[Endpoint, SSLSocket]( |
| 45 | + ObservableConnectionManager(ssl.create_default_context()), |
| 46 | +) |
0 commit comments