55
66use axum:: http:: Request ;
77use pin_project:: pin_project;
8- use prometheus:: HistogramTimer ;
98use std:: {
109 future:: Future ,
1110 pin:: Pin ,
1211 sync:: Arc ,
1312 task:: { Context , Poll } ,
13+ time:: Instant ,
1414} ;
1515use tower:: { Layer , Service } ;
1616
17+ use crate :: error:: StatusCodeExt ;
18+
1719pub type MetricLabels = Arc < dyn MetricLabelProvider + ' static + Send + Sync > ;
1820
1921pub trait MetricLabelProvider {
@@ -25,7 +27,6 @@ pub trait MetricLabelProvider {
2527pub struct PrometheusMetricsMiddleware < S > {
2628 inner : S ,
2729 histogram : prometheus:: HistogramVec ,
28- failure : prometheus:: CounterVec ,
2930}
3031
3132/// MetricsMiddleware used in tower components
@@ -35,13 +36,11 @@ pub struct PrometheusMetricsMiddleware<S> {
3536pub struct PrometheusMetricsMiddlewareLayer {
3637 /// Histogram used to register the processing timer
3738 histogram : prometheus:: HistogramVec ,
38- /// Counter metric in case of failure
39- failure : prometheus:: CounterVec ,
4039}
4140
4241impl PrometheusMetricsMiddlewareLayer {
43- pub fn new ( histogram : prometheus:: HistogramVec , failure : prometheus :: CounterVec ) -> Self {
44- Self { histogram, failure }
42+ pub fn new ( histogram : prometheus:: HistogramVec ) -> Self {
43+ Self { histogram }
4544 }
4645}
4746
@@ -52,7 +51,6 @@ impl<S> Layer<S> for PrometheusMetricsMiddlewareLayer {
5251 PrometheusMetricsMiddleware {
5352 inner,
5453 histogram : self . histogram . clone ( ) ,
55- failure : self . failure . clone ( ) ,
5654 }
5755 }
5856}
@@ -61,6 +59,7 @@ impl<S, ReqBody> Service<Request<ReqBody>> for PrometheusMetricsMiddleware<S>
6159where
6260 S : Service < Request < ReqBody > > + Clone + ' static ,
6361 ReqBody : ' static ,
62+ Result < S :: Response , S :: Error > : StatusCodeExt ,
6463{
6564 type Response = S :: Response ;
6665 type Error = S :: Error ;
7574 PrometheusMetricsFuture {
7675 timer : None ,
7776 histogram : self . histogram . clone ( ) ,
78- failure : self . failure . clone ( ) ,
7977 labels,
8078 fut : self . inner . call ( request) ,
8179 }
@@ -85,46 +83,45 @@ where
8583#[ pin_project]
8684pub struct PrometheusMetricsFuture < F > {
8785 /// Instant at which we started the requst.
88- timer : Option < HistogramTimer > ,
86+ timer : Option < Instant > ,
8987
9088 histogram : prometheus:: HistogramVec ,
91- failure : prometheus:: CounterVec ,
92-
9389 labels : Option < MetricLabels > ,
9490
9591 #[ pin]
9692 fut : F ,
9793}
9894
99- impl < F , R , E > Future for PrometheusMetricsFuture < F >
95+ impl < F , T > Future for PrometheusMetricsFuture < F >
10096where
101- F : Future < Output = Result < R , E > > ,
97+ F : Future < Output = T > ,
98+ T : StatusCodeExt ,
10299{
103100 type Output = F :: Output ;
104101
105102 fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
106103 let this = self . project ( ) ;
107- let Some ( labels) = & this. labels else {
104+ let Some ( labels) = this. labels else {
108105 return this. fut . poll ( cx) ;
109106 } ;
110107
111108 if this. timer . is_none ( ) {
112109 // Start timer so we can track duration of request.
113- let duration_metric = this. histogram . with_label_values ( & labels. get_labels ( ) ) ;
114- * this. timer = Some ( duration_metric. start_timer ( ) ) ;
110+ * this. timer = Some ( Instant :: now ( ) ) ;
115111 }
116112
117113 match this. fut . poll ( cx) {
118114 Poll :: Ready ( result) => {
119- if result. is_err ( ) {
120- let _ = this
121- . failure
122- . get_metric_with_label_values ( & labels. get_labels ( ) ) ;
123- }
115+ let status_code = result. status_code ( ) ;
116+ // add status code
117+ let mut labels = labels. get_labels ( ) ;
118+ labels. push ( status_code. as_str ( ) ) ;
119+ let duration_metric = this. histogram . with_label_values ( & labels) ;
120+
124121 // Record the duration of this request.
125- if let Some ( timer) = this. timer . take ( ) {
126- timer. observe_duration ( ) ;
127- }
122+ let timer = this. timer . take ( ) . expect ( "timer should exist" ) ;
123+ duration_metric . observe ( timer. elapsed ( ) . as_secs_f64 ( ) ) ;
124+
128125 Poll :: Ready ( result)
129126 }
130127 Poll :: Pending => Poll :: Pending ,
@@ -141,9 +138,13 @@ mod tests {
141138 http:: { Request , Response } ,
142139 } ;
143140 use prometheus:: core:: Collector ;
141+ use reqwest:: StatusCode ;
144142 use tower:: { Service , ServiceBuilder , ServiceExt } ;
145143
146- use crate :: middleware:: prometheus_metrics:: { MetricLabels , PrometheusMetricsMiddlewareLayer } ;
144+ use crate :: {
145+ error:: StatusCodeExt ,
146+ middleware:: prometheus_metrics:: { MetricLabels , PrometheusMetricsMiddlewareLayer } ,
147+ } ;
147148
148149 use super :: MetricLabelProvider ;
149150
@@ -153,12 +154,22 @@ mod tests {
153154 vec ! [ "label1," , "label2" , "label3" ]
154155 }
155156 }
156- async fn handle ( _: Request < Body > ) -> anyhow:: Result < Response < Body > > {
157+
158+ #[ derive( Debug ) ]
159+ struct ErrorResponse ;
160+
161+ impl StatusCodeExt for ErrorResponse {
162+ fn status_code ( & self ) -> StatusCode {
163+ StatusCode :: INTERNAL_SERVER_ERROR
164+ }
165+ }
166+
167+ async fn handle ( _: Request < Body > ) -> Result < Response < Body > , ErrorResponse > {
157168 Ok ( Response :: new ( Body :: default ( ) ) )
158169 }
159170
160- async fn handle_err ( _: Request < Body > ) -> anyhow :: Result < Response < Body > > {
161- Err ( anyhow :: anyhow! ( "Error" ) )
171+ async fn handle_err ( _: Request < Body > ) -> Result < Response < Body > , ErrorResponse > {
172+ Err ( ErrorResponse )
162173 }
163174
164175 #[ tokio:: test]
@@ -167,36 +178,20 @@ mod tests {
167178 let histogram_metric = prometheus:: register_histogram_vec_with_registry!(
168179 "histogram_metric" ,
169180 "Test" ,
170- & [ "deployment" , "sender" , "allocation" ] ,
171- registry,
172- )
173- . unwrap ( ) ;
174-
175- let failure_metric = prometheus:: register_counter_vec_with_registry!(
176- "failure_metric" ,
177- "Test" ,
178- & [ "deployment" , "sender" , "allocation" ] ,
181+ & [ "deployment" , "sender" , "allocation" , "status_code" ] ,
179182 registry,
180183 )
181184 . unwrap ( ) ;
182185
183186 // check if everything is clean
184- assert_eq ! (
185- histogram_metric
186- . collect( )
187- . first( )
188- . unwrap( )
189- . get_metric( )
190- . len( ) ,
191- 0
192- ) ;
193- assert_eq ! (
194- failure_metric. collect( ) . first( ) . unwrap( ) . get_metric( ) . len( ) ,
195- 0
196- ) ;
197-
198- let metrics_layer =
199- PrometheusMetricsMiddlewareLayer :: new ( histogram_metric. clone ( ) , failure_metric. clone ( ) ) ;
187+ assert ! ( histogram_metric
188+ . collect( )
189+ . first( )
190+ . unwrap( )
191+ . get_metric( )
192+ . is_empty( ) ) ;
193+
194+ let metrics_layer = PrometheusMetricsMiddlewareLayer :: new ( histogram_metric. clone ( ) ) ;
200195 let mut service = ServiceBuilder :: new ( )
201196 . layer ( metrics_layer)
202197 . service_fn ( handle) ;
@@ -209,23 +204,21 @@ mod tests {
209204 req. extensions_mut ( ) . insert ( labels. clone ( ) ) ;
210205 let _ = handle. call ( req) . await ;
211206
212- assert_eq ! (
207+ let how_many_metrics = | status : u32 | {
213208 histogram_metric
214209 . collect ( )
215210 . first ( )
216211 . unwrap ( )
217212 . get_metric ( )
218- . len( ) ,
219- 1
220- ) ;
213+ . iter ( )
214+ . filter ( |a| a. get_label ( ) [ 3 ] . get_value ( ) == status. to_string ( ) )
215+ . count ( )
216+ } ;
221217
222- assert_eq ! (
223- failure_metric. collect( ) . first( ) . unwrap( ) . get_metric( ) . len( ) ,
224- 0
225- ) ;
218+ assert_eq ! ( how_many_metrics( 200 ) , 1 ) ;
219+ assert_eq ! ( how_many_metrics( 500 ) , 0 ) ;
226220
227- let metrics_layer =
228- PrometheusMetricsMiddlewareLayer :: new ( histogram_metric. clone ( ) , failure_metric. clone ( ) ) ;
221+ let metrics_layer = PrometheusMetricsMiddlewareLayer :: new ( histogram_metric. clone ( ) ) ;
229222 let mut service = ServiceBuilder :: new ( )
230223 . layer ( metrics_layer)
231224 . service_fn ( handle_err) ;
@@ -236,20 +229,7 @@ mod tests {
236229 let _ = handle. call ( req) . await ;
237230
238231 // it's using the same labels, should have only one metric
239- assert_eq ! (
240- histogram_metric
241- . collect( )
242- . first( )
243- . unwrap( )
244- . get_metric( )
245- . len( ) ,
246- 1
247- ) ;
248-
249- // new failture
250- assert_eq ! (
251- failure_metric. collect( ) . first( ) . unwrap( ) . get_metric( ) . len( ) ,
252- 1
253- ) ;
232+ assert_eq ! ( how_many_metrics( 200 ) , 1 ) ;
233+ assert_eq ! ( how_many_metrics( 500 ) , 1 ) ;
254234 }
255235}
0 commit comments