Skip to content

Commit 5f2ff88

Browse files
authored
Refactor the HTTP retry code to use a single implementation (#135)
* error: include the underlying error in the error text While these may not be the optimal error messages in all cases, this should make them more useful by including the underlying error message so that you get, for example, "HTTP client error occurred: HTTP status client error (405 Method Not Allowed) for url (http://127.0.0.1:35195/)" instead of "HTTP client error ocurred". * http: provide a generic HTTP client interface that retries Add a generic `request()` function which implements the retry rules we have and provide `get()` and `post()` wrappers for the function. Move the IMDS and Wireserver interfaces to this new interface so we don't have multiple implementations of the retry rules. * Retry when the response body can't be deserialized In the event we receive something that can't be deserialized, retry the request. This seems like a rather unlikely scenario, since I would expect a partial response to fail with a lower-level reqwest error, but it can happen in theory, so let's handle it gracefully. * imds: bump API version and set extended=true to the query parameters This causes the request to be treated as a request for data for provisioning purposes. Without it, the server may respond with an HTTP 200 and the response may not include all the data we need to provision. Using this query parameter should turn those requests into failures which we'll retry. Fixes #134
1 parent 79e22b6 commit 5f2ff88

File tree

5 files changed

+500
-128
lines changed

5 files changed

+500
-128
lines changed

libazureinit/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ strum = { version = "0.26.3", features = ["derive"] }
2222
fstab = "0.4.0"
2323

2424
[dev-dependencies]
25+
tracing-test = { version = "0.2", features = ["no-env-filter"] }
2526
tempfile = "3"
2627
tokio = { version = "1", features = ["full"] }
2728
tokio-util = "0.7.11"

libazureinit/src/error.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424
/// ```
2525
#[derive(thiserror::Error, Debug)]
2626
pub enum Error {
27-
#[error("Unable to deserialize or serialize JSON data")]
27+
#[error("Unable to deserialize or serialize JSON data: {0}")]
2828
Json(#[from] serde_json::Error),
29-
#[error("Unable to deserialize or serialize XML data")]
29+
#[error("Unable to deserialize or serialize XML data: {0}")]
3030
Xml(#[from] serde_xml_rs::Error),
31-
#[error("HTTP client error ocurred")]
31+
#[error("HTTP client error occurred: {0}")]
3232
Http(#[from] reqwest::Error),
33-
#[error("An I/O error occurred")]
33+
#[error("An I/O error occurred: {0}")]
3434
Io(#[from] std::io::Error),
3535
#[error("HTTP request did not succeed (HTTP {status} from {endpoint})")]
3636
HttpStatus {
@@ -44,7 +44,7 @@ pub enum Error {
4444
},
4545
#[error("failed to construct a C-style string")]
4646
NulError(#[from] std::ffi::NulError),
47-
#[error("nix call failed")]
47+
#[error("nix call failed: {0}")]
4848
Nix(#[from] nix::Error),
4949
#[error("The user {user} does not exist")]
5050
UserMissing { user: String },
@@ -54,7 +54,7 @@ pub enum Error {
5454
InstanceMetadataFailure,
5555
#[error("Provisioning a user with a non-empty password is not supported")]
5656
NonEmptyPassword,
57-
#[error("Unable to get list of block devices")]
57+
#[error("Unable to get list of block devices: {0}")]
5858
BlockUtils(#[from] block_utils::BlockUtilsError),
5959
#[error(
6060
"Failed to set the hostname; none of the provided backends succeeded"
@@ -69,5 +69,11 @@ pub enum Error {
6969
)]
7070
NoPasswordProvisioner,
7171
#[error("A timeout error occurred")]
72-
Timeout(#[from] tokio::time::error::Elapsed),
72+
Timeout,
73+
}
74+
75+
impl From<tokio::time::error::Elapsed> for Error {
76+
fn from(_: tokio::time::error::Elapsed) -> Self {
77+
Self::Timeout
78+
}
7379
}

libazureinit/src/goalstate.rs

Lines changed: 97 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,14 @@
33

44
use reqwest::header::HeaderMap;
55
use reqwest::header::HeaderValue;
6-
use reqwest::{Client, StatusCode};
6+
use reqwest::Client;
7+
use tracing::instrument;
78

89
use std::time::Duration;
910

1011
use serde::Deserialize;
1112
use serde_xml_rs::from_str;
1213

13-
use tokio::time::timeout;
14-
1514
use crate::error::Error;
1615
use crate::http;
1716

@@ -101,65 +100,52 @@ const DEFAULT_GOALSTATE_URL: &str =
101100
/// Some("http://127.0.0.1:8000/"),
102101
/// );
103102
/// ```
103+
#[instrument(err, skip_all)]
104104
pub async fn get_goalstate(
105105
client: &Client,
106106
retry_interval: Duration,
107-
total_timeout: Duration,
107+
mut total_timeout: Duration,
108108
url: Option<&str>,
109109
) -> Result<Goalstate, Error> {
110-
let url = url.unwrap_or(DEFAULT_GOALSTATE_URL);
111-
112110
let mut headers = HeaderMap::new();
113111
headers.insert("x-ms-agent-name", HeaderValue::from_static("azure-init"));
114112
headers.insert("x-ms-version", HeaderValue::from_static("2012-11-30"));
115-
116-
let response = timeout(total_timeout, async {
117-
let now = std::time::Instant::now();
118-
loop {
119-
if let Ok(response) = client
120-
.get(url)
121-
.headers(headers.clone())
122-
.timeout(Duration::from_secs(http::WIRESERVER_HTTP_TIMEOUT_SEC))
123-
.send()
124-
.await
125-
{
126-
let statuscode = response.status();
127-
128-
if statuscode == StatusCode::OK {
129-
tracing::info!(
130-
"HTTP response succeeded with status {}",
131-
statuscode
113+
let url = url.unwrap_or(DEFAULT_GOALSTATE_URL);
114+
let request_timeout =
115+
Duration::from_secs(http::WIRESERVER_HTTP_TIMEOUT_SEC);
116+
117+
while !total_timeout.is_zero() {
118+
let (response, remaining_timeout) = http::get(
119+
client,
120+
headers.clone(),
121+
request_timeout,
122+
retry_interval,
123+
total_timeout,
124+
url,
125+
)
126+
.await?;
127+
match response.text().await {
128+
Ok(body) => {
129+
let goalstate = from_str(&body).map_err(|error| {
130+
tracing::warn!(
131+
?error,
132+
"The response body was invalid and could not be deserialized"
132133
);
133-
return Ok(response);
134-
}
135-
136-
if !http::RETRY_CODES.contains(&statuscode) {
137-
return response.error_for_status().map_err(|error| {
138-
tracing::error!(
139-
?error,
140-
"{}",
141-
format!(
142-
"HTTP call failed due to status {}",
143-
statuscode
144-
)
145-
);
146-
error
147-
});
134+
error.into()
135+
});
136+
if goalstate.is_ok() {
137+
return goalstate;
148138
}
149139
}
150-
151-
tracing::info!("Retrying to get HTTP response in {} sec, remaining timeout {} sec.", retry_interval.as_secs(), total_timeout.saturating_sub(now.elapsed()).as_secs());
152-
153-
tokio::time::sleep(retry_interval).await;
140+
Err(error) => {
141+
tracing::warn!(?error, "Failed to read the full response body")
142+
}
154143
}
155-
})
156-
.await?;
157-
158-
let goalstate_body = response?.text().await?;
159144

160-
let goalstate: Goalstate = from_str(&goalstate_body)?;
145+
total_timeout = remaining_timeout;
146+
}
161147

162-
Ok(goalstate)
148+
Err(Error::Timeout)
163149
}
164150

165151
const DEFAULT_HEALTH_URL: &str = "http://168.63.129.16/machine/?comp=health";
@@ -199,56 +185,36 @@ const DEFAULT_HEALTH_URL: &str = "http://168.63.129.16/machine/?comp=health";
199185
/// );
200186
/// }
201187
/// ```
188+
#[instrument(err, skip_all)]
202189
pub async fn report_health(
203190
client: &Client,
204191
goalstate: Goalstate,
205192
retry_interval: Duration,
206193
total_timeout: Duration,
207194
url: Option<&str>,
208195
) -> Result<(), Error> {
209-
let url = url.unwrap_or(DEFAULT_HEALTH_URL);
210-
211196
let mut headers = HeaderMap::new();
212197
headers.insert("x-ms-agent-name", HeaderValue::from_static("azure-init"));
213198
headers.insert("x-ms-version", HeaderValue::from_static("2012-11-30"));
214199
headers.insert(
215200
"Content-Type",
216201
HeaderValue::from_static("text/xml;charset=utf-8"),
217202
);
203+
let request_timeout =
204+
Duration::from_secs(http::WIRESERVER_HTTP_TIMEOUT_SEC);
205+
let url = url.unwrap_or(DEFAULT_HEALTH_URL);
218206

219207
let post_request = build_report_health_file(goalstate);
220208

221-
_ = timeout(total_timeout, async {
222-
let now = std::time::Instant::now();
223-
loop {
224-
if let Ok(response) = client
225-
.post(url)
226-
.headers(headers.clone())
227-
.body(post_request.clone())
228-
.timeout(Duration::from_secs(http::WIRESERVER_HTTP_TIMEOUT_SEC))
229-
.send()
230-
.await
231-
{
232-
let statuscode = response.status();
233-
234-
if statuscode == StatusCode::OK {
235-
tracing::info!("HTTP response succeeded with status {}", statuscode);
236-
return Ok(response);
237-
}
238-
239-
if !http::RETRY_CODES.contains(&statuscode) {
240-
return response.error_for_status().map_err(|error| {
241-
tracing::error!(?error, "{}", format!("HTTP call failed due to status {}", statuscode));
242-
error
243-
});
244-
}
245-
}
246-
247-
tracing::info!("Retrying to get HTTP response in {} sec, remaining timeout {} sec.", retry_interval.as_secs(), total_timeout.saturating_sub(now.elapsed()).as_secs());
248-
249-
tokio::time::sleep(retry_interval).await;
250-
}
251-
})
209+
_ = http::post(
210+
client,
211+
headers,
212+
post_request,
213+
request_timeout,
214+
retry_interval,
215+
total_timeout,
216+
url,
217+
)
252218
.await?;
253219

254220
Ok(())
@@ -458,4 +424,54 @@ mod tests {
458424
assert!(!run_goalstate_retry(rc).await);
459425
}
460426
}
427+
428+
// Assert malformed responses are retried.
429+
//
430+
// In this case the server doesn't return XML at all.
431+
#[tokio::test]
432+
#[tracing_test::traced_test]
433+
async fn malformed_response() {
434+
let body = "You thought this was XML, but you were wrong";
435+
let payload = format!(
436+
"HTTP/1.1 {} {}\r\nContent-Type: application/xml\r\nContent-Length: {}\r\n\r\n{}",
437+
StatusCode::OK.as_u16(),
438+
StatusCode::OK.to_string(),
439+
body.len(),
440+
body
441+
);
442+
443+
let serverlistener = TcpListener::bind("127.0.0.1:0").await.unwrap();
444+
let addr = serverlistener.local_addr().unwrap();
445+
let cancel_token = tokio_util::sync::CancellationToken::new();
446+
let server = tokio::spawn(unittest::serve_requests(
447+
serverlistener,
448+
payload,
449+
cancel_token.clone(),
450+
));
451+
452+
let client = Client::builder()
453+
.timeout(std::time::Duration::from_secs(5))
454+
.build()
455+
.unwrap();
456+
457+
let res = get_goalstate(
458+
&client,
459+
Duration::from_millis(10),
460+
Duration::from_millis(50),
461+
Some(format!("http://{:}:{:}/", addr.ip(), addr.port()).as_str()),
462+
)
463+
.await;
464+
465+
cancel_token.cancel();
466+
467+
let requests = server.await.unwrap();
468+
assert!(requests >= 2);
469+
assert!(logs_contain(
470+
"The response body was invalid and could not be deserialized"
471+
));
472+
match res {
473+
Err(crate::error::Error::Timeout) => {}
474+
_ => panic!("Response should have timed out"),
475+
};
476+
}
461477
}

0 commit comments

Comments
 (0)