@@ -14,6 +14,13 @@ use std::{
1414} ;
1515use tower:: { Layer , Service } ;
1616
17+ use crate :: error:: StatusCodeExt ;
18+
19+ #[ derive( Clone ) ]
20+ pub struct HandlerLabels ( pub MetricLabels ) ;
21+ #[ derive( Clone ) ]
22+ pub struct FailureLabels ( pub MetricLabels ) ;
23+
1724pub type MetricLabels = Arc < dyn MetricLabelProvider + ' static + Send + Sync > ;
1825
1926pub trait MetricLabelProvider {
@@ -61,6 +68,7 @@ impl<S, ReqBody> Service<Request<ReqBody>> for PrometheusMetricsMiddleware<S>
6168where
6269 S : Service < Request < ReqBody > > + Clone + ' static ,
6370 ReqBody : ' static ,
71+ Result < S :: Response , S :: Error > : StatusCodeExt ,
6472{
6573 type Response = S :: Response ;
6674 type Error = S :: Error ;
@@ -71,12 +79,14 @@ where
7179 }
7280
7381 fn call ( & mut self , request : Request < ReqBody > ) -> PrometheusMetricsFuture < S :: Future > {
74- let labels = request. extensions ( ) . get :: < MetricLabels > ( ) . cloned ( ) ;
82+ let handler_labels = request. extensions ( ) . get :: < HandlerLabels > ( ) . cloned ( ) ;
83+ let failure_labels = request. extensions ( ) . get :: < FailureLabels > ( ) . cloned ( ) ;
7584 PrometheusMetricsFuture {
7685 timer : None ,
7786 histogram : self . histogram . clone ( ) ,
7887 failure : self . failure . clone ( ) ,
79- labels,
88+ handler_labels,
89+ failure_labels,
8090 fut : self . inner . call ( request) ,
8191 }
8292 }
@@ -90,7 +100,8 @@ pub struct PrometheusMetricsFuture<F> {
90100 histogram : prometheus:: HistogramVec ,
91101 failure : prometheus:: CounterVec ,
92102
93- labels : Option < MetricLabels > ,
103+ handler_labels : Option < HandlerLabels > ,
104+ failure_labels : Option < FailureLabels > ,
94105
95106 #[ pin]
96107 fut : F ,
@@ -99,32 +110,41 @@ pub struct PrometheusMetricsFuture<F> {
99110impl < F , R , E > Future for PrometheusMetricsFuture < F >
100111where
101112 F : Future < Output = Result < R , E > > ,
113+ Result < R , E > : StatusCodeExt ,
102114{
103115 type Output = F :: Output ;
104116
105117 fn poll ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Self :: Output > {
106118 let this = self . project ( ) ;
107- let Some ( labels ) = & this. labels else {
119+ if this . handler_labels . is_none ( ) && this. failure_labels . is_none ( ) {
108120 return this. fut . poll ( cx) ;
109- } ;
121+ }
110122
111123 if this. timer . is_none ( ) {
112- // Start timer so we can track duration of request.
113- let duration_metric = this. histogram . with_label_values ( & labels. get_labels ( ) ) ;
114- * this. timer = Some ( duration_metric. start_timer ( ) ) ;
124+ if let Some ( HandlerLabels ( labels) ) = & this. handler_labels {
125+ // Start timer so we can track duration of request.
126+ let duration_metric = this. histogram . with_label_values ( & labels. get_labels ( ) ) ;
127+ * this. timer = Some ( duration_metric. start_timer ( ) ) ;
128+ }
115129 }
116130
117131 match this. fut . poll ( cx) {
118132 Poll :: Ready ( result) => {
119- if result. is_err ( ) {
120- let _ = this
121- . failure
122- . get_metric_with_label_values ( & labels. get_labels ( ) ) ;
133+ let status_code = result. status_code ( ) ;
134+ if !status_code. is_success ( ) {
135+ if let Some ( FailureLabels ( labels) ) = & this. failure_labels {
136+ let mut labels = labels. get_labels ( ) ;
137+ labels. push ( status_code. as_str ( ) ) ;
138+ this. failure
139+ . get_metric_with_label_values ( & labels)
140+ . expect ( "Couldn't register metric" ) ;
141+ }
123142 }
124143 // Record the duration of this request.
125144 if let Some ( timer) = this. timer . take ( ) {
126145 timer. observe_duration ( ) ;
127146 }
147+
128148 Poll :: Ready ( result)
129149 }
130150 Poll :: Pending => Poll :: Pending ,
@@ -141,9 +161,13 @@ mod tests {
141161 http:: { Request , Response } ,
142162 } ;
143163 use prometheus:: core:: Collector ;
164+ use reqwest:: StatusCode ;
144165 use tower:: { Service , ServiceBuilder , ServiceExt } ;
145166
146- use crate :: middleware:: prometheus_metrics:: { MetricLabels , PrometheusMetricsMiddlewareLayer } ;
167+ use crate :: {
168+ error:: StatusCodeExt ,
169+ middleware:: prometheus_metrics:: { MetricLabels , PrometheusMetricsMiddlewareLayer } ,
170+ } ;
147171
148172 use super :: MetricLabelProvider ;
149173
@@ -153,12 +177,22 @@ mod tests {
153177 vec ! [ "label1," , "label2" , "label3" ]
154178 }
155179 }
156- async fn handle ( _: Request < Body > ) -> anyhow:: Result < Response < Body > > {
180+
181+ #[ derive( Debug ) ]
182+ struct ErrorResponse ;
183+
184+ impl StatusCodeExt for ErrorResponse {
185+ fn status_code ( & self ) -> StatusCode {
186+ StatusCode :: INTERNAL_SERVER_ERROR
187+ }
188+ }
189+
190+ async fn handle ( _: Request < Body > ) -> Result < Response < Body > , ErrorResponse > {
157191 Ok ( Response :: new ( Body :: default ( ) ) )
158192 }
159193
160- async fn handle_err ( _: Request < Body > ) -> anyhow :: Result < Response < Body > > {
161- Err ( anyhow :: anyhow! ( "Error" ) )
194+ async fn handle_err ( _: Request < Body > ) -> Result < Response < Body > , ErrorResponse > {
195+ Err ( ErrorResponse )
162196 }
163197
164198 #[ tokio:: test]
@@ -175,7 +209,7 @@ mod tests {
175209 let failure_metric = prometheus:: register_counter_vec_with_registry!(
176210 "failure_metric" ,
177211 "Test" ,
178- & [ "deployment" , "sender" , "allocation" ] ,
212+ & [ "deployment" , "sender" , "allocation" , "status_code" ] ,
179213 registry,
180214 )
181215 . unwrap ( ) ;
0 commit comments