diff --git a/hyperactor/Cargo.toml b/hyperactor/Cargo.toml index 008164ec5..fece880bf 100644 --- a/hyperactor/Cargo.toml +++ b/hyperactor/Cargo.toml @@ -1,4 +1,4 @@ -# @generated by autocargo from //monarch/hyperactor:[channel_benchmarks,hyperactor,hyperactor-example-derive,hyperactor-example-stream] +# @generated by autocargo from //monarch/hyperactor:[channel_benchmarks,hyperactor,hyperactor-example-channel,hyperactor-example-derive,hyperactor-example-stream] [package] name = "hyperactor" @@ -13,6 +13,10 @@ license = "BSD-3-Clause" name = "channel_benchmarks" path = "benches/main.rs" +[[bin]] +name = "hyperactor_example_channel" +path = "example/channel.rs" + [[bin]] name = "hyperactor_example_derive" path = "example/derive.rs" diff --git a/hyperactor/benches/main.rs b/hyperactor/benches/main.rs index 9dfa03430..92b95159c 100644 --- a/hyperactor/benches/main.rs +++ b/hyperactor/benches/main.rs @@ -11,6 +11,7 @@ use std::time::Duration; use std::time::Instant; +use bytes::Bytes; use criterion::BenchmarkId; use criterion::Criterion; use criterion::Throughput; @@ -18,6 +19,7 @@ use criterion::criterion_group; use criterion::criterion_main; use futures::future::join_all; use hyperactor::Named; +use hyperactor::channel; use hyperactor::channel::ChannelAddr; use hyperactor::channel::ChannelTransport; use hyperactor::channel::Rx; @@ -26,10 +28,18 @@ use hyperactor::channel::dial; use hyperactor::channel::serve; use serde::Deserialize; use serde::Serialize; +use tokio::runtime; use tokio::runtime::Runtime; use tokio::select; use tokio::sync::oneshot; +fn new_runtime() -> Runtime { + runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap() +} + #[derive(Debug, Clone, Serialize, Deserialize, Named, PartialEq)] struct Message { id: u64, @@ -62,7 +72,7 @@ fn bench_message_sizes(c: &mut Criterion) { group.sampling_mode(criterion::SamplingMode::Flat); group.sample_size(10); group.bench_function(BenchmarkId::from_parameter(size), move |b| { - let mut b = b.to_async(Runtime::new().unwrap()); + let mut b = b.to_async(new_runtime()); let tt = &transport; b.iter_custom(|iters| async move { let addr = ChannelAddr::any(tt.clone()); @@ -106,7 +116,7 @@ fn bench_message_rates(c: &mut Criterion) { let rate = *rate; group.bench_function(format!("rate_{}_{}mps", transport_name, rate), move |b| { - let mut b = b.to_async(Runtime::new().unwrap()); + let mut b = b.to_async(new_runtime()); b.iter_custom(|iters| async move { let total_msgs = iters * rate; let addr = ChannelAddr::any(transport.clone()); @@ -169,6 +179,66 @@ fn bench_message_rates(c: &mut Criterion) { group.finish(); } -criterion_group!(benches, bench_message_sizes, bench_message_rates); +// Try to replicate https://www.internalfb.com/phabricator/paste/view/P1903314366 +fn bench_channel_ping_pong(c: &mut Criterion) { + let transport = ChannelTransport::Unix; + + for size in [1usize, 1_000_000usize] { + let mut group = c.benchmark_group("channel_ping_pong".to_string()); + let transport = transport.clone(); + group.throughput(Throughput::Bytes((size * 2) as u64)); // send and receive + group.sampling_mode(criterion::SamplingMode::Flat); + group.sample_size(100); + group.bench_function(BenchmarkId::from_parameter(size), move |b| { + let mut b = b.to_async(new_runtime()); + b.iter_custom(|iters| channel_ping_pong(transport.clone(), size, iters as usize)); + }); + group.finish(); + } +} + +async fn channel_ping_pong( + transport: ChannelTransport, + message_size: usize, + num_iter: usize, +) -> Duration { + let (client_addr, mut client_rx) = channel::serve::(ChannelAddr::any(transport.clone())) + .await + .unwrap(); + let (server_addr, mut server_rx) = channel::serve::(ChannelAddr::any(transport.clone())) + .await + .unwrap(); + + let _server_handle: tokio::task::JoinHandle> = + tokio::spawn(async move { + let client_tx = channel::dial(client_addr)?; + loop { + let message = server_rx.recv().await?; + client_tx.post(message); + } + }); + + let client_handle: tokio::task::JoinHandle> = + tokio::spawn(async move { + let server_tx = channel::dial(server_addr)?; + let message = Bytes::from(vec![0u8; message_size]); + for _ in 0..num_iter { + server_tx.post(message.clone() /*cheap */); + client_rx.recv().await?; + } + Ok(()) + }); + + let start = Instant::now(); + client_handle.await.unwrap().unwrap(); + start.elapsed() +} + +criterion_group!( + benches, + bench_message_sizes, + bench_message_rates, + bench_channel_ping_pong +); criterion_main!(benches); diff --git a/hyperactor/example/channel.rs b/hyperactor/example/channel.rs new file mode 100644 index 000000000..0e5b562ec --- /dev/null +++ b/hyperactor/example/channel.rs @@ -0,0 +1,98 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +use bytes::Bytes; +use hyperactor::channel; +use hyperactor::channel::ChannelAddr; +use hyperactor::channel::ChannelRx; +use hyperactor::channel::ChannelTransport; +use hyperactor::channel::Rx; +use hyperactor::channel::Tx; +use tokio::time::Duration; +use tokio::time::Instant; + +async fn server( + mut server_rx: ChannelRx, + client_addr: ChannelAddr, +) -> Result<(), anyhow::Error> { + let client_tx = channel::dial(client_addr)?; + loop { + let message = server_rx.recv().await?; + client_tx.post(message); + } +} + +// Analog of https://www.internalfb.com/phabricator/paste/view/P1903314366, using Channel APIs. +// Possibly we should create separate threads for the client and server to also make the OS-level +// setup equivalent. +#[tokio::main(flavor = "current_thread")] +async fn main() -> Result<(), anyhow::Error> { + let transport = ChannelTransport::Tcp; + // let transport = ChannelTransport::Local; + let message_size = 1_000_000; + let num_iter = 100; + + let (client_addr, mut client_rx) = channel::serve::(ChannelAddr::any(transport.clone())) + .await + .unwrap(); + let (server_addr, server_rx) = channel::serve::(ChannelAddr::any(transport.clone())) + .await + .unwrap(); + + let _server_handle = tokio::spawn(server(server_rx, client_addr)); + + let server_tx = channel::dial(server_addr)?; + let message = Bytes::from(vec![0u8; message_size]); + + for _ in 0..10 { + // Warmup + let t = Instant::now(); + server_tx.post(message.clone() /*cheap */); + client_rx.recv().await?; + } + + let mut latencies = vec![]; + let mut total_bytes_sent = 0usize; + let mut total_bytes_received = 0usize; + + let start = Instant::now(); + for _ in 0..num_iter { + total_bytes_sent += message.len(); + let start = Instant::now(); + server_tx.post(message.clone() /*cheap */); + total_bytes_received += client_rx.recv().await?.len(); + latencies.push(start.elapsed()); + } + let elapsed = start.elapsed(); + + let avg_latency = ((latencies.iter().sum::().as_micros() as f64) / 1000f64) + / (latencies.len() as f64); + let min_latency = (latencies.iter().min().unwrap().as_micros() as f64) / 1000f64; + let max_latency = (latencies.iter().max().unwrap().as_micros() as f64) / 1000f64; + + let total_bytes_transferred = total_bytes_sent + total_bytes_received; + let bandwidth_bytes_per_sec = + (total_bytes_transferred as f64) / ((elapsed.as_millis() as f64) / 1000f64); + let bandwidth_mbps = (bandwidth_bytes_per_sec * 8f64) / (1024f64 * 1024f64); + + println!("Results:"); + println!("Average latency: {} ms", avg_latency); + println!("Min latency: {} ms", min_latency); + println!("Max latency: {} ms", max_latency); + println!("Total iterations: {}", latencies.len()); + println!("Total time: {} seconds", elapsed.as_secs()); + println!("Bytes sent: {} bytes", total_bytes_sent); + println!("Bytes received: {} bytes", total_bytes_received); + println!("Total bytes transferred: {} bytes", total_bytes_transferred); + println!( + "Bandwidth: {} bytes/sec ({} Mbps)", + bandwidth_bytes_per_sec, bandwidth_mbps + ); + + Ok(()) +} diff --git a/hyperactor/src/data.rs b/hyperactor/src/data.rs index c5c4df6ff..d8c9ab32f 100644 --- a/hyperactor/src/data.rs +++ b/hyperactor/src/data.rs @@ -105,6 +105,12 @@ impl Named for std::time::Duration { } } +impl Named for bytes::Bytes { + fn typename() -> &'static str { + "bytes::Bytes" + } +} + // A macro that implements type-keyed interning of typenames. This is useful // for implementing [`Named`] for generic types. #[doc(hidden)] // not part of the public API