Skip to content

Commit 71e7331

Browse files
committed
refactor: add metrics middleware
Signed-off-by: Gustavo Inacio <[email protected]>
1 parent efa04a2 commit 71e7331

File tree

3 files changed

+255
-15
lines changed

3 files changed

+255
-15
lines changed

Cargo.lock

Lines changed: 14 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/service/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,13 @@ axum-extra = { version = "0.9.3", features = [
5353
tokio-util = "0.7.10"
5454
cost-model = { git = "https://github.com/graphprotocol/agora", rev = "3ed34ca" }
5555
bip39.workspace = true
56+
tower = "0.5.1"
57+
pin-project = "1.1.7"
5658

5759
[dev-dependencies]
5860
hex-literal = "0.4.1"
5961
test-assets = { path = "../test-assets" }
6062
tower-test = "0.4.0"
61-
tower = "0.5.1"
6263
tokio-test = "0.4.4"
6364

6465
[build-dependencies]
Lines changed: 239 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,246 @@
1-
use std::sync::Arc;
1+
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
2+
// SPDX-License-Identifier: Apache-2.0
23

4+
//! update metrics in case it succeeds or fails
5+
6+
use axum::http::Request;
7+
use pin_project::pin_project;
8+
use prometheus::HistogramTimer;
9+
use std::{
10+
future::Future,
11+
pin::Pin,
12+
sync::Arc,
13+
task::{Context, Poll},
14+
};
15+
use tower::{Layer, Service};
316

417
pub type MetricLabels = Arc<dyn MetricLabelProvider + 'static + Send + Sync>;
518

619
pub trait MetricLabelProvider {
720
fn get_labels(&self) -> Vec<&str>;
821
}
22+
23+
#[derive(Clone)]
24+
pub struct MetricsMiddleware<S> {
25+
inner: S,
26+
histogram: prometheus::HistogramVec,
27+
failure: prometheus::CounterVec,
28+
}
29+
30+
#[derive(Clone)]
31+
pub struct MetricsMiddlewareLayer {
32+
histogram: prometheus::HistogramVec,
33+
failure: prometheus::CounterVec,
34+
}
35+
36+
impl MetricsMiddlewareLayer {
37+
pub fn new(histogram: prometheus::HistogramVec, failure: prometheus::CounterVec) -> Self {
38+
Self { histogram, failure }
39+
}
40+
}
41+
42+
impl<S> Layer<S> for MetricsMiddlewareLayer {
43+
type Service = MetricsMiddleware<S>;
44+
45+
fn layer(&self, inner: S) -> Self::Service {
46+
MetricsMiddleware {
47+
inner,
48+
histogram: self.histogram.clone(),
49+
failure: self.failure.clone(),
50+
}
51+
}
52+
}
53+
54+
impl<S, ReqBody> Service<Request<ReqBody>> for MetricsMiddleware<S>
55+
where
56+
S: Service<Request<ReqBody>> + Clone + 'static,
57+
ReqBody: 'static,
58+
{
59+
type Response = S::Response;
60+
type Error = S::Error;
61+
type Future = MetricsFuture<S::Future>;
62+
63+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
64+
self.inner.poll_ready(cx)
65+
}
66+
67+
fn call(&mut self, request: Request<ReqBody>) -> MetricsFuture<S::Future> {
68+
let labels = request.extensions().get::<MetricLabels>().cloned();
69+
MetricsFuture {
70+
timer: None,
71+
histogram: self.histogram.clone(),
72+
failure: self.failure.clone(),
73+
labels,
74+
fut: self.inner.call(request),
75+
}
76+
}
77+
}
78+
79+
#[pin_project]
80+
pub struct MetricsFuture<F> {
81+
/// Instant at which we started the requst.
82+
timer: Option<HistogramTimer>,
83+
84+
histogram: prometheus::HistogramVec,
85+
failure: prometheus::CounterVec,
86+
87+
labels: Option<MetricLabels>,
88+
89+
#[pin]
90+
fut: F,
91+
}
92+
93+
impl<F, R, E> Future for MetricsFuture<F>
94+
where
95+
F: Future<Output = Result<R, E>>,
96+
{
97+
type Output = F::Output;
98+
99+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
100+
let this = self.project();
101+
let Some(labels) = &this.labels else {
102+
return this.fut.poll(cx);
103+
};
104+
105+
if this.timer.is_none() {
106+
// Start timer so we can track duration of request.
107+
let duration_metric = this.histogram.with_label_values(&labels.get_labels());
108+
*this.timer = Some(duration_metric.start_timer());
109+
}
110+
111+
match this.fut.poll(cx) {
112+
Poll::Ready(result) => {
113+
if result.is_err() {
114+
let _ = this
115+
.failure
116+
.get_metric_with_label_values(&labels.get_labels());
117+
}
118+
// Record the duration of this request.
119+
if let Some(timer) = this.timer.take() {
120+
timer.observe_duration();
121+
}
122+
Poll::Ready(result)
123+
}
124+
Poll::Pending => Poll::Pending,
125+
}
126+
}
127+
}
128+
129+
#[cfg(test)]
130+
mod tests {
131+
use std::sync::Arc;
132+
133+
use axum::{body::Body, http::{Request, Response}};
134+
use prometheus::core::Collector;
135+
use tower::{Service, ServiceBuilder, ServiceExt};
136+
137+
use crate::middleware::metrics::{MetricLabels, MetricsMiddlewareLayer};
138+
139+
use super::MetricLabelProvider;
140+
141+
struct TestLabel;
142+
impl MetricLabelProvider for TestLabel {
143+
fn get_labels(&self) -> Vec<&str> {
144+
vec!["label1,", "label2", "label3"]
145+
}
146+
}
147+
async fn handle(_: Request<Body>) -> anyhow::Result<Response<Body>> {
148+
Ok(Response::new(Body::default()))
149+
}
150+
151+
async fn handle_err(_: Request<Body>) -> anyhow::Result<Response<Body>> {
152+
Err(anyhow::anyhow!("Error"))
153+
}
154+
155+
#[tokio::test]
156+
async fn test_metrics_middleware() {
157+
let registry = prometheus::Registry::new();
158+
let histogram_metric = prometheus::register_histogram_vec_with_registry!(
159+
"histogram_metric",
160+
"Test",
161+
&["deployment", "sender", "allocation"],
162+
registry,
163+
)
164+
.unwrap();
165+
166+
let failure_metric = prometheus::register_counter_vec_with_registry!(
167+
"failure_metric",
168+
"Test",
169+
&["deployment", "sender", "allocation"],
170+
registry,
171+
)
172+
.unwrap();
173+
174+
// check if everything is clean
175+
assert_eq!(
176+
histogram_metric
177+
.collect()
178+
.first()
179+
.unwrap()
180+
.get_metric()
181+
.len(),
182+
0
183+
);
184+
assert_eq!(
185+
failure_metric.collect().first().unwrap().get_metric().len(),
186+
0
187+
);
188+
189+
let metrics_layer =
190+
MetricsMiddlewareLayer::new(histogram_metric.clone(), failure_metric.clone());
191+
let mut service = ServiceBuilder::new()
192+
.layer(metrics_layer)
193+
.service_fn(handle);
194+
let handle = service.ready().await.unwrap();
195+
196+
// default labels, all empty
197+
let labels: MetricLabels = Arc::new(TestLabel);
198+
199+
let mut req = Request::new(Default::default());
200+
req.extensions_mut().insert(labels.clone());
201+
let _ = handle.call(req).await;
202+
203+
assert_eq!(
204+
histogram_metric
205+
.collect()
206+
.first()
207+
.unwrap()
208+
.get_metric()
209+
.len(),
210+
1
211+
);
212+
213+
assert_eq!(
214+
failure_metric.collect().first().unwrap().get_metric().len(),
215+
0
216+
);
217+
218+
let metrics_layer =
219+
MetricsMiddlewareLayer::new(histogram_metric.clone(), failure_metric.clone());
220+
let mut service = ServiceBuilder::new()
221+
.layer(metrics_layer)
222+
.service_fn(handle_err);
223+
let handle = service.ready().await.unwrap();
224+
225+
let mut req = Request::new(Default::default());
226+
req.extensions_mut().insert(labels);
227+
let _ = handle.call(req).await;
228+
229+
// it's using the same labels, should have only one metric
230+
assert_eq!(
231+
histogram_metric
232+
.collect()
233+
.first()
234+
.unwrap()
235+
.get_metric()
236+
.len(),
237+
1
238+
);
239+
240+
// new failture
241+
assert_eq!(
242+
failure_metric.collect().first().unwrap().get_metric().len(),
243+
1
244+
);
245+
}
246+
}

0 commit comments

Comments
 (0)