Skip to content

Commit aa7fc0d

Browse files
committed
Add --threads and --workers options in tpch benchmarks
1 parent b23c0c7 commit aa7fc0d

File tree

3 files changed

+82
-27
lines changed

3 files changed

+82
-27
lines changed

benchmarks/src/bin/dfbench.rs

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,12 +30,14 @@ enum Options {
3030
}
3131

3232
// Main benchmark runner entrypoint
33-
#[tokio::main]
34-
pub async fn main() -> Result<()> {
33+
pub fn main() -> Result<()> {
3534
env_logger::init();
3635

3736
match Options::from_args() {
38-
Options::Tpch(opt) => Box::pin(opt.run()).await,
39-
Options::TpchConvert(opt) => opt.run().await,
37+
Options::Tpch(opt) => opt.run(),
38+
Options::TpchConvert(opt) => {
39+
let rt = tokio::runtime::Runtime::new()?;
40+
rt.block_on(async { opt.run().await })
41+
}
4042
}
4143
}

benchmarks/src/tpch/run.rs

Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,6 @@ use super::{
2020
TPCH_QUERY_START_ID, TPCH_TABLES,
2121
};
2222
use async_trait::async_trait;
23-
use std::path::PathBuf;
24-
use std::sync::Arc;
25-
2623
use datafusion::arrow::record_batch::RecordBatch;
2724
use datafusion::arrow::util::pretty::{self, pretty_format_batches};
2825
use datafusion::common::instant::Instant;
@@ -40,18 +37,25 @@ use datafusion::execution::{SessionState, SessionStateBuilder};
4037
use datafusion::physical_plan::display::DisplayableExecutionPlan;
4138
use datafusion::physical_plan::{collect, displayable};
4239
use datafusion::prelude::*;
40+
use datafusion_distributed::MappedDistributedSessionBuilderExt;
41+
use std::path::PathBuf;
42+
use std::sync::Arc;
4343

4444
use crate::util::{
4545
BenchmarkRun, CommonOpt, InMemoryCacheExecCodec, InMemoryDataSourceRule, QueryResult,
4646
WarmingUpMarker,
4747
};
48-
use datafusion_distributed::test_utils::localhost::start_localhost_context;
48+
use datafusion_distributed::test_utils::localhost::{
49+
get_free_ports, spawn_flight_service, start_localhost_context, LocalHostChannelResolver,
50+
};
4951
use datafusion_distributed::{
5052
DistributedExt, DistributedPhysicalOptimizerRule, DistributedSessionBuilder,
5153
DistributedSessionBuilderContext,
5254
};
5355
use log::info;
5456
use structopt::StructOpt;
57+
use tokio::net::TcpListener;
58+
use tokio::task::JoinHandle;
5559

5660
// hack to avoid `default_value is meaningless for bool` errors
5761
type BoolDefaultTrue = bool;
@@ -113,6 +117,14 @@ pub struct RunOpt {
113117
/// Number of partitions per task.
114118
#[structopt(long = "ppt")]
115119
partitions_per_task: Option<usize>,
120+
121+
/// Number of physical threads per worker (default 1)
122+
#[structopt(long, default_value = "1")]
123+
workers: usize,
124+
125+
/// Number of physical threads per worker
126+
#[structopt(long)]
127+
threads: Option<usize>,
116128
}
117129

118130
#[async_trait]
@@ -156,7 +168,50 @@ impl DistributedSessionBuilder for RunOpt {
156168
}
157169

158170
impl RunOpt {
159-
pub async fn run(mut self) -> Result<()> {
171+
pub fn spawn_workers(self) -> Vec<(tokio::runtime::Runtime, JoinHandle<()>)> {
172+
let ports = get_free_ports(self.workers);
173+
let channel_resolver = LocalHostChannelResolver::new(ports.clone());
174+
let threads_per_worker = self.threads;
175+
let session_builder = self.map(move |builder: SessionStateBuilder| {
176+
let channel_resolver = channel_resolver.clone();
177+
Ok(builder
178+
.with_distributed_channel_resolver(channel_resolver)
179+
.build())
180+
});
181+
let mut handles = vec![];
182+
for port in ports {
183+
let session_builder = session_builder.clone();
184+
let rt = tokio::runtime::Builder::new_multi_thread()
185+
.worker_threads(threads_per_worker.unwrap_or(get_available_parallelism()))
186+
.enable_all()
187+
.build()
188+
.unwrap();
189+
let handle = rt.spawn(async move {
190+
let listener = TcpListener::bind(format!("127.0.0.1:{port}"))
191+
.await
192+
.unwrap();
193+
spawn_flight_service(session_builder, listener)
194+
.await
195+
.unwrap();
196+
});
197+
198+
handles.push((rt, handle));
199+
}
200+
handles
201+
}
202+
203+
pub fn run(self) -> Result<()> {
204+
let _handle = self.clone().spawn_workers();
205+
206+
let rt = tokio::runtime::Builder::new_multi_thread()
207+
.worker_threads(self.threads.unwrap_or(get_available_parallelism()))
208+
.enable_all()
209+
.build()?;
210+
211+
rt.block_on(async move { self._run().await })
212+
}
213+
214+
pub async fn _run(mut self) -> Result<()> {
160215
let (ctx, _guard) = start_localhost_context(1, self.clone()).await;
161216
println!("Running benchmarks with the following options: {self:?}");
162217
let query_range = match self.query {

src/test_utils/localhost.rs

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,17 @@ use tokio::net::TcpListener;
1717
use tonic::transport::{Channel, Server};
1818
use url::Url;
1919

20+
pub fn get_free_ports(n: usize) -> Vec<u16> {
21+
let listeners = (0..n)
22+
.map(|_| std::net::TcpListener::bind("127.0.0.1:0"))
23+
.collect::<Result<Vec<_>, _>>()
24+
.expect("Failed to bind to address");
25+
listeners
26+
.iter()
27+
.map(|listener| listener.local_addr().unwrap().port())
28+
.collect()
29+
}
30+
2031
pub async fn start_localhost_context<B>(
2132
num_workers: usize,
2233
session_builder: B,
@@ -25,23 +36,7 @@ where
2536
B: DistributedSessionBuilder + Send + Sync + 'static,
2637
B: Clone,
2738
{
28-
let listeners = futures::future::try_join_all(
29-
(0..num_workers)
30-
.map(|_| TcpListener::bind("127.0.0.1:0"))
31-
.collect::<Vec<_>>(),
32-
)
33-
.await
34-
.expect("Failed to bind to address");
35-
36-
let ports: Vec<u16> = listeners
37-
.iter()
38-
.map(|listener| {
39-
listener
40-
.local_addr()
41-
.expect("Failed to get local address")
42-
.port()
43-
})
44-
.collect();
39+
let ports = get_free_ports(num_workers);
4540

4641
let channel_resolver = LocalHostChannelResolver::new(ports.clone());
4742
let session_builder = session_builder.map(move |builder: SessionStateBuilder| {
@@ -51,8 +46,11 @@ where
5146
.build())
5247
});
5348
let mut join_set = JoinSet::new();
54-
for listener in listeners {
49+
for port in ports {
5550
let session_builder = session_builder.clone();
51+
let listener = TcpListener::bind(format!("127.0.0.1:{port}"))
52+
.await
53+
.unwrap();
5654
join_set.spawn(async move {
5755
spawn_flight_service(session_builder, listener)
5856
.await

0 commit comments

Comments
 (0)