Skip to content

Commit d89b5b0

Browse files
committed
refactor: add metrics middleware
Signed-off-by: Gustavo Inacio <[email protected]>
1 parent d7a847b commit d89b5b0

File tree

3 files changed

+258
-15
lines changed

3 files changed

+258
-15
lines changed

Cargo.lock

Lines changed: 14 additions & 13 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/service/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,13 @@ axum-extra = { version = "0.9.3", features = [
5353
tokio-util = "0.7.10"
5454
cost-model = { git = "https://github.com/graphprotocol/agora", rev = "3ed34ca" }
5555
bip39.workspace = true
56+
tower = "0.5.1"
57+
pin-project = "1.1.7"
5658

5759
[dev-dependencies]
5860
hex-literal = "0.4.1"
5961
test-assets = { path = "../test-assets" }
6062
tower-test = "0.4.0"
61-
tower = "0.5.1"
6263
tokio-test = "0.4.4"
6364

6465
[build-dependencies]
Lines changed: 242 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,249 @@
1-
use std::sync::Arc;
1+
// Copyright 2023-, Edge & Node, GraphOps, and Semiotic Labs.
2+
// SPDX-License-Identifier: Apache-2.0
23

4+
//! update metrics in case it succeeds or fails
5+
6+
use axum::http::Request;
7+
use pin_project::pin_project;
8+
use prometheus::HistogramTimer;
9+
use std::{
10+
future::Future,
11+
pin::Pin,
12+
sync::Arc,
13+
task::{Context, Poll},
14+
};
15+
use tower::{Layer, Service};
316

417
pub type MetricLabels = Arc<dyn MetricLabelProvider + 'static + Send + Sync>;
518

619
pub trait MetricLabelProvider {
720
fn get_labels(&self) -> Vec<&str>;
821
}
22+
23+
#[derive(Clone)]
24+
pub struct MetricsMiddleware<S> {
25+
inner: S,
26+
histogram: prometheus::HistogramVec,
27+
failure: prometheus::CounterVec,
28+
}
29+
30+
#[derive(Clone)]
31+
pub struct MetricsMiddlewareLayer {
32+
histogram: prometheus::HistogramVec,
33+
failure: prometheus::CounterVec,
34+
}
35+
36+
impl MetricsMiddlewareLayer {
37+
pub fn new(histogram: prometheus::HistogramVec, failure: prometheus::CounterVec) -> Self {
38+
Self { histogram, failure }
39+
}
40+
}
41+
42+
impl<S> Layer<S> for MetricsMiddlewareLayer {
43+
type Service = MetricsMiddleware<S>;
44+
45+
fn layer(&self, inner: S) -> Self::Service {
46+
MetricsMiddleware {
47+
inner,
48+
histogram: self.histogram.clone(),
49+
failure: self.failure.clone(),
50+
}
51+
}
52+
}
53+
54+
impl<S, ReqBody> Service<Request<ReqBody>> for MetricsMiddleware<S>
55+
where
56+
S: Service<Request<ReqBody>> + Clone + 'static,
57+
ReqBody: 'static,
58+
{
59+
type Response = S::Response;
60+
type Error = S::Error;
61+
type Future = MetricsFuture<S::Future>;
62+
63+
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
64+
self.inner.poll_ready(cx)
65+
}
66+
67+
fn call(&mut self, request: Request<ReqBody>) -> MetricsFuture<S::Future> {
68+
let labels = request.extensions().get::<MetricLabels>().cloned();
69+
MetricsFuture {
70+
timer: None,
71+
histogram: self.histogram.clone(),
72+
failure: self.failure.clone(),
73+
labels,
74+
fut: self.inner.call(request),
75+
}
76+
}
77+
}
78+
79+
#[pin_project]
80+
pub struct MetricsFuture<F> {
81+
/// Instant at which we started the requst.
82+
timer: Option<HistogramTimer>,
83+
84+
histogram: prometheus::HistogramVec,
85+
failure: prometheus::CounterVec,
86+
87+
labels: Option<MetricLabels>,
88+
89+
#[pin]
90+
fut: F,
91+
}
92+
93+
impl<F, R, E> Future for MetricsFuture<F>
94+
where
95+
F: Future<Output = Result<R, E>>,
96+
{
97+
type Output = F::Output;
98+
99+
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
100+
let this = self.project();
101+
let Some(labels) = &this.labels else {
102+
return this.fut.poll(cx);
103+
};
104+
105+
if this.timer.is_none() {
106+
// Start timer so we can track duration of request.
107+
let duration_metric = this.histogram.with_label_values(&labels.get_labels());
108+
*this.timer = Some(duration_metric.start_timer());
109+
}
110+
111+
match this.fut.poll(cx) {
112+
Poll::Ready(result) => {
113+
if result.is_err() {
114+
let _ = this
115+
.failure
116+
.get_metric_with_label_values(&labels.get_labels());
117+
}
118+
// Record the duration of this request.
119+
if let Some(timer) = this.timer.take() {
120+
timer.observe_duration();
121+
}
122+
Poll::Ready(result)
123+
}
124+
Poll::Pending => Poll::Pending,
125+
}
126+
}
127+
}
128+
129+
#[cfg(test)]
130+
mod tests {
131+
use std::sync::Arc;
132+
133+
use axum::{
134+
body::Body,
135+
http::{Request, Response},
136+
};
137+
use prometheus::core::Collector;
138+
use tower::{Service, ServiceBuilder, ServiceExt};
139+
140+
use crate::middleware::metrics::{MetricLabels, MetricsMiddlewareLayer};
141+
142+
use super::MetricLabelProvider;
143+
144+
struct TestLabel;
145+
impl MetricLabelProvider for TestLabel {
146+
fn get_labels(&self) -> Vec<&str> {
147+
vec!["label1,", "label2", "label3"]
148+
}
149+
}
150+
async fn handle(_: Request<Body>) -> anyhow::Result<Response<Body>> {
151+
Ok(Response::new(Body::default()))
152+
}
153+
154+
async fn handle_err(_: Request<Body>) -> anyhow::Result<Response<Body>> {
155+
Err(anyhow::anyhow!("Error"))
156+
}
157+
158+
#[tokio::test]
159+
async fn test_metrics_middleware() {
160+
let registry = prometheus::Registry::new();
161+
let histogram_metric = prometheus::register_histogram_vec_with_registry!(
162+
"histogram_metric",
163+
"Test",
164+
&["deployment", "sender", "allocation"],
165+
registry,
166+
)
167+
.unwrap();
168+
169+
let failure_metric = prometheus::register_counter_vec_with_registry!(
170+
"failure_metric",
171+
"Test",
172+
&["deployment", "sender", "allocation"],
173+
registry,
174+
)
175+
.unwrap();
176+
177+
// check if everything is clean
178+
assert_eq!(
179+
histogram_metric
180+
.collect()
181+
.first()
182+
.unwrap()
183+
.get_metric()
184+
.len(),
185+
0
186+
);
187+
assert_eq!(
188+
failure_metric.collect().first().unwrap().get_metric().len(),
189+
0
190+
);
191+
192+
let metrics_layer =
193+
MetricsMiddlewareLayer::new(histogram_metric.clone(), failure_metric.clone());
194+
let mut service = ServiceBuilder::new()
195+
.layer(metrics_layer)
196+
.service_fn(handle);
197+
let handle = service.ready().await.unwrap();
198+
199+
// default labels, all empty
200+
let labels: MetricLabels = Arc::new(TestLabel);
201+
202+
let mut req = Request::new(Default::default());
203+
req.extensions_mut().insert(labels.clone());
204+
let _ = handle.call(req).await;
205+
206+
assert_eq!(
207+
histogram_metric
208+
.collect()
209+
.first()
210+
.unwrap()
211+
.get_metric()
212+
.len(),
213+
1
214+
);
215+
216+
assert_eq!(
217+
failure_metric.collect().first().unwrap().get_metric().len(),
218+
0
219+
);
220+
221+
let metrics_layer =
222+
MetricsMiddlewareLayer::new(histogram_metric.clone(), failure_metric.clone());
223+
let mut service = ServiceBuilder::new()
224+
.layer(metrics_layer)
225+
.service_fn(handle_err);
226+
let handle = service.ready().await.unwrap();
227+
228+
let mut req = Request::new(Default::default());
229+
req.extensions_mut().insert(labels);
230+
let _ = handle.call(req).await;
231+
232+
// it's using the same labels, should have only one metric
233+
assert_eq!(
234+
histogram_metric
235+
.collect()
236+
.first()
237+
.unwrap()
238+
.get_metric()
239+
.len(),
240+
1
241+
);
242+
243+
// new failture
244+
assert_eq!(
245+
failure_metric.collect().first().unwrap().get_metric().len(),
246+
1
247+
);
248+
}
249+
}

0 commit comments

Comments
 (0)