|
3 | 3 |
|
4 | 4 | use reqwest::header::HeaderMap;
|
5 | 5 | use reqwest::header::HeaderValue;
|
6 |
| -use reqwest::{Client, StatusCode}; |
| 6 | +use reqwest::Client; |
| 7 | +use tracing::instrument; |
7 | 8 |
|
8 | 9 | use std::time::Duration;
|
9 | 10 |
|
10 | 11 | use serde::Deserialize;
|
11 | 12 | use serde_xml_rs::from_str;
|
12 | 13 |
|
13 |
| -use tokio::time::timeout; |
14 |
| - |
15 | 14 | use crate::error::Error;
|
16 | 15 | use crate::http;
|
17 | 16 |
|
@@ -101,65 +100,52 @@ const DEFAULT_GOALSTATE_URL: &str =
|
101 | 100 | /// Some("http://127.0.0.1:8000/"),
|
102 | 101 | /// );
|
103 | 102 | /// ```
|
| 103 | +#[instrument(err, skip_all)] |
104 | 104 | pub async fn get_goalstate(
|
105 | 105 | client: &Client,
|
106 | 106 | retry_interval: Duration,
|
107 |
| - total_timeout: Duration, |
| 107 | + mut total_timeout: Duration, |
108 | 108 | url: Option<&str>,
|
109 | 109 | ) -> Result<Goalstate, Error> {
|
110 |
| - let url = url.unwrap_or(DEFAULT_GOALSTATE_URL); |
111 |
| - |
112 | 110 | let mut headers = HeaderMap::new();
|
113 | 111 | headers.insert("x-ms-agent-name", HeaderValue::from_static("azure-init"));
|
114 | 112 | headers.insert("x-ms-version", HeaderValue::from_static("2012-11-30"));
|
115 |
| - |
116 |
| - let response = timeout(total_timeout, async { |
117 |
| - let now = std::time::Instant::now(); |
118 |
| - loop { |
119 |
| - if let Ok(response) = client |
120 |
| - .get(url) |
121 |
| - .headers(headers.clone()) |
122 |
| - .timeout(Duration::from_secs(http::WIRESERVER_HTTP_TIMEOUT_SEC)) |
123 |
| - .send() |
124 |
| - .await |
125 |
| - { |
126 |
| - let statuscode = response.status(); |
127 |
| - |
128 |
| - if statuscode == StatusCode::OK { |
129 |
| - tracing::info!( |
130 |
| - "HTTP response succeeded with status {}", |
131 |
| - statuscode |
| 113 | + let url = url.unwrap_or(DEFAULT_GOALSTATE_URL); |
| 114 | + let request_timeout = |
| 115 | + Duration::from_secs(http::WIRESERVER_HTTP_TIMEOUT_SEC); |
| 116 | + |
| 117 | + while !total_timeout.is_zero() { |
| 118 | + let (response, remaining_timeout) = http::get( |
| 119 | + client, |
| 120 | + headers.clone(), |
| 121 | + request_timeout, |
| 122 | + retry_interval, |
| 123 | + total_timeout, |
| 124 | + url, |
| 125 | + ) |
| 126 | + .await?; |
| 127 | + match response.text().await { |
| 128 | + Ok(body) => { |
| 129 | + let goalstate = from_str(&body).map_err(|error| { |
| 130 | + tracing::warn!( |
| 131 | + ?error, |
| 132 | + "The response body was invalid and could not be deserialized" |
132 | 133 | );
|
133 |
| - return Ok(response); |
134 |
| - } |
135 |
| - |
136 |
| - if !http::RETRY_CODES.contains(&statuscode) { |
137 |
| - return response.error_for_status().map_err(|error| { |
138 |
| - tracing::error!( |
139 |
| - ?error, |
140 |
| - "{}", |
141 |
| - format!( |
142 |
| - "HTTP call failed due to status {}", |
143 |
| - statuscode |
144 |
| - ) |
145 |
| - ); |
146 |
| - error |
147 |
| - }); |
| 134 | + error.into() |
| 135 | + }); |
| 136 | + if goalstate.is_ok() { |
| 137 | + return goalstate; |
148 | 138 | }
|
149 | 139 | }
|
150 |
| - |
151 |
| - tracing::info!("Retrying to get HTTP response in {} sec, remaining timeout {} sec.", retry_interval.as_secs(), total_timeout.saturating_sub(now.elapsed()).as_secs()); |
152 |
| - |
153 |
| - tokio::time::sleep(retry_interval).await; |
| 140 | + Err(error) => { |
| 141 | + tracing::warn!(?error, "Failed to read the full response body") |
| 142 | + } |
154 | 143 | }
|
155 |
| - }) |
156 |
| - .await?; |
157 |
| - |
158 |
| - let goalstate_body = response?.text().await?; |
159 | 144 |
|
160 |
| - let goalstate: Goalstate = from_str(&goalstate_body)?; |
| 145 | + total_timeout = remaining_timeout; |
| 146 | + } |
161 | 147 |
|
162 |
| - Ok(goalstate) |
| 148 | + Err(Error::Timeout) |
163 | 149 | }
|
164 | 150 |
|
165 | 151 | const DEFAULT_HEALTH_URL: &str = "http://168.63.129.16/machine/?comp=health";
|
@@ -199,56 +185,36 @@ const DEFAULT_HEALTH_URL: &str = "http://168.63.129.16/machine/?comp=health";
|
199 | 185 | /// );
|
200 | 186 | /// }
|
201 | 187 | /// ```
|
| 188 | +#[instrument(err, skip_all)] |
202 | 189 | pub async fn report_health(
|
203 | 190 | client: &Client,
|
204 | 191 | goalstate: Goalstate,
|
205 | 192 | retry_interval: Duration,
|
206 | 193 | total_timeout: Duration,
|
207 | 194 | url: Option<&str>,
|
208 | 195 | ) -> Result<(), Error> {
|
209 |
| - let url = url.unwrap_or(DEFAULT_HEALTH_URL); |
210 |
| - |
211 | 196 | let mut headers = HeaderMap::new();
|
212 | 197 | headers.insert("x-ms-agent-name", HeaderValue::from_static("azure-init"));
|
213 | 198 | headers.insert("x-ms-version", HeaderValue::from_static("2012-11-30"));
|
214 | 199 | headers.insert(
|
215 | 200 | "Content-Type",
|
216 | 201 | HeaderValue::from_static("text/xml;charset=utf-8"),
|
217 | 202 | );
|
| 203 | + let request_timeout = |
| 204 | + Duration::from_secs(http::WIRESERVER_HTTP_TIMEOUT_SEC); |
| 205 | + let url = url.unwrap_or(DEFAULT_HEALTH_URL); |
218 | 206 |
|
219 | 207 | let post_request = build_report_health_file(goalstate);
|
220 | 208 |
|
221 |
| - _ = timeout(total_timeout, async { |
222 |
| - let now = std::time::Instant::now(); |
223 |
| - loop { |
224 |
| - if let Ok(response) = client |
225 |
| - .post(url) |
226 |
| - .headers(headers.clone()) |
227 |
| - .body(post_request.clone()) |
228 |
| - .timeout(Duration::from_secs(http::WIRESERVER_HTTP_TIMEOUT_SEC)) |
229 |
| - .send() |
230 |
| - .await |
231 |
| - { |
232 |
| - let statuscode = response.status(); |
233 |
| - |
234 |
| - if statuscode == StatusCode::OK { |
235 |
| - tracing::info!("HTTP response succeeded with status {}", statuscode); |
236 |
| - return Ok(response); |
237 |
| - } |
238 |
| - |
239 |
| - if !http::RETRY_CODES.contains(&statuscode) { |
240 |
| - return response.error_for_status().map_err(|error| { |
241 |
| - tracing::error!(?error, "{}", format!("HTTP call failed due to status {}", statuscode)); |
242 |
| - error |
243 |
| - }); |
244 |
| - } |
245 |
| - } |
246 |
| - |
247 |
| - tracing::info!("Retrying to get HTTP response in {} sec, remaining timeout {} sec.", retry_interval.as_secs(), total_timeout.saturating_sub(now.elapsed()).as_secs()); |
248 |
| - |
249 |
| - tokio::time::sleep(retry_interval).await; |
250 |
| - } |
251 |
| - }) |
| 209 | + _ = http::post( |
| 210 | + client, |
| 211 | + headers, |
| 212 | + post_request, |
| 213 | + request_timeout, |
| 214 | + retry_interval, |
| 215 | + total_timeout, |
| 216 | + url, |
| 217 | + ) |
252 | 218 | .await?;
|
253 | 219 |
|
254 | 220 | Ok(())
|
@@ -458,4 +424,54 @@ mod tests {
|
458 | 424 | assert!(!run_goalstate_retry(rc).await);
|
459 | 425 | }
|
460 | 426 | }
|
| 427 | + |
| 428 | + // Assert malformed responses are retried. |
| 429 | + // |
| 430 | + // In this case the server doesn't return XML at all. |
| 431 | + #[tokio::test] |
| 432 | + #[tracing_test::traced_test] |
| 433 | + async fn malformed_response() { |
| 434 | + let body = "You thought this was XML, but you were wrong"; |
| 435 | + let payload = format!( |
| 436 | + "HTTP/1.1 {} {}\r\nContent-Type: application/xml\r\nContent-Length: {}\r\n\r\n{}", |
| 437 | + StatusCode::OK.as_u16(), |
| 438 | + StatusCode::OK.to_string(), |
| 439 | + body.len(), |
| 440 | + body |
| 441 | + ); |
| 442 | + |
| 443 | + let serverlistener = TcpListener::bind("127.0.0.1:0").await.unwrap(); |
| 444 | + let addr = serverlistener.local_addr().unwrap(); |
| 445 | + let cancel_token = tokio_util::sync::CancellationToken::new(); |
| 446 | + let server = tokio::spawn(unittest::serve_requests( |
| 447 | + serverlistener, |
| 448 | + payload, |
| 449 | + cancel_token.clone(), |
| 450 | + )); |
| 451 | + |
| 452 | + let client = Client::builder() |
| 453 | + .timeout(std::time::Duration::from_secs(5)) |
| 454 | + .build() |
| 455 | + .unwrap(); |
| 456 | + |
| 457 | + let res = get_goalstate( |
| 458 | + &client, |
| 459 | + Duration::from_millis(10), |
| 460 | + Duration::from_millis(50), |
| 461 | + Some(format!("http://{:}:{:}/", addr.ip(), addr.port()).as_str()), |
| 462 | + ) |
| 463 | + .await; |
| 464 | + |
| 465 | + cancel_token.cancel(); |
| 466 | + |
| 467 | + let requests = server.await.unwrap(); |
| 468 | + assert!(requests >= 2); |
| 469 | + assert!(logs_contain( |
| 470 | + "The response body was invalid and could not be deserialized" |
| 471 | + )); |
| 472 | + match res { |
| 473 | + Err(crate::error::Error::Timeout) => {} |
| 474 | + _ => panic!("Response should have timed out"), |
| 475 | + }; |
| 476 | + } |
461 | 477 | }
|
0 commit comments