Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,8 @@ gotatun = { version = "0.2.0" }
hickory-proto = "0.25.2"
hickory-resolver = "0.25.2"
hickory-server = { version = "0.25.2", features = ["resolver"] }
hyper-util = { version = "0.1.8", features = [
hyper-util = { version = "0.1.20", features = [
"client",
"client-legacy",
"http1",
"http2",
] }
Expand Down
2 changes: 1 addition & 1 deletion mullvad-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ futures = { workspace = true }
http = "1.1.0"
http-body-util = "0.1.2"
hyper = { version = "1.8.1", features = ["client", "http1"] }
hyper-util = { workspace = true }
hyper-util = { workspace = true, features = ["client-pool"] }
ipnetwork = { workspace = true }
libc = "0.2"
log = { workspace = true }
Expand Down
117 changes: 49 additions & 68 deletions mullvad-api/src/https_client_with_sni.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ use futures::{StreamExt, channel::mpsc, future, pin_mut};
use futures::{channel::oneshot, sink::SinkExt};
use http::uri::Scheme;
use hyper::Uri;
use hyper_util::rt::TokioIo;
use mullvad_encrypted_dns_proxy::{
Forwarder as EncryptedDNSForwarder, config::ProxyConfig as EncryptedDNSConfig,
};
Expand All @@ -27,10 +26,8 @@ use std::{
future::Future,
io,
net::{IpAddr, SocketAddr},
pin::Pin,
str::{self, FromStr},
sync::{Arc, Mutex},
task::{Context, Poll},
time::Duration,
};
use talpid_types::{ErrorExt, net::proxy};
Expand All @@ -39,7 +36,6 @@ use tokio::{
net::{TcpSocket, TcpStream},
time::timeout,
};
use tower::Service;

#[cfg(any(feature = "api-override", test))]
use crate::proxy::ConnectionDecorator;
Expand Down Expand Up @@ -424,27 +420,11 @@ impl HttpsConnectorWithSni {
};
Ok(SocketAddr::new(addr.ip(), port))
}
}

impl fmt::Debug for HttpsConnectorWithSni {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HttpsConnectorWithSni").finish()
}
}

impl Service<Uri> for HttpsConnectorWithSni {
type Response = TokioIo<AbortableStream<ApiConnection>>;
type Error = io::Error;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + 'static>>;

fn poll_ready(&mut self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
let mut inner = self.inner.lock().unwrap();
inner.stream_handles.retain(|handle| !handle.is_closed());
Poll::Ready(Ok(()))
}

fn call(&mut self, uri: Uri) -> Self::Future {
pub async fn get_stream(
&mut self,
uri: Uri,
) -> Result<AbortableStream<ApiConnection>, io::Error> {
let inner = self.inner.clone();
let abort_notify = self.abort_notify.clone();
#[cfg(target_os = "android")]
Expand All @@ -454,55 +434,56 @@ impl Service<Uri> for HttpsConnectorWithSni {
#[cfg(any(feature = "api-override", test))]
let disable_tls = self.disable_tls;

let fut = async move {
if uri.scheme() != Some(&Scheme::HTTPS) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid url, not https",
));
}
let Some(hostname) = uri.host().map(str::to_owned) else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid url, missing host",
));
};
let addr = Self::resolve_address(&*dns_resolver, uri).await?;

// Loop until we have established a connection. This starts over if a new endpoint
// is selected while connecting.
let stream = loop {
let notify = abort_notify.notified();
let proxy_config = { inner.lock().unwrap().proxy_config.clone() };
let stream_fut = proxy_config.connect(
&hostname,
&addr,
#[cfg(target_os = "android")]
socket_bypass_tx.clone(),
#[cfg(any(feature = "api-override", test))]
disable_tls,
);
if uri.scheme() != Some(&Scheme::HTTPS) {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid url, not https",
));
}
let Some(hostname) = uri.host().map(str::to_owned) else {
return Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid url, missing host",
));
};
let addr = Self::resolve_address(&*dns_resolver, uri).await?;

