Skip to content

Commit c66ccb7

Browse files
committed
gateway: gRPC proxy support
1 parent 38eaf93 commit c66ccb7

File tree

3 files changed

+61
-27
lines changed

3 files changed

+61
-27
lines changed

gateway/src/main_service.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ pub struct ProxyInner {
6060
notify_state_updated: Notify,
6161
auth_client: AuthClient,
6262
pub(crate) acceptor: RwLock<TlsAcceptor>,
63+
pub(crate) h2_acceptor: RwLock<TlsAcceptor>,
6364
}
6465

6566
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -141,15 +142,19 @@ impl ProxyInner {
141142
}
142143
false => None,
143144
};
144-
let acceptor =
145-
RwLock::new(create_acceptor(&config.proxy).context("Failed to create acceptor")?);
145+
let acceptor = RwLock::new(
146+
create_acceptor(&config.proxy, false).context("Failed to create acceptor")?,
147+
);
148+
let h2_acceptor =
149+
RwLock::new(create_acceptor(&config.proxy, true).context("Failed to create acceptor")?);
146150
Ok(Self {
147151
config,
148152
state,
149153
notify_state_updated: Notify::new(),
150154
my_app_id,
151155
auth_client,
152156
acceptor,
157+
h2_acceptor,
153158
certbot,
154159
})
155160
}

gateway/src/proxy.rs

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ struct DstInfo {
6565
app_id: String,
6666
port: u16,
6767
is_tls: bool,
68+
is_h2: bool,
6869
}
6970

