Skip to content

Commit dab7670

Browse files
committed
Draft: spawn workers correctly in different tokio runtimes
1 parent 6a17cca commit dab7670

File tree

1 file changed

+60
-45
lines changed

1 file changed

+60
-45
lines changed

benchmarks/src/tpch/run.rs

Lines changed: 60 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ use datafusion::physical_plan::display::DisplayableExecutionPlan;
4242
use datafusion::physical_plan::{collect, displayable};
4343
use datafusion::prelude::*;
4444
use datafusion_distributed::test_utils::localhost::{
45-
get_free_ports, spawn_flight_service, start_localhost_context, LocalHostChannelResolver,
45+
get_free_ports, spawn_flight_service, LocalHostChannelResolver,
4646
};
4747
use datafusion_distributed::MappedDistributedSessionBuilderExt;
4848
use datafusion_distributed::{
@@ -55,7 +55,6 @@ use std::path::PathBuf;
5555
use std::sync::Arc;
5656
use structopt::StructOpt;
5757
use tokio::net::TcpListener;
58-
use tokio::task::JoinHandle;
5958

6059
/// Run the tpch benchmark.
6160
///
@@ -160,51 +159,64 @@ impl DistributedSessionBuilder for RunOpt {
160159
}
161160

162161
impl RunOpt {
163-
pub fn spawn_workers(self) -> Vec<(tokio::runtime::Runtime, JoinHandle<()>)> {
162+
pub fn run(self) -> Result<()> {
164163
let ports = get_free_ports(self.workers);
165-
let channel_resolver = LocalHostChannelResolver::new(ports.clone());
164+
165+
let _handle = self.clone().spawn_workers(ports.clone());
166+
drop(_handle);
167+
168+
let rt = tokio::runtime::Builder::new_multi_thread()
169+
.worker_threads(self.threads.unwrap_or(get_available_parallelism()))
170+
.enable_all()
171+
.build()?;
172+
173+
rt.block_on(async move { self.run_local(ports).await })
174+
}
175+
176+
pub fn spawn_workers(self, ports: Vec<u16>) -> Vec<std::thread::JoinHandle<()>> {
166177
let threads_per_worker = self.threads;
178+
let ports_copy = ports.clone();
167179
let session_builder = self.map(move |builder: SessionStateBuilder| {
168-
let channel_resolver = channel_resolver.clone();
180+
let channel_resolver = LocalHostChannelResolver::new(ports.clone());
169181
Ok(builder
170182
.with_distributed_channel_resolver(channel_resolver)
171183
.build())
172184
});
173185
let mut handles = vec![];
174-
for port in ports {
186+
for port in ports_copy {
175187
let session_builder = session_builder.clone();
176-
let rt = tokio::runtime::Builder::new_multi_thread()
177-
.worker_threads(threads_per_worker.unwrap_or(get_available_parallelism()))
178-
.enable_all()
179-
.build()
180-
.unwrap();
181-
let handle = rt.spawn(async move {
182-
let listener = TcpListener::bind(format!("127.0.0.1:{port}"))
183-
.await
184-
.unwrap();
185-
spawn_flight_service(session_builder, listener)
186-
.await
188+
let handle = std::thread::spawn(move || {
189+
let rt = tokio::runtime::Builder::new_multi_thread()
190+
.worker_threads(threads_per_worker.unwrap_or(get_available_parallelism()))
191+
.enable_all()
192+
.build()
187193
.unwrap();
194+
rt.block_on(async move {
195+
let listener = TcpListener::bind(format!("127.0.0.1:{port}"))
196+
.await
197+
.unwrap();
198+
spawn_flight_service(session_builder, listener)
199+
.await
200+
.unwrap();
201+
})
188202
});
189203

190-
handles.push((rt, handle));
204+
handles.push(handle);
191205
}
192206
handles
193207
}
194208

195-
pub fn run(self) -> Result<()> {
196-
let _handle = self.clone().spawn_workers();
197-
198-
let rt = tokio::runtime::Builder::new_multi_thread()
199-
.worker_threads(self.threads.unwrap_or(get_available_parallelism()))
200-
.enable_all()
201-
.build()?;
202-
203-
rt.block_on(async move { self._run().await })
204-
}
205-
206-
pub async fn _run(mut self) -> Result<()> {
207-
let (ctx, _guard) = start_localhost_context(1, self.clone()).await;
209+
async fn run_local(mut self, ports: Vec<u16>) -> Result<()> {
210+
let session_builder = self.clone().map(move |builder: SessionStateBuilder| {
211+
let channel_resolver = LocalHostChannelResolver::new(ports.clone());
212+
Ok(builder
213+
.with_distributed_channel_resolver(channel_resolver)
214+
.build())
215+
});
216+
let state = session_builder
217+
.build_session_state(DistributedSessionBuilderContext::default())
218+
.await?;
219+
let ctx = SessionContext::new_with_state(state);
208220
println!("Running benchmarks with the following options: {self:?}");
209221
let query_range = match self.query {
210222
Some(query_id) => query_id..=query_id,
@@ -215,6 +227,22 @@ impl RunOpt {
215227
.get_or_insert(self.get_path()?.join("results.json"));
216228
let mut benchmark_run = BenchmarkRun::new();
217229

230+
// Warmup the cache for the in-memory mode.
231+
if self.mem_table {
232+
for query_id in query_range.clone() {
233+
// put the WarmingUpMarker in the context, otherwise, queries will fail as the
234+
// InMemoryCacheExec node will think they should already be warmed up.
235+
let sql = &get_query_sql(query_id)?;
236+
let ctx = ctx
237+
.clone()
238+
.with_distributed_option_extension(WarmingUpMarker::warming_up())?;
239+
for query in sql.iter() {
240+
self.execute_query(&ctx, query).await?;
241+
}
242+
println!("Query {query_id} data loaded in memory");
243+
}
244+
}
245+
218246
for query_id in query_range {
219247
benchmark_run.start_new_case(&format!("Query {query_id}"));
220248
let query_run = self.benchmark_query(query_id, &ctx).await;
@@ -226,7 +254,7 @@ impl RunOpt {
226254
}
227255
Err(e) => {
228256
benchmark_run.mark_failed();
229-
eprintln!("Query {query_id} failed: {e}");
257+
eprintln!("Query {query_id} failed: {e:?}");
230258
}
231259
}
232260
}
@@ -247,19 +275,6 @@ impl RunOpt {
247275

248276
let sql = &get_query_sql(query_id)?;
249277

250-
// Warmup the cache for the in-memory mode.
251-
if self.mem_table {
252-
// put the WarmingUpMarker in the context, otherwise, queries will fail as the
253-
// InMemoryCacheExec node will think they should already be warmed up.
254-
let ctx = ctx
255-
.clone()
256-
.with_distributed_option_extension(WarmingUpMarker::warming_up())?;
257-
for query in sql.iter() {
258-
self.execute_query(&ctx, query).await?;
259-
}
260-
println!("Query {query_id} data loaded in memory");
261-
}
262-
263278
for i in 0..self.iterations() {
264279
let start = Instant::now();
265280
let mut result = vec![];

0 commit comments

Comments
 (0)