// Loop until we have established a connection. This starts over if a new endpoint
// is selected while connecting.
let stream = loop {
let notify = abort_notify.notified();
let proxy_config = { inner.lock().unwrap().proxy_config.clone() };
let stream_fut = proxy_config.connect(
&hostname,
&addr,
#[cfg(target_os = "android")]
socket_bypass_tx.clone(),
#[cfg(any(feature = "api-override", test))]
disable_tls,
);

pin_mut!(stream_fut);
pin_mut!(notify);
pin_mut!(stream_fut);
pin_mut!(notify);

// Wait for connection. Abort and retry if we switched to a different server.
if let future::Either::Left((stream, _)) = future::select(stream_fut, notify).await
{
break stream?;
}
};
// Wait for connection. Abort and retry if we switched to a different server.
if let future::Either::Left((stream, _)) = future::select(stream_fut, notify).await {
break stream?;
}
};

let (stream, socket_handle) = AbortableStream::new(stream);
let (stream, socket_handle) = AbortableStream::new(stream);

{
let mut inner = inner.lock().unwrap();
inner.stream_handles.push(socket_handle);
}
{
let mut inner = inner.lock().unwrap();
inner.stream_handles.push(socket_handle);
}

Ok(TokioIo::new(stream))
};
Ok(stream)
}
}

