Skip to content

Commit b25329f

Browse files
authored
inbound: Decouple inbound stack from TCP connections (#823)
In order to make the inbound stack testable without initializing OS sockets, this change decouples the inbound stack from its TCP client, passing in an abstract implementation to `inbound::Config::build`. This change introduces a new `stack::Fail` module that always fails. This is used to ensure that connections targeting the inbound server without an opaque port are not forwarded, removing the need for endpoint-level loop prevention.
1 parent 3492262 commit b25329f

File tree

6 files changed

+168
-132
lines changed

6 files changed

+168
-132
lines changed

linkerd/app/core/src/svc.rs

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,22 @@ pub use crate::proxy::http;
44
use crate::{cache, Error};
55
pub use linkerd2_buffer as buffer;
66
pub use linkerd2_concurrency_limit::ConcurrencyLimit;
7-
pub use linkerd2_stack::{self as stack, layer, BoxNewService, NewRouter, NewService, NewUnwrapOr};
7+
pub use linkerd2_stack::{
8+
self as stack, layer, BoxNewService, Fail, NewRouter, NewService, NewUnwrapOr,
9+
};
810
pub use linkerd2_stack_tracing::{InstrumentMake, InstrumentMakeLayer};
911
pub use linkerd2_timeout::{self as timeout, FailFast};
1012
use std::{
1113
task::{Context, Poll},
1214
time::Duration,
1315
};
14-
use tower::layer::util::{Identity, Stack as Pair};
15-
pub use tower::layer::Layer;
16-
pub use tower::make::MakeService;
17-
pub use tower::spawn_ready::SpawnReady;
18-
pub use tower::util::Either;
19-
pub use tower::{service_fn as mk, Service, ServiceExt};
16+
use tower::{
17+
layer::util::{Identity, Stack as Pair},
18+
make::MakeService,
19+
};
20+
pub use tower::{
21+
layer::Layer, service_fn as mk, spawn_ready::SpawnReady, util::Either, Service, ServiceExt,
22+
};
2023

2124
#[derive(Clone, Debug)]
2225
pub struct Layers<L>(L);

linkerd/app/inbound/src/endpoint.rs

Lines changed: 9 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -88,18 +88,6 @@ impl Into<transport::labels::Key> for &'_ TcpAccept {
8888

8989
// === impl HttpEndpoint ===
9090

91-
impl Into<SocketAddr> for HttpEndpoint {
92-
fn into(self) -> SocketAddr {
93-
(&self).into()
94-
}
95-
}
96-
97-
impl Into<SocketAddr> for &'_ HttpEndpoint {
98-
fn into(self) -> SocketAddr {
99-
([127, 0, 0, 1], self.port).into()
100-
}
101-
}
102-
10391
impl Into<http::client::Settings> for &'_ HttpEndpoint {
10492
fn into(self) -> http::client::Settings {
10593
self.settings
@@ -115,23 +103,6 @@ impl From<Target> for HttpEndpoint {
115103
}
116104
}
117105

118-
impl tls::HasPeerIdentity for HttpEndpoint {
119-
fn peer_identity(&self) -> tls::PeerIdentity {
120-
Conditional::None(tls::ReasonForNoPeerName::Loopback)
121-
}
122-
}
123-
124-
impl Into<transport::labels::Key> for &'_ HttpEndpoint {
125-
fn into(self) -> transport::labels::Key {
126-
transport::labels::Key::Connect(transport::labels::EndpointLabels {
127-
direction: transport::labels::Direction::In,
128-
authority: None,
129-
labels: None,
130-
tls_id: tls::Conditional::None(tls::ReasonForNoPeerName::Loopback).into(),
131-
})
132-
}
133-
}
134-
135106
// === TcpEndpoint ===
136107

137108
impl From<TcpAccept> for TcpEndpoint {
@@ -142,30 +113,21 @@ impl From<TcpAccept> for TcpEndpoint {
142113
}
143114
}
144115