7071
fn parse_destination(sni: &str, dotted_base_domain: &str) -> Result<DstInfo> {
@@ -83,22 +84,28 @@ fn parse_destination(sni: &str, dotted_base_domain: &str) -> Result<DstInfo> {
8384
let last_part = parts.next();
8485
let is_tls;
8586
let port;
87+
let is_h2;
8688
match last_part {
8789
None => {
8890
is_tls = false;
91+
is_h2 = false;
8992
port = None;
9093
}
9194
Some(last_part) => {
92-
let port_str = match last_part.strip_suffix('s') {
93-
None => {
94-
is_tls = false;
95-
last_part
96-
}
97-
Some(last_part) => {
98-
is_tls = true;
99-
last_part
100-
}
95+
let (port_str, has_g) = match last_part.strip_suffix('g') {
96+
Some(without_g) => (without_g, true),
97+
None => (last_part, false),
10198
};
99+
100+
let (port_str, has_s) = match port_str.strip_suffix('s') {
101+
Some(without_s) => (without_s, true),
102+
None => (port_str, false),
103+
};
104+
if has_g && has_s {
105+
bail!("invalid sni format: `gs` is not allowed");
106+
}
107+
is_h2 = has_g;
108+
is_tls = has_s;
102109
port = if port_str.is_empty() {
103110
None
104111
} else {
@@ -114,6 +121,7 @@ fn parse_destination(sni: &str, dotted_base_domain: &str) -> Result<DstInfo> {
114121
app_id,
115122
port,
116123
is_tls,
124+
is_h2,
117125
})
118126
}
119127

@@ -138,7 +146,9 @@ async fn handle_connection(
138146
if dst.is_tls {
139147
tls_passthough::proxy_to_app(state, inbound, buffer, &dst.app_id, dst.port).await
140148
} else {
141-
state.proxy(inbound, buffer, &dst.app_id, dst.port).await
149+
state
150+
.proxy(inbound, buffer, &dst.app_id, dst.port, dst.is_h2)
151+
.await
142152
}
143153
} else {
144154
tls_passthough::proxy_with_sni(state, inbound, buffer, &sni).await

gateway/src/proxy/tls_terminate.rs

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ where
9393
}
9494
}
9595

96-
pub(crate) fn create_acceptor(config: &ProxyConfig) -> Result<TlsAcceptor> {
96+
pub(crate) fn create_acceptor(config: &ProxyConfig, h2: bool) -> Result<TlsAcceptor> {
9797
let cert_pem = fs::read(&config.cert_chain).context("failed to read certificate")?;
9898
let key_pem = fs::read(&config.cert_key).context("failed to read private key")?;
9999
let certs = CertificateDer::pem_slice_iter(cert_pem.as_slice())
@@ -114,12 +114,16 @@ pub(crate) fn create_acceptor(config: &ProxyConfig) -> Result<TlsAcceptor> {
114114
TlsVersion::Tls13 => &TLS13,
115115
})
116116
.collect::<Vec<_>>();
117-
let config = rustls::ServerConfig::builder_with_provider(Arc::new(provider))
117+
let mut config = rustls::ServerConfig::builder_with_provider(Arc::new(provider))
118118
.with_protocol_versions(&supported_versions)
119119
.context("Failed to build TLS config")?
120120
.with_no_client_auth()
121121
.with_single_cert(certs, key)?;
122122

123+
if h2 {
124+
config.alpn_protocols = vec![b"h2".to_vec()];
125+
}
126+
123127
let acceptor = TlsAcceptor::from(Arc::new(config));
124128

125129
Ok(acceptor)
@@ -145,11 +149,16 @@ impl Proxy {
145149
/// Reload the TLS acceptor with fresh certificates
146150
pub fn reload_certificates(&self) -> Result<()> {
147151
info!("Reloading TLS certificates");
148-
let new_acceptor = create_acceptor(&self.config.proxy)?;
149-
150152
// Replace the acceptor with the new one
151153
if let Ok(mut acceptor) = self.acceptor.write() {
152-
*acceptor = new_acceptor;
154+
*acceptor = create_acceptor(&self.config.proxy, false)?;
155+
info!("TLS certificates successfully reloaded");
156+
} else {
157+
bail!("Failed to acquire write lock for TLS acceptor");
158+
}
159+
160+
if let Ok(mut acceptor) = self.h2_acceptor.write() {
161+
*acceptor = create_acceptor(&self.config.proxy, true)?;
153162
info!("TLS certificates successfully reloaded");
154163
} else {
155164
bail!("Failed to acquire write lock for TLS acceptor");
@@ -163,11 +172,12 @@ impl Proxy {
163172
inbound: TcpStream,
164173
buffer: Vec<u8>,
165174
port: u16,
175+
h2: bool,
166176
) -> Result<()> {
167177
if port != 80 {
168178
bail!("Only port 80 is supported for this node");
169179
}
170-
let stream = self.tls_accept(inbound, buffer).await?;
180+
let stream = self.tls_accept(inbound, buffer, h2).await?;
171181
let io = TokioIo::new(stream);
172182

173183
let service = service_fn(|req: Request<Incoming>| async move {
@@ -218,11 +228,12 @@ impl Proxy {
218228
inbound: TcpStream,
219229
buffer: Vec<u8>,
220230
port: u16,
231+
h2: bool,
221232
) -> Result<()> {
222233
if port != 80 {
223234
bail!("Only port 80 is supported for health checks");
224235
}
225-
let stream = self.tls_accept(inbound, buffer).await?;
236+
let stream = self.tls_accept(inbound, buffer, h2).await?;
226237

227238
// Wrap the TLS stream with TokioIo to make it compatible with hyper 1.x
228239
let io = TokioIo::new(stream);
@@ -253,17 +264,24 @@ impl Proxy {
253264
&self,
254265
inbound: TcpStream,
255266
buffer: Vec<u8>,
267+
h2: bool,
256268
) -> Result<TlsStream<MergedStream>> {
257269
let stream = MergedStream {
258270
buffer,
259271
buffer_cursor: 0,
260272
inbound,
261273
};
262-
let acceptor = self
263-
.acceptor
264-
.read()
265-
.expect("Failed to acquire read lock for TLS acceptor")
266-
.clone();
274+
let acceptor = if h2 {
275+
self.h2_acceptor
276+
.read()
277+
.expect("Failed to acquire read lock for TLS acceptor")
278+
.clone()
279+
} else {
280+
self.acceptor
281+
.read()
282+
.expect("Failed to acquire read lock for TLS acceptor")
283+
.clone()
284+
};
267285
let tls_stream = timeout(
268286
self.config.proxy.timeouts.handshake,
269287
acceptor.accept(stream),
@@ -280,19 +298,20 @@ impl Proxy {
280298
buffer: Vec<u8>,
281299
app_id: &str,
282300
port: u16,
301+
h2: bool,
283302
) -> Result<()> {
284303
if app_id == "health" {
285-
return self.handle_health_check(inbound, buffer, port).await;
304+
return self.handle_health_check(inbound, buffer, port, h2).await;
286305
}
287306
if app_id == "gateway" {
288-
return self.handle_this_node(inbound, buffer, port).await;
307+
return self.handle_this_node(inbound, buffer, port, h2).await;
289308
}
290309
let addresses = self
291310
.lock()
292311
.select_top_n_hosts(app_id)
293312
.with_context(|| format!("app {app_id} not found"))?;
294313
debug!("selected top n hosts: {addresses:?}");
295-
let tls_stream = self.tls_accept(inbound, buffer).await?;
314+
let tls_stream = self.tls_accept(inbound, buffer, h2).await?;
296315
let (outbound, _counter) = timeout(
297316
self.config.proxy.timeouts.connect,
298317
connect_multiple_hosts(addresses, port),

0 commit comments

Comments
 (0)