Box::pin(fut)
impl fmt::Debug for HttpsConnectorWithSni {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("HttpsConnectorWithSni").finish()
}
}
102 changes: 52 additions & 50 deletions mullvad-api/src/rest.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ use hyper::{
body::{Body, Buf, Bytes, Incoming},
header::{self, HeaderValue},
};
use hyper_util::client::legacy::connect::Connect;
use mullvad_types::account::AccountNumber;
use std::{
borrow::Cow,
Expand Down Expand Up @@ -45,12 +44,13 @@ pub enum Error {
#[error("Request cancelled")]
Aborted,

#[error("Legacy hyper error")]
LegacyHyperError(#[from] Arc<hyper_util::client::legacy::Error>),

#[error("Hyper error")]
HyperError(#[from] Arc<hyper::Error>),

/// Connection dropped
#[error("Connection dropped unexpectedly")]
ConnectionDropped,

#[error("Invalid header value")]
InvalidHeaderError,

Expand Down Expand Up @@ -91,30 +91,27 @@ impl From<Infallible> for Error {

impl Error {
pub fn is_network_error(&self) -> bool {
matches!(
self,
Error::HyperError(_) | Error::LegacyHyperError(_) | Error::TimeoutError
)
}

/// Return true if there was no route to the destination
pub fn is_offline(&self) -> bool {
match self {
Error::LegacyHyperError(error) if error.is_connect() => {
if let Some(cause) = error.source()
&& let Some(err) = cause.downcast_ref::<std::io::Error>()
{
return err.raw_os_error() == Some(libc::ENETUNREACH);
}

false
}
// TODO: Currently, we use the legacy hyper client for all REST requests. If this
// changes in the future, we likely need to match on `Error::HyperError` here and
// determine how to achieve the equivalent behavior. See DES-1288.
_ => false,
}
}
matches!(self, Error::HyperError(_) | Error::TimeoutError)
}

// Return true if there was no route to the destination
// pub fn is_offline(&self) -> bool {
// match self {
// Error::LegacyHyperError(error) if error.is_connect() => {
// if let Some(cause) = error.source()
// && let Some(err) = cause.downcast_ref::<std::io::Error>()
// {
// return err.raw_os_error() == Some(libc::ENETUNREACH);
// }

// false
// }
// // TODO: Currently, we use the legacy hyper client for all REST requests. If this
// // changes in the future, we likely need to match on `Error::HyperError` here and
// // determine how to achieve the equivalent behavior. See DES-1288.
// _ => false,
// }
// }

pub fn is_aborted(&self) -> bool {
matches!(self, Error::Aborted)
Expand All @@ -141,16 +138,14 @@ impl Error {
}

// TODO: Look into an alternative to using the legacy hyper client `DES-1288`
type RequestClient =
hyper_util::client::legacy::Client<HttpsConnectorWithSni, BoxBody<Bytes, Error>>;

/// A service that executes HTTP requests, allowing for on-demand termination of all in-flight
/// requests
pub(crate) struct RequestService<T: ConnectionModeProvider> {
command_tx: Weak<mpsc::UnboundedSender<RequestCommand>>,
command_rx: mpsc::UnboundedReceiver<RequestCommand>,
connector_handle: HttpsConnectorWithSniHandle,
client: RequestClient,
connector: HttpsConnectorWithSni,
connection_mode_provider: T,
connection_mode_generation: usize,
api_availability: ApiAvailability,
Expand All @@ -176,17 +171,14 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> {
connector_handle.set_connection_mode(connection_mode_provider.initial());

let (command_tx, command_rx) = mpsc::unbounded();
let client =
hyper_util::client::legacy::Client::builder(hyper_util::rt::TokioExecutor::new())
.build(connector);

let command_tx = Arc::new(command_tx);

let service = Self {
command_tx: Arc::downgrade(&command_tx),
command_rx,
connector_handle,
client,
connector,
connection_mode_provider,
connection_mode_generation: 0,
api_availability,
Expand Down Expand Up @@ -245,7 +237,7 @@ impl<T: ConnectionModeProvider + 'static> RequestService<T> {
let api_availability = self.api_availability.clone();
let request_future = request
.map(|r| http::Request::map(r, BodyExt::boxed))
.into_future(self.client.clone(), api_availability.clone());
.into_future(self.connector.clone(), api_availability.clone());

let connection_mode_generation = self.connection_mode_generation;

Expand Down Expand Up @@ -429,26 +421,23 @@ where
B::Data: Send,
B::Error: Into<Box<dyn StdError + Send + Sync>>,
{
async fn into_future<C: Connect + Clone + Send + Sync + 'static>(
async fn into_future(
self,
hyper_client: hyper_util::client::legacy::Client<C, B>,
connection: HttpsConnectorWithSni,
api_availability: ApiAvailability,
) -> Result<Response<Incoming>> {
let timeout = self.timeout;
let inner_fut = self.into_future_without_timeout(hyper_client, api_availability);
let inner_fut = self.into_future_without_timeout(connection, api_availability);
tokio::time::timeout(timeout, inner_fut)
.await
.map_err(|_| Error::TimeoutError)?
}

async fn into_future_without_timeout<C>(
async fn into_future_without_timeout(
mut self,
hyper_client: hyper_util::client::legacy::Client<C, B>,
mut connection: HttpsConnectorWithSni,
api_availability: ApiAvailability,
) -> Result<Response<Incoming>>
where
C: Connect + Clone + Send + Sync + 'static,
{
) -> Result<Response<Incoming>> {
let _ = api_availability.wait_for_unsuspend().await;

// Obtain access token first
Expand All @@ -461,11 +450,25 @@ where
.insert(header::AUTHORIZATION, auth);
}

let stream = connection.get_stream(self.uri().clone()).await.unwrap();
let tokio_io = hyper_util::rt::TokioIo::new(stream);

hyper_util::client::pool::map::Map::builder::<http::Uri>()
.keys(|dst| (dst.scheme().cloned(), dst.authority().cloned()))
.values(move |dst| {});
let (mut sender, conn) = hyper::client::conn::http1::handshake(tokio_io).await?;

// Make request to hyper client
let response = hyper_client
.request(self.request)
.await
.map_err(Error::from);
let response = tokio::select! {
res = sender.send_request(self.request) => res.map_err(Error::from),
conn_res = conn => {
log::error!("API request connection failed");
match conn_res {
Ok(()) => Err(Error::ConnectionDropped),
Err(err) => Err(Error::HyperError(Arc::new(err))),
}
}
};

// Notify access token store of expired tokens
if let (Some(account), Some(store)) = (&self.account, &self.access_token_store) {
Expand Down Expand Up @@ -769,7 +772,6 @@ macro_rules! impl_into_arc_err {
}

impl_into_arc_err!(hyper::Error);
impl_into_arc_err!(hyper_util::client::legacy::Error);
impl_into_arc_err!(serde_json::Error);
impl_into_arc_err!(http::Error);
impl_into_arc_err!(http::uri::InvalidUri);
Loading
Loading