145-
impl From<(Option<Header>, TcpAccept)> for TcpEndpoint {
146-
fn from((hdr, tcp): (Option<Header>, TcpAccept)) -> Self {
147-
match hdr {
148-
Some(Header { port, .. }) => Self { port },
149-
None => tcp.into(),
150-
}
116+
impl From<Header> for TcpEndpoint {
117+
fn from(Header { port, .. }: Header) -> Self {
118+
Self { port }
151119
}
152120
}
153121

154-
impl Into<SocketAddr> for TcpEndpoint {
155-
fn into(self) -> SocketAddr {
156-
(&self).into()
122+
impl From<HttpEndpoint> for TcpEndpoint {
123+
fn from(HttpEndpoint { port, .. }: HttpEndpoint) -> Self {
124+
Self { port }
157125
}
158126
}
159127

160-
impl Into<SocketAddr> for &'_ TcpEndpoint {
161-
fn into(self) -> SocketAddr {
162-
([127, 0, 0, 1], self.port).into()
163-
}
164-
}
165-
166-
impl tls::HasPeerIdentity for TcpEndpoint {
167-
fn peer_identity(&self) -> tls::PeerIdentity {
168-
Conditional::None(tls::ReasonForNoPeerName::Loopback)
128+
impl Into<u16> for TcpEndpoint {
129+
fn into(self) -> u16 {
130+
self.port
169131
}
170132
}
171133

linkerd/app/inbound/src/lib.rs

Lines changed: 76 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,8 @@ use self::prevent_loop::PreventLoop;
1313
use self::require_identity_for_ports::RequireIdentityForPorts;
1414
use linkerd2_app_core::{
1515
classify,
16-
config::{ProxyConfig, ServerConfig},
17-
drain, dst, errors, metrics,
18-
opaque_transport::DetectHeader,
16+
config::{ConnectConfig, ProxyConfig, ServerConfig},
17+
drain, dst, errors, metrics, opaque_transport,
1918
opencensus::proto::trace::v1 as oc,
2019
profiles,
2120
proxy::{
@@ -28,8 +27,8 @@ use linkerd2_app_core::{
2827
transport::{self, io, listen, tls},
2928
Error, NameAddr, NameMatch, TraceContext, DST_OVERRIDE_HEADER,
3029
};
31-
use std::{collections::HashMap, time::Duration};
32-
use tokio::{net::TcpStream, sync::mpsc};
30+
use std::{collections::HashMap, fmt::Debug, net::SocketAddr, time::Duration};
31+
use tokio::sync::mpsc;
3332
use tracing::debug_span;
3433

3534
mod allow_discovery;
@@ -49,16 +48,37 @@ pub struct Config {
4948
#[derive(Clone, Debug)]
5049
pub struct SkipByPort(std::sync::Arc<indexmap::IndexSet<u16>>);
5150

51+
#[derive(Default)]
52+
struct NonOpaqueRefused(());
53+
5254
type SensorIo<T> = io::SensorIo<T, transport::metrics::Sensor>;
5355

5456
// === impl Config ===
5557

58+
pub fn tcp_connect<T: Into<u16>>(
59+
config: &ConnectConfig,
60+
) -> impl svc::Service<
61+
T,
62+
Response = impl io::AsyncRead + io::AsyncWrite + Send,
63+
Error = Error,
64+
Future = impl Send,
65+
> + Clone {
66+
// Establishes connections to remote peers (for both TCP
67+
// forwarding and HTTP proxying).
68+
svc::stack(transport::ConnectTcp::new(config.keepalive))
69+
.push_map_target(|t: T| ([127, 0, 0, 1], t.into()))
70+
// Limits the time we wait for a connection to be established.
71+
.push_timeout(config.timeout)
72+
.into_inner()
73+
}
74+
5675
#[allow(clippy::too_many_arguments)]
5776
impl Config {
58-
pub fn build<L, LSvc, P>(
77+
pub fn build<I, C, L, LSvc, P>(
5978
self,
60-
listen_addr: std::net::SocketAddr,
79+
listen_addr: SocketAddr,
6180
local_identity: tls::Conditional<identity::Local>,
81+
connect: C,
6282
http_loopback: L,
6383
profiles_client: P,
6484
tap: tap::Registry,
@@ -67,9 +87,21 @@ impl Config {
6787
drain: drain::Watch,
6888
) -> impl svc::NewService<
6989
listen::Addrs,
70-
Service = impl svc::Service<TcpStream, Response = (), Error = Error, Future = impl Send>,
90+
Service = impl svc::Service<I, Response = (), Error = Error, Future = impl Send>,
7191
> + Clone
7292
where
93+
I: tls::accept::Detectable
94+
+ io::AsyncRead
95+
+ io::AsyncWrite
96+
+ io::PeerAddr
97+
+ Debug
98+
+ Send
99+
+ Unpin
100+
+ 'static,
101+
C: svc::Service<TcpEndpoint> + Clone + Send + Sync + Unpin + 'static,
102+
C::Response: io::AsyncRead + io::AsyncWrite + Send + Unpin + 'static,
103+
C::Error: Into<Error>,
104+
C::Future: Send + Unpin,
73105
L: svc::NewService<Target, Service = LSvc> + Clone + Send + Sync + 'static,
74106
LSvc: svc::Service<http::Request<http::BoxBody>, Response = http::Response<http::BoxBody>>
75107
+ Send
@@ -81,9 +113,8 @@ impl Config {
81113
P::Future: Send,
82114
{
83115
let prevent_loop = PreventLoop::from(listen_addr.port());
84-
let tcp_connect = self.build_tcp_connect(prevent_loop, &metrics);
85116
let http_router = self.build_http_router(
86-
tcp_connect.clone(),
117+
connect.clone(),
87118
prevent_loop,
88119
http_loopback,
89120
profiles_client,
@@ -93,7 +124,8 @@ impl Config {
93124
);
94125

95126
// Forwards TCP streams that cannot be decoded as HTTP.
96-
let tcp_forward = svc::stack(tcp_connect)
127+
let tcp_forward = svc::stack(connect)
128+
.push(metrics.transport.layer_connect())
97129
.push_make_thunk()
98130
.push_on_response(
99131
svc::layers()
@@ -115,34 +147,9 @@ impl Config {
115147
self.build_tls_accept(accept, tcp_forward, local_identity, metrics)
116148
}
117149

118-
pub fn build_tcp_connect(
119-
&self,
120-
prevent_loop: PreventLoop,
121-
metrics: &metrics::Proxy,
122-
) -> impl svc::Service<
123-
TcpEndpoint,
124-
Response = impl io::AsyncRead + io::AsyncWrite + Send,
125-
Error = Error,
126-
Future = impl Send,
127-
> + svc::Service<
128-
HttpEndpoint,
129-
Response = impl io::AsyncRead + io::AsyncWrite + Send,
130-
Error = Error,
131-
Future = impl Send,
132-
> + Clone {
133-
// Establishes connections to remote peers (for both TCP
134-
// forwarding and HTTP proxying).
135-
svc::stack(transport::ConnectTcp::new(self.proxy.connect.keepalive))
136-
// Limits the time we wait for a connection to be established.
137-
.push_timeout(self.proxy.connect.timeout)
138-
.push(metrics.transport.layer_connect())
139-
.push_request_filter(prevent_loop)
140-
.into_inner()
141-
}
142-
143150
pub fn build_http_router<C, P, L, LSvc>(
144151
&self,
145-
tcp_connect: C,
152+
connect: C,
146153
prevent_loop: impl Into<PreventLoop>,
147154
loopback: L,
148155
profiles_client: P,
@@ -159,7 +166,7 @@ impl Config {
159166
> + Clone,
160167
> + Clone
161168
where
162-
C: svc::Service<HttpEndpoint> + Clone + Send + Sync + Unpin + 'static,
169+
C: svc::Service<TcpEndpoint> + Clone + Send + Sync + Unpin + 'static,
163170
C::Response: io::AsyncRead + io::AsyncWrite + Send + Unpin + 'static,
164171
C::Error: Into<Error>,
165172
C::Future: Send + Unpin,
@@ -174,45 +181,31 @@ impl Config {
174181
LSvc::Error: Into<Error>,
175182
LSvc::Future: Send,
176183
{
177-
let Config {
178-
allow_discovery,
179-
proxy:
180-
ProxyConfig {
181-
connect,
182-
buffer_capacity,
183-
cache_max_idle_age,
184-
dispatch_timeout,
185-
..
186-
},
187-
..
188-
} = self.clone();
189-
190184
let prevent_loop = prevent_loop.into();
191185

192186
// Creates HTTP clients for each inbound port & HTTP settings.
193-
let endpoint = svc::stack(tcp_connect)
187+
let endpoint = svc::stack(connect)
188+
.push(metrics.transport.layer_connect())
189+
.push_map_target(TcpEndpoint::from)
194190
.push(http::client::layer(
195-
connect.h1_settings,
196-
connect.h2_settings,
191+
self.proxy.connect.h1_settings,
192+
self.proxy.connect.h2_settings,
197193
))
198194
.push(reconnect::layer({
199-
let backoff = connect.backoff;
195+
let backoff = self.proxy.connect.backoff;
200196
move |_| Ok(backoff.stream())
201197
}))
202198
.check_new_service::<HttpEndpoint, http::Request<_>>();
203199

204-
let observe = svc::layers()
200+
let target = endpoint
201+
.push_map_target(HttpEndpoint::from)
205202
// Registers the stack to be tapped.
206203
.push(tap::NewTapHttp::layer(tap))
207204
// Records metrics for each `Target`.
208205
.push(metrics.http_endpoint.to_layer::<classify::Response, _>())
209206
.push_on_response(TraceContext::layer(
210207
span_sink.map(|span_sink| SpanConverter::client(span_sink, trace_labels())),
211-
));
212-
213-
let target = endpoint
214-
.push_map_target(HttpEndpoint::from)
215-
.push(observe)
208+
))
216209
.push_on_response(http::BoxResponse::layer())
217210
.check_new_service::<Target, http::Request<_>>();
218211

@@ -242,7 +235,7 @@ impl Config {
242235
.push_map_target(endpoint::Logical::from)
243236
.push(profiles::discover::layer(
244237
profiles_client,
245-
AllowProfile(allow_discovery),
238+
AllowProfile(self.allow_discovery.clone()),
246239
))
247240
.push_on_response(http::BoxResponse::layer())
248241
.instrument(|_: &Target| debug_span!("profile"))
@@ -257,11 +250,11 @@ impl Config {
257250
.check_new_service::<Target, http::Request<http::BoxBody>>()
258251
.push_on_response(
259252
svc::layers()
260-
.push(svc::FailFast::layer("Logical", dispatch_timeout))
261-
.push_spawn_buffer(buffer_capacity)
253+
.push(svc::FailFast::layer("Logical", self.proxy.dispatch_timeout))
254+
.push_spawn_buffer(self.proxy.buffer_capacity)
262255
.push(metrics.stack.layer(stack_labels("http", "logical"))),
263256
)
264-
.push_cache(cache_max_idle_age)
257+
.push_cache(self.proxy.cache_max_idle_age)
265258
.push_on_response(
266259
svc::layers()
267260
.push(http::Retain::layer())
@@ -324,15 +317,17 @@ impl Config {
324317
prevent_loop.into(),
325318
// If the connection targets the inbound port, try to detect an
326319
// opaque transport header and rewrite the target port
327-
// accordingly. If there was no opaque transport header, the
328-
// forwarding will fail when the tcp connect stack applies loop
329-
// prevention.
320+
// accordingly. If there was no opaque transport header, fail
321+
// the connection with a ConnectionRefused error.
330322
svc::stack(tcp_forward)
331-
.push_map_target(TcpEndpoint::from)
323+
.push_map_target(|(h, _): (opaque_transport::Header, _)| TcpEndpoint::from(h))
324+
.push(svc::NewUnwrapOr::layer(
325+
svc::Fail::<_, NonOpaqueRefused>::default(),
326+
))
332327
.push(transport::NewDetectService::layer(
333328
transport::detect::DetectTimeout::new(
334329
self.proxy.detect_protocol_timeout,
335-
DetectHeader::default(),
330+
opaque_transport::DetectHeader::default(),
336331
),
337332
)),
338333
)
@@ -448,3 +443,14 @@ impl svc::stack::Switch<listen::Addrs> for SkipByPort {
448443
!self.0.contains(&t.target_addr().port())
449444
}
450445
}
446+
447+
// === impl NonOpaqueRefused ===
448+
449+
impl Into<Error> for NonOpaqueRefused {
450+
fn into(self) -> Error {
451+
Error::from(io::Error::new(
452+
io::ErrorKind::ConnectionRefused,
453+
"Non-opaque-transport connection refused",
454+
))
455+
}
456+
}

0 commit comments

Comments
 (0)