@@ -34,17 +34,21 @@ use datafusion::datasource::file_format::FileFormat;
3434use datafusion:: datasource:: listing:: {
3535 ListingOptions , ListingTable , ListingTableConfig , ListingTableUrl ,
3636} ;
37- use datafusion:: datasource:: { MemTable , TableProvider } ;
37+ use datafusion:: datasource:: TableProvider ;
3838use datafusion:: error:: { DataFusionError , Result } ;
3939use datafusion:: execution:: { SessionState , SessionStateBuilder } ;
4040use datafusion:: physical_plan:: display:: DisplayableExecutionPlan ;
4141use datafusion:: physical_plan:: { collect, displayable} ;
4242use datafusion:: prelude:: * ;
4343
44- use crate :: util:: { print_memory_stats, BenchmarkRun , CommonOpt , QueryResult } ;
44+ use crate :: util:: {
45+ BenchmarkRun , CommonOpt , InMemoryCacheExecCodec , InMemoryDataSourceRule , QueryResult ,
46+ WarmingUpMarker ,
47+ } ;
4548use datafusion_distributed:: test_utils:: localhost:: start_localhost_context;
4649use datafusion_distributed:: {
47- DistributedPhysicalOptimizerRule , DistributedSessionBuilder , DistributedSessionBuilderContext ,
50+ DistributedExt , DistributedPhysicalOptimizerRule , DistributedSessionBuilder ,
51+ DistributedSessionBuilderContext ,
4852} ;
4953use log:: info;
5054use structopt:: StructOpt ;
@@ -115,18 +119,23 @@ pub struct RunOpt {
115119impl DistributedSessionBuilder for RunOpt {
116120 async fn build_session_state (
117121 & self ,
118- _ctx : DistributedSessionBuilderContext ,
122+ ctx : DistributedSessionBuilderContext ,
119123 ) -> Result < SessionState , DataFusionError > {
120124 let mut builder = SessionStateBuilder :: new ( ) . with_default_features ( ) ;
121125
122126 let config = self
123127 . common
124128 . config ( ) ?
125129 . with_collect_statistics ( !self . disable_statistics )
130+ . with_distributed_user_codec ( InMemoryCacheExecCodec )
131+ . with_distributed_option_extension_from_headers :: < WarmingUpMarker > ( & ctx. headers ) ?
126132 . with_target_partitions ( self . partitions ( ) ) ;
127133
128134 let rt_builder = self . common . runtime_env_builder ( ) ?;
129135
136+ if self . mem_table {
137+ builder = builder. with_physical_optimizer_rule ( Arc :: new ( InMemoryDataSourceRule ) ) ;
138+ }
130139 if self . distributed {
131140 let mut rule = DistributedPhysicalOptimizerRule :: new ( ) ;
132141 if let Some ( partitions_per_task) = self . partitions_per_task {
@@ -191,8 +200,18 @@ impl RunOpt {
191200
192201 let sql = & get_query_sql ( query_id) ?;
193202
194- let single_node_ctx = SessionContext :: new ( ) ;
195- self . register_tables ( & single_node_ctx) . await ?;
203+ // Warmup the cache for the in-memory mode.
204+ if self . mem_table {
205+ // put the WarmingUpMarker in the context, otherwise, queries will fail as the
206+ // InMemoryCacheExec node will think they should already be warmed up.
207+ let ctx = ctx
208+ . clone ( )
209+ . with_distributed_option_extension ( WarmingUpMarker :: warming_up ( ) ) ?;
210+ for query in sql. iter ( ) {
211+ self . execute_query ( & ctx, query) . await ?;
212+ }
213+ println ! ( "Query {query_id} data loaded in memory" ) ;
214+ }
196215
197216 for i in 0 ..self . iterations ( ) {
198217 let start = Instant :: now ( ) ;
@@ -225,30 +244,12 @@ impl RunOpt {
225244 let avg = millis. iter ( ) . sum :: < f64 > ( ) / millis. len ( ) as f64 ;
226245 println ! ( "Query {query_id} avg time: {avg:.2} ms" ) ;
227246
228- // Print memory stats using mimalloc (only when compiled with --features mimalloc_extended)
229- print_memory_stats ( ) ;
230-
231247 Ok ( query_results)
232248 }
233249
234250 async fn register_tables ( & self , ctx : & SessionContext ) -> Result < ( ) > {
235251 for table in TPCH_TABLES {
236- let table_provider = { self . get_table ( ctx, table) . await ? } ;
237-
238- if self . mem_table {
239- println ! ( "Loading table '{table}' into memory" ) ;
240- let start = Instant :: now ( ) ;
241- let memtable =
242- MemTable :: load ( table_provider, Some ( self . partitions ( ) ) , & ctx. state ( ) ) . await ?;
243- println ! (
244- "Loaded table '{}' into memory in {} ms" ,
245- table,
246- start. elapsed( ) . as_millis( )
247- ) ;
248- ctx. register_table ( * table, Arc :: new ( memtable) ) ?;
249- } else {
250- ctx. register_table ( * table, table_provider) ?;
251- }
252+ ctx. register_table ( * table, self . get_table ( ctx, table) . await ?) ?;
252253 }
253254 Ok ( ( ) )
254255 }
0 commit comments