Skip to content

Commit 36df03d

Browse files
authored
Merge pull request #34191 from teskje/ctp-unbounded
ctp: make channels unbounded
2 parents 6b72de9 + 52ee700 commit 36df03d

File tree

2 files changed

+43
-42
lines changed

2 files changed

+43
-42
lines changed

src/service/src/transport.rs

Lines changed: 9 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ use std::convert::Infallible;
2626
use std::fmt::Debug;
2727
use std::time::Duration;
2828

29-
use anyhow::{anyhow, bail};
29+
use anyhow::bail;
3030
use async_trait::async_trait;
3131
use bincode::Options;
3232
use futures::future;
@@ -243,9 +243,9 @@ where
243243
#[derive(Debug)]
244244
struct Connection<Out, In> {
245245
/// Message sender connected to the send task.
246-
msg_tx: mpsc::Sender<Out>,
246+
msg_tx: mpsc::UnboundedSender<Out>,
247247
/// Message receiver connected to the receive task.
248-
msg_rx: mpsc::Receiver<In>,
248+
msg_rx: mpsc::UnboundedReceiver<In>,
249249
/// Receiver for errors encountered by connection tasks.
250250
error_rx: watch::Receiver<String>,
251251

@@ -289,8 +289,8 @@ impl<Out: Message, In: Message> Connection<Out, In> {
289289

290290
handshake(&mut reader, &mut writer, version, server_fqdn).await?;
291291

292-
let (out_tx, out_rx) = mpsc::channel(1024);
293-
let (in_tx, in_rx) = mpsc::channel(1024);
292+
let (out_tx, out_rx) = mpsc::unbounded_channel();
293+
let (in_tx, in_rx) = mpsc::unbounded_channel();
294294
// Initialize the error channel with a default error to return if none of the tasks
295295
// produced an error.
296296
let (error_tx, error_rx) = watch::channel("connection closed".into());
@@ -314,7 +314,7 @@ impl<Out: Message, In: Message> Connection<Out, In> {
314314

315315
/// Enqueue a message for sending.
316316
async fn send(&mut self, msg: Out) -> anyhow::Result<()> {
317-
match self.msg_tx.send(msg).await {
317+
match self.msg_tx.send(msg) {
318318
Ok(()) => Ok(()),
319319
Err(_) => bail!(self.collect_error().await),
320320
}
@@ -347,7 +347,7 @@ impl<Out: Message, In: Message> Connection<Out, In> {
347347
/// Run a connection's send task.
348348
async fn run_send_task<W: AsyncWrite + Unpin>(
349349
mut writer: W,
350-
mut msg_rx: mpsc::Receiver<Out>,
350+
mut msg_rx: mpsc::UnboundedReceiver<Out>,
351351
error_tx: watch::Sender<String>,
352352
mut metrics: impl Metrics<Out, In>,
353353
) {
@@ -383,7 +383,7 @@ impl<Out: Message, In: Message> Connection<Out, In> {
383383
/// Run a connection's recv task.
384384
async fn run_recv_task<R: AsyncRead + Unpin>(
385385
mut reader: R,
386-
msg_tx: mpsc::Sender<In>,
386+
msg_tx: mpsc::UnboundedSender<In>,
387387
error_tx: watch::Sender<String>,
388388
mut metrics: impl Metrics<Out, In>,
389389
) {
@@ -393,7 +393,7 @@ impl<Out: Message, In: Message> Connection<Out, In> {
393393
trace!(?msg, "ctp: received message");
394394
metrics.message_received(&msg);
395395

396-
if msg_tx.send(msg).await.is_err() {
396+
if msg_tx.send(msg).is_err() {
397397
break;
398398
}
399399
}
@@ -407,38 +407,6 @@ impl<Out: Message, In: Message> Connection<Out, In> {
407407
}
408408
}
409409

410-
/// A connection handler that simply forwards messages over channels.
411-
#[derive(Debug)]
412-
pub struct ChannelHandler<In, Out> {
413-
tx: mpsc::UnboundedSender<In>,
414-
rx: mpsc::UnboundedReceiver<Out>,
415-
}
416-
417-
impl<In, Out> ChannelHandler<In, Out> {
418-
pub fn new(tx: mpsc::UnboundedSender<In>, rx: mpsc::UnboundedReceiver<Out>) -> Self {
419-
Self { tx, rx }
420-
}
421-
}
422-
423-
#[async_trait]
424-
impl<In: Message, Out: Message> GenericClient<In, Out> for ChannelHandler<In, Out> {
425-
async fn send(&mut self, cmd: In) -> anyhow::Result<()> {
426-
let result = self.tx.send(cmd);
427-
result.map_err(|_| anyhow!("client channel disconnected"))
428-
}
429-
430-
/// # Cancel safety
431-
///
432-
/// This method is cancel safe.
433-
async fn recv(&mut self) -> anyhow::Result<Option<Out>> {
434-
// `mpsc::UnboundedReceiver::recv` is cancel safe.
435-
match self.rx.recv().await {
436-
Some(resp) => Ok(Some(resp)),
437-
None => bail!("client channel disconnected"),
438-
}
439-
}
440-
}
441-
442410
/// Perform the CTP handshake.
443411
///
444412
/// To perform the handshake, each endpoint sends the protocol magic number, followed by a

src/service/tests/transport.rs

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,14 @@ use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
1313
use std::sync::{Arc, Mutex, Once};
1414
use std::time::Duration;
1515

16+
use anyhow::{anyhow, bail};
1617
use async_trait::async_trait;
1718
use futures::future;
1819
use mz_ore::assert_none;
1920
use mz_ore::netio::Listener;
2021
use mz_ore::retry::Retry;
2122
use mz_service::client::GenericClient;
22-
use mz_service::transport::{self, ChannelHandler, Message, NoopMetrics};
23+
use mz_service::transport::{self, Message, NoopMetrics};
2324
use semver::Version;
2425
use tokio::io::AsyncWriteExt;
2526
use tokio::sync::{mpsc, oneshot};
@@ -556,6 +557,38 @@ fn test_metrics() {
556557
sim.run().unwrap();
557558
}
558559

560+
/// A connection handler that simply forwards messages over channels.
561+
#[derive(Debug)]
562+
pub struct ChannelHandler<In, Out> {
563+
tx: mpsc::UnboundedSender<In>,
564+
rx: mpsc::UnboundedReceiver<Out>,
565+
}
566+
567+
impl<In, Out> ChannelHandler<In, Out> {
568+
pub fn new(tx: mpsc::UnboundedSender<In>, rx: mpsc::UnboundedReceiver<Out>) -> Self {
569+
Self { tx, rx }
570+
}
571+
}
572+
573+
#[async_trait]
574+
impl<In: Message, Out: Message> GenericClient<In, Out> for ChannelHandler<In, Out> {
575+
async fn send(&mut self, cmd: In) -> anyhow::Result<()> {
576+
let result = self.tx.send(cmd);
577+
result.map_err(|_| anyhow!("client channel disconnected"))
578+
}
579+
580+
/// # Cancel safety
581+
///
582+
/// This method is cancel safe.
583+
async fn recv(&mut self) -> anyhow::Result<Option<Out>> {
584+
// `mpsc::Receiver::recv` is cancel safe.
585+
match self.rx.recv().await {
586+
Some(resp) => Ok(Some(resp)),
587+
None => bail!("client channel disconnected"),
588+
}
589+
}
590+
}
591+
559592
/// A connection handler that produces a single outbound message and then becomes silent.
560593
#[derive(Debug)]
561594
struct OneOutputHandler {

0 commit comments

Comments
 (0)