diff --git a/src/app/app_execution.rs b/src/app/app_execution.rs index a94951bf..b4b7c8fd 100644 --- a/src/app/app_execution.rs +++ b/src/app/app_execution.rs @@ -18,25 +18,42 @@ //! [`AppExecution`]: Handles executing queries for the TUI application. use crate::app::state::tabs::sql::Query; -use crate::app::AppEvent; +use crate::app::{AppEvent, ExecutionError, ExecutionResultsBatch}; use crate::execution::ExecutionContext; use color_eyre::eyre::Result; +use datafusion::execution::context::SessionContext; +use datafusion::execution::SendableRecordBatchStream; +use datafusion::physical_plan::execute_stream; use futures::StreamExt; use log::{error, info}; use std::sync::Arc; use std::time::Duration; use tokio::sync::mpsc::UnboundedSender; +use tokio::sync::Mutex; /// Handles executing queries for the TUI application, formatting results /// and sending them to the UI. -pub(crate) struct AppExecution { +pub struct AppExecution { inner: Arc, + result_stream: Arc>>, } impl AppExecution { /// Create a new instance of [`AppExecution`]. pub fn new(inner: Arc) -> Self { - Self { inner } + Self { + inner, + result_stream: Arc::new(Mutex::new(None)), + } + } + + pub fn session_ctx(&self) -> &SessionContext { + self.inner.session_ctx() + } + + pub async fn set_result_stream(&self, stream: SendableRecordBatchStream) { + let mut s = self.result_stream.lock().await; + *s = Some(stream) } /// Run the sequence of SQL queries, sending the results as [`AppEvent::QueryResult`] via the sender. @@ -60,33 +77,53 @@ impl AppExecution { let start = std::time::Instant::now(); if i == statement_count - 1 { info!("Executing last query and display results"); - match self.inner.execute_sql(sql).await { - Ok(mut stream) => { - let mut batches = Vec::new(); - while let Some(maybe_batch) = stream.next().await { - match maybe_batch { - Ok(batch) => { - batches.push(batch); - } - Err(e) => { - let elapsed = start.elapsed(); - query.set_error(Some(e.to_string())); - query.set_execution_time(elapsed); - break; + sender.send(AppEvent::NewExecution)?; + match self.inner.create_physical_plan(sql).await { + Ok(plan) => match execute_stream(plan, self.inner.session_ctx().task_ctx()) { + Ok(stream) => { + self.set_result_stream(stream).await; + let mut stream = self.result_stream.lock().await; + if let Some(s) = stream.as_mut() { + if let Some(b) = s.next().await { + match b { + Ok(b) => { + let duration = start.elapsed(); + let results = ExecutionResultsBatch { + query: sql.to_string(), + batch: b, + duration, + }; + sender.send(AppEvent::ExecutionResultsNextPage( + results, + ))?; + } + Err(e) => { + error!("Error getting RecordBatch: {:?}", e); + } + } } } } + Err(stream_err) => { + error!("Error creating physical plan: {:?}", stream_err); + let elapsed = start.elapsed(); + let e = ExecutionError { + query: sql.to_string(), + error: stream_err.to_string(), + duration: elapsed, + }; + sender.send(AppEvent::ExecutionResultsError(e))?; + } + }, + Err(plan_err) => { + error!("Error creating physical plan: {:?}", plan_err); let elapsed = start.elapsed(); - let rows: usize = batches.iter().map(|r| r.num_rows()).sum(); - query.set_results(Some(batches)); - query.set_num_rows(Some(rows)); - query.set_execution_time(elapsed); - } - Err(e) => { - error!("Error creating dataframe: {:?}", e); - let elapsed = start.elapsed(); - query.set_error(Some(e.to_string())); - query.set_execution_time(elapsed); + let e = ExecutionError { + query: sql.to_string(), + error: plan_err.to_string(), + duration: elapsed, + }; + sender.send(AppEvent::ExecutionResultsError(e))?; } } } else { @@ -107,4 +144,27 @@ impl AppExecution { } Ok(()) } + + pub async fn next_batch(&self, sql: String, sender: UnboundedSender) { + let mut stream = self.result_stream.lock().await; + if let Some(s) = stream.as_mut() { + let start = std::time::Instant::now(); + if let Some(b) = s.next().await { + match b { + Ok(b) => { + let duration = start.elapsed(); + let results = ExecutionResultsBatch { + query: sql, + batch: b, + duration, + }; + let _ = sender.send(AppEvent::ExecutionResultsNextPage(results)); + } + Err(e) => { + error!("Error getting RecordBatch: {:?}", e); + } + } + } + } + } } diff --git a/src/app/handlers/flightsql.rs b/src/app/handlers/flightsql.rs index c849e129..a2b98dd6 100644 --- a/src/app/handlers/flightsql.rs +++ b/src/app/handlers/flightsql.rs @@ -66,74 +66,74 @@ pub fn normal_mode_handler(app: &mut App, key: KeyEvent) { } } - KeyCode::Enter => { - info!("Run FS query"); - let sql = app.state.flightsql_tab.editor().lines().join(""); - info!("SQL: {}", sql); - let execution = Arc::clone(&app.execution); - let _event_tx = app.event_tx(); - tokio::spawn(async move { - let client = execution.flightsql_client(); - let mut query = - FlightSQLQuery::new(sql.clone(), None, None, None, Duration::default(), None); - let start = Instant::now(); - if let Some(ref mut c) = *client.lock().await { - info!("Sending query"); - match c.execute(sql, None).await { - Ok(flight_info) => { - for endpoint in flight_info.endpoint { - if let Some(ticket) = endpoint.ticket { - match c.do_get(ticket.into_request()).await { - Ok(mut stream) => { - let mut batches: Vec = Vec::new(); - // temporarily only show the first batch to avoid - // buffering massive result sets. Eventually there should - // be some sort of paging logic - // see https://github.com/datafusion-contrib/datafusion-tui/pull/133#discussion_r1756680874 - // while let Some(maybe_batch) = stream.next().await { - if let Some(maybe_batch) = stream.next().await { - match maybe_batch { - Ok(batch) => { - info!("Batch rows: {}", batch.num_rows()); - batches.push(batch); - } - Err(e) => { - error!("Error getting batch: {:?}", e); - let elapsed = start.elapsed(); - query.set_error(Some(e.to_string())); - query.set_execution_time(elapsed); - } - } - } - let elapsed = start.elapsed(); - let rows: usize = - batches.iter().map(|r| r.num_rows()).sum(); - query.set_results(Some(batches)); - query.set_num_rows(Some(rows)); - query.set_execution_time(elapsed); - } - Err(e) => { - error!("Error getting response: {:?}", e); - let elapsed = start.elapsed(); - query.set_error(Some(e.to_string())); - query.set_execution_time(elapsed); - } - } - } - } - } - Err(e) => { - error!("Error getting response: {:?}", e); - let elapsed = start.elapsed(); - query.set_error(Some(e.to_string())); - query.set_execution_time(elapsed); - } - } - } - - let _ = _event_tx.send(AppEvent::FlightSQLQueryResult(query)); - }); - } + // KeyCode::Enter => { + // info!("Run FS query"); + // let sql = app.state.flightsql_tab.editor().lines().join(""); + // info!("SQL: {}", sql); + // let execution = Arc::clone(&app.execution); + // let _event_tx = app.event_tx(); + // tokio::spawn(async move { + // let client = execution.flightsql_client(); + // let mut query = + // FlightSQLQuery::new(sql.clone(), None, None, None, Duration::default(), None); + // let start = Instant::now(); + // if let Some(ref mut c) = *client.lock().await { + // info!("Sending query"); + // match c.execute(sql, None).await { + // Ok(flight_info) => { + // for endpoint in flight_info.endpoint { + // if let Some(ticket) = endpoint.ticket { + // match c.do_get(ticket.into_request()).await { + // Ok(mut stream) => { + // let mut batches: Vec = Vec::new(); + // // temporarily only show the first batch to avoid + // // buffering massive result sets. Eventually there should + // // be some sort of paging logic + // // see https://github.com/datafusion-contrib/datafusion-tui/pull/133#discussion_r1756680874 + // // while let Some(maybe_batch) = stream.next().await { + // if let Some(maybe_batch) = stream.next().await { + // match maybe_batch { + // Ok(batch) => { + // info!("Batch rows: {}", batch.num_rows()); + // batches.push(batch); + // } + // Err(e) => { + // error!("Error getting batch: {:?}", e); + // let elapsed = start.elapsed(); + // query.set_error(Some(e.to_string())); + // query.set_execution_time(elapsed); + // } + // } + // } + // let elapsed = start.elapsed(); + // let rows: usize = + // batches.iter().map(|r| r.num_rows()).sum(); + // query.set_results(Some(batches)); + // query.set_num_rows(Some(rows)); + // query.set_execution_time(elapsed); + // } + // Err(e) => { + // error!("Error getting response: {:?}", e); + // let elapsed = start.elapsed(); + // query.set_error(Some(e.to_string())); + // query.set_execution_time(elapsed); + // } + // } + // } + // } + // } + // Err(e) => { + // error!("Error getting response: {:?}", e); + // let elapsed = start.elapsed(); + // query.set_error(Some(e.to_string())); + // query.set_execution_time(elapsed); + // } + // } + // } + // + // let _ = _event_tx.send(AppEvent::FlightSQLQueryResult(query)); + // }); + // } _ => {} } } diff --git a/src/app/handlers/mod.rs b/src/app/handlers/mod.rs index 62e825b3..c77bba68 100644 --- a/src/app/handlers/mod.rs +++ b/src/app/handlers/mod.rs @@ -25,6 +25,7 @@ use ratatui::crossterm::event::{self, KeyCode, KeyEvent}; use tui_logger::TuiWidgetEvent; use crate::app::state::tabs::history::Context; +use crate::app::ExecutionResultsBatch; #[cfg(feature = "flightsql")] use arrow_flight::sql::client::FlightSqlServiceClient; @@ -148,8 +149,6 @@ fn context_tab_app_event_handler(app: &mut App, event: AppEvent) { } pub fn app_event_handler(app: &mut App, event: AppEvent) -> Result<()> { - // TODO: AppEvent::QueryResult can probably be handled here rather than duplicating in - // each tab trace!("Tui::Event: {:?}", event); let now = std::time::Instant::now(); match event { @@ -180,17 +179,35 @@ pub fn app_event_handler(app: &mut App, event: AppEvent) -> Result<()> { } }); } - AppEvent::QueryResult(r) => { - app.state.sql_tab.set_query(r.clone()); - app.state.sql_tab.refresh_query_results_state(); + AppEvent::NewExecution => { + app.state.sql_tab.reset_execution_results(); + } + AppEvent::ExecutionResultsError(e) => { + app.state.sql_tab.set_execution_error(e.clone()); let history_query = HistoryQuery::new( Context::Local, - r.sql().clone(), - *r.execution_time(), - r.execution_stats().clone(), + e.query().to_string(), + *e.duration(), + None, + Some(e.error().to_string()), ); + info!("Adding to history: {:?}", history_query); app.state.history_tab.add_to_history(history_query); - app.state.history_tab.refresh_history_table_state() + app.state.history_tab.refresh_history_table_state(); + } + AppEvent::ExecutionResultsNextPage(r) => { + let ExecutionResultsBatch { + query, + duration, + batch, + } = r; + app.state.sql_tab.add_batch(batch); + app.state.sql_tab.next_page(); + app.state.sql_tab.refresh_query_results_state(); + let history_query = + HistoryQuery::new(Context::Local, query.to_string(), duration, None, None); + app.state.history_tab.add_to_history(history_query); + app.state.history_tab.refresh_history_table_state(); } #[cfg(feature = "flightsql")] AppEvent::FlightSQLQueryResult(r) => { @@ -201,33 +218,34 @@ pub fn app_event_handler(app: &mut App, event: AppEvent) -> Result<()> { r.sql().clone(), *r.execution_time(), r.execution_stats().clone(), + None, ); app.state.history_tab.add_to_history(history_query); app.state.history_tab.refresh_history_table_state() } - #[cfg(feature = "flightsql")] - AppEvent::EstablishFlightSQLConnection => { - let url = app.state.config.flightsql.connection_url.clone(); - info!("Connection to FlightSQL host: {}", url); - let url: &'static str = Box::leak(url.into_boxed_str()); - let execution = Arc::clone(&app.execution); - tokio::spawn(async move { - let client = execution.flightsql_client(); - let maybe_channel = Channel::from_static(url).connect().await; - info!("Created channel"); - match maybe_channel { - Ok(channel) => { - let flightsql_client = FlightSqlServiceClient::new(channel); - let mut locked_client = client.lock().await; - *locked_client = Some(flightsql_client); - info!("Connected to FlightSQL host"); - } - Err(e) => { - info!("Error creating channel for FlightSQL: {:?}", e); - } - } - }); - } + // #[cfg(feature = "flightsql")] + // AppEvent::EstablishFlightSQLConnection => { + // let url = app.state.config.flightsql.connection_url.clone(); + // info!("Connection to FlightSQL host: {}", url); + // let url: &'static str = Box::leak(url.into_boxed_str()); + // let execution = Arc::clone(&app.execution); + // tokio::spawn(async move { + // let client = execution.flightsql_client(); + // let maybe_channel = Channel::from_static(url).connect().await; + // info!("Created channel"); + // match maybe_channel { + // Ok(channel) => { + // let flightsql_client = FlightSqlServiceClient::new(channel); + // let mut locked_client = client.lock().await; + // *locked_client = Some(flightsql_client); + // info!("Connected to FlightSQL host"); + // } + // Err(e) => { + // info!("Error creating channel for FlightSQL: {:?}", e); + // } + // } + // }); + // } _ => { match app.state.tabs.selected { SelectedTab::SQL => sql::app_event_handler(app, event), diff --git a/src/app/handlers/sql.rs b/src/app/handlers/sql.rs index 4e6fa5f7..4812dfd4 100644 --- a/src/app/handlers/sql.rs +++ b/src/app/handlers/sql.rs @@ -15,17 +15,13 @@ // specific language governing permissions and limitations // under the License. -use std::{sync::Arc, time::Instant}; +use std::sync::Arc; -use datafusion::{arrow::array::RecordBatch, physical_plan::execute_stream}; -use log::{error, info}; +use log::info; use ratatui::crossterm::event::{KeyCode, KeyEvent, KeyModifiers}; -use tokio_stream::StreamExt; use super::App; -use crate::app::app_execution::AppExecution; -use crate::app::{handlers::tab_navigation_handler, state::tabs::sql::Query, AppEvent}; -use crate::execution::collect_plan_stats; +use crate::app::{handlers::tab_navigation_handler, AppEvent}; pub fn normal_mode_handler(app: &mut App, key: KeyEvent) { match key.code { @@ -62,19 +58,43 @@ pub fn normal_mode_handler(app: &mut App, key: KeyEvent) { } KeyCode::Enter => { - info!("Run query"); let sql = app.state.sql_tab.editor().lines().join(""); - info!("SQL: {}", sql); - let app_execution = AppExecution::new(Arc::clone(&app.execution)); + info!("Running query: {}", sql); let _event_tx = app.event_tx().clone(); - // TODO: Maybe this should be on a separate runtime to prevent blocking main thread / - // runtime - // TODO: Extract this into function to be used in both normal and editable handler + let execution = Arc::clone(&app.execution); + // TODO: Extract this into function to be used in both normal and editable handler. + // Only useful if we get Ctrl / Cmd + Enter to work in editable mode though. tokio::spawn(async move { let sqls: Vec<&str> = sql.split(';').collect(); - let _ = app_execution.run_sqls(sqls, _event_tx).await; + let _ = execution.run_sqls(sqls, _event_tx).await; }); } + KeyCode::Right => { + let _event_tx = app.event_tx().clone(); + if let (Some(p), c) = ( + app.state().sql_tab.results_page(), + app.state().sql_tab.batches_count(), + ) { + // We don't need to fetch the next batch if moving forward a page and we're not + // on the last page since we would have already fetched it. + if p < c - 1 { + app.state.sql_tab.next_page(); + app.state.sql_tab.refresh_query_results_state(); + return; + } + } + if let Some(p) = app.state.history_tab.history().last() { + let execution = Arc::clone(&app.execution); + let sql = p.sql().clone(); + tokio::spawn(async move { + execution.next_batch(sql, _event_tx).await; + }); + } + } + KeyCode::Left => { + app.state.sql_tab.previous_page(); + app.state.sql_tab.refresh_query_results_state(); + } _ => {} } } @@ -85,54 +105,6 @@ pub fn editable_handler(app: &mut App, key: KeyEvent) { (KeyCode::Right, KeyModifiers::ALT) => app.state.sql_tab.next_word(), (KeyCode::Backspace, KeyModifiers::ALT) => app.state.sql_tab.delete_word(), (KeyCode::Esc, _) => app.state.sql_tab.exit_edit(), - (KeyCode::Enter, KeyModifiers::CONTROL) => { - let query = app.state.sql_tab.editor().lines().join(""); - let ctx = app.execution.session_ctx().clone(); - let _event_tx = app.event_tx(); - // TODO: Maybe this should be on a separate runtime to prevent blocking main thread / - // runtime - tokio::spawn(async move { - // TODO: Turn this into a match and return the error somehow - let start = Instant::now(); - if let Ok(df) = ctx.sql(&query).await { - let plan = df.create_physical_plan().await; - match plan { - Ok(p) => { - let task_ctx = ctx.task_ctx(); - let stream = execute_stream(Arc::clone(&p), task_ctx); - let mut batches: Vec = Vec::new(); - match stream { - Ok(mut s) => { - while let Some(b) = s.next().await { - match b { - Ok(b) => batches.push(b), - Err(e) => { - error!("Error getting RecordBatch: {:?}", e) - } - } - } - - let elapsed = start.elapsed(); - let stats = collect_plan_stats(p); - info!("Got stats: {:?}", stats); - let query = - Query::new(query, Some(batches), None, None, elapsed, None); - let _ = _event_tx.send(AppEvent::QueryResult(query)); - } - Err(e) => { - error!("Error creating RecordBatchStream: {:?}", e) - } - } - } - Err(e) => { - error!("Error creating physical plan: {:?}", e) - } - } - } else { - error!("Error creating dataframe") - } - }); - } _ => app.state.sql_tab.update_editor_content(key), } } diff --git a/src/app/mod.rs b/src/app/mod.rs index 4030a671..a95ae85d 100644 --- a/src/app/mod.rs +++ b/src/app/mod.rs @@ -23,6 +23,7 @@ pub mod ui; use color_eyre::eyre::eyre; use color_eyre::Result; use crossterm::event as ct; +use datafusion::arrow::array::RecordBatch; use futures::FutureExt; use log::{debug, error, info, trace}; use ratatui::backend::CrosstermBackend; @@ -32,12 +33,14 @@ use ratatui::crossterm::{ }; use ratatui::{prelude::*, style::palette::tailwind, widgets::*}; use std::sync::Arc; +use std::time::Duration; use strum::IntoEnumIterator; use tokio::sync::mpsc::{self, UnboundedReceiver, UnboundedSender}; use tokio::task::JoinHandle; use tokio_stream::StreamExt; use tokio_util::sync::CancellationToken; +use self::app_execution::AppExecution; use self::handlers::{app_event_handler, crossterm_event_handler}; use self::state::tabs::sql::Query; use crate::execution::ExecutionContext; @@ -46,6 +49,64 @@ use crate::execution::ExecutionContext; use self::state::tabs::flightsql::FlightSQLQuery; #[derive(Clone, Debug)] +pub struct ExecutionError { + query: String, + error: String, + duration: Duration, +} + +#[derive(Clone, Debug)] +pub struct ExecutionResultsBatch { + query: String, + batch: RecordBatch, + duration: Duration, +} + +impl ExecutionResultsBatch { + pub fn new(query: String, batch: RecordBatch, duration: Duration) -> Self { + Self { + query, + batch, + duration, + } + } + + pub fn query(&self) -> &str { + &self.query + } + + pub fn batch(&self) -> &RecordBatch { + &self.batch + } + + pub fn duration(&self) -> &Duration { + &self.duration + } +} + +impl ExecutionError { + pub fn new(query: String, error: String, duration: Duration) -> Self { + Self { + query, + error, + duration, + } + } + + pub fn query(&self) -> &str { + &self.query + } + + pub fn error(&self) -> &str { + &self.error + } + + pub fn duration(&self) -> &Duration { + &self.duration + } +} + +#[derive(Debug)] pub enum AppEvent { Key(event::KeyEvent), Error, @@ -59,7 +120,11 @@ pub enum AppEvent { Mouse(event::MouseEvent), Resize(u16, u16), ExecuteDDL(String), + NewExecution, QueryResult(Query), + ExecutionResultsNextPage(ExecutionResultsBatch), + ExecutionResultsPreviousPage, + ExecutionResultsError(ExecutionError), #[cfg(feature = "flightsql")] EstablishFlightSQLConnection, #[cfg(feature = "flightsql")] @@ -68,7 +133,7 @@ pub enum AppEvent { pub struct App<'app> { state: state::AppState<'app>, - execution: Arc, + execution: Arc, event_tx: UnboundedSender, event_rx: UnboundedReceiver, cancellation_token: CancellationToken, @@ -80,6 +145,7 @@ impl<'app> App<'app> { let (event_tx, event_rx) = mpsc::unbounded_channel(); let cancellation_token = CancellationToken::new(); let task = tokio::spawn(async {}); + let app_execution = Arc::new(AppExecution::new(Arc::new(execution))); Self { state, @@ -87,7 +153,7 @@ impl<'app> App<'app> { event_rx, event_tx, cancellation_token, - execution: Arc::new(execution), + execution: app_execution, } } @@ -99,7 +165,7 @@ impl<'app> App<'app> { &mut self.event_rx } - pub fn execution(&self) -> Arc { + pub fn execution(&self) -> Arc { Arc::clone(&self.execution) } @@ -318,7 +384,7 @@ impl App<'_> { loop { let event = app.next().await?; - if let AppEvent::Render = event.clone() { + if let AppEvent::Render = &event { terminal.draw(|f| f.render_widget(&app, f.area()))?; }; diff --git a/src/app/state/tabs/history.rs b/src/app/state/tabs/history.rs index 76461206..7ed6d77e 100644 --- a/src/app/state/tabs/history.rs +++ b/src/app/state/tabs/history.rs @@ -43,6 +43,7 @@ pub struct HistoryQuery { sql: String, execution_time: Duration, execution_stats: Option, + _error: Option, } impl HistoryQuery { @@ -51,12 +52,14 @@ impl HistoryQuery { sql: String, execution_time: Duration, execution_stats: Option, + _error: Option, ) -> Self { Self { context, sql, execution_time, execution_stats, + _error, } } pub fn sql(&self) -> &String { @@ -71,13 +74,13 @@ impl HistoryQuery { &self.execution_stats } - pub fn scanned_bytes(&self) -> usize { - if let Some(stats) = &self.execution_stats { - stats.bytes_scanned() - } else { - 0 - } - } + // pub fn scanned_bytes(&self) -> usize { + // if let Some(stats) = &self.execution_stats { + // stats.bytes_scanned() + // } else { + // 0 + // } + // } pub fn context(&self) -> &Context { &self.context diff --git a/src/app/state/tabs/sql.rs b/src/app/state/tabs/sql.rs index 8b4c35c4..8155cc7e 100644 --- a/src/app/state/tabs/sql.rs +++ b/src/app/state/tabs/sql.rs @@ -26,6 +26,7 @@ use ratatui::style::{Modifier, Style}; use ratatui::widgets::TableState; use tui_textarea::TextArea; +use crate::app::ExecutionError; use crate::config::AppConfig; use crate::execution::ExecutionStats; @@ -130,6 +131,9 @@ pub struct SQLTabState<'app> { editor_editable: bool, query: Option, query_results_state: Option>, + result_batches: Option>, + results_page: Option, + execution_error: Option, } impl<'app> SQLTabState<'app> { @@ -147,6 +151,9 @@ impl<'app> SQLTabState<'app> { editor_editable: false, query: None, query_results_state: None, + result_batches: None, + results_page: None, + execution_error: None, } } @@ -158,6 +165,13 @@ impl<'app> SQLTabState<'app> { self.query_results_state = Some(RefCell::new(TableState::default())); } + pub fn reset_execution_results(&mut self) { + self.result_batches = None; + self.results_page = None; + self.execution_error = None; + self.refresh_query_results_state(); + } + pub fn editor(&self) -> TextArea { // TODO: Figure out how to do this without clone. Probably need logic in handler to make // updates to the Widget and then pass a ref @@ -223,4 +237,55 @@ impl<'app> SQLTabState<'app> { pub fn delete_word(&mut self) { self.editor.delete_word(); } + + pub fn add_batch(&mut self, batch: RecordBatch) { + if let Some(batches) = self.result_batches.as_mut() { + batches.push(batch); + } else { + self.result_batches = Some(vec![batch]); + } + } + + pub fn current_batch(&self) -> Option<&RecordBatch> { + match (self.results_page, self.result_batches.as_ref()) { + (Some(page), Some(batches)) => batches.get(page), + _ => None, + } + } + + pub fn batches_count(&self) -> usize { + if let Some(batches) = &self.result_batches { + batches.len() + } else { + 0 + } + } + + pub fn execution_error(&self) -> &Option { + &self.execution_error + } + + pub fn set_execution_error(&mut self, error: ExecutionError) { + self.execution_error = Some(error); + } + + pub fn results_page(&self) -> Option { + self.results_page + } + + pub fn next_page(&mut self) { + if let Some(page) = self.results_page { + self.results_page = Some(page + 1); + } else { + self.results_page = Some(0); + } + } + + pub fn previous_page(&mut self) { + if let Some(page) = self.results_page { + if page > 0 { + self.results_page = Some(page - 1); + } + } + } } diff --git a/src/app/ui/convert.rs b/src/app/ui/convert.rs index 1d129df3..d3a01a28 100644 --- a/src/app/ui/convert.rs +++ b/src/app/ui/convert.rs @@ -166,7 +166,7 @@ pub fn empty_results_table<'frame>() -> Table<'frame> { } pub fn record_batches_to_table<'frame, 'results>( - record_batches: &'results [RecordBatch], + record_batches: &'results [&RecordBatch], ) -> Result> where // The results come from sql_tab state which persists until the next query is run which is diff --git a/src/app/ui/tabs/flightsql.rs b/src/app/ui/tabs/flightsql.rs index daf9ba58..b8bdc9a2 100644 --- a/src/app/ui/tabs/flightsql.rs +++ b/src/app/ui/tabs/flightsql.rs @@ -54,25 +54,25 @@ pub fn render_sql_results(area: Rect, buf: &mut Buffer, app: &App) { )) .fg(tailwind::WHITE); let block = block.title_bottom(stats).fg(tailwind::ORANGE.c500); - let maybe_table = record_batches_to_table(r); - match maybe_table { - Ok(table) => { - let table = table - .highlight_style( - Style::default().bg(tailwind::WHITE).fg(tailwind::BLACK), - ) - .block(block); - - let mut s = s.borrow_mut(); - StatefulWidget::render(table, area, buf, &mut s); - } - Err(e) => { - let row = Row::new(vec![e.to_string()]); - let widths = vec![Constraint::Percentage(100)]; - let table = Table::new(vec![row], widths).block(block); - Widget::render(table, area, buf); - } - } + // let maybe_table = record_batches_to_table(r); + // match maybe_table { + // Ok(table) => { + // let table = table + // .highlight_style( + // Style::default().bg(tailwind::WHITE).fg(tailwind::BLACK), + // ) + // .block(block); + // + // let mut s = s.borrow_mut(); + // StatefulWidget::render(table, area, buf, &mut s); + // } + // Err(e) => { + // let row = Row::new(vec![e.to_string()]); + // let widths = vec![Constraint::Percentage(100)]; + // let table = Table::new(vec![row], widths).block(block); + // Widget::render(table, area, buf); + // } + // } } } else if let Some(e) = q.error() { let row = Row::new(vec![e.to_string()]); diff --git a/src/app/ui/tabs/history.rs b/src/app/ui/tabs/history.rs index 3f0988b5..a9ee3aa9 100644 --- a/src/app/ui/tabs/history.rs +++ b/src/app/ui/tabs/history.rs @@ -62,7 +62,9 @@ pub fn render_query_history(area: Rect, buf: &mut Buffer, app: &App) { .title(" Query History ") .borders(Borders::ALL); let history = app.state.history_tab.history(); + info!("History: {:?}", history); let history_table_state = app.state.history_tab.history_table_state(); + info!("History Table State: {:?}", history_table_state); match (history.is_empty(), history_table_state) { (true, _) | (_, None) => { let row = Row::new(vec!["Your query history will show here"]); @@ -85,7 +87,14 @@ pub fn render_query_history(area: Rect, buf: &mut Buffer, app: &App) { Cell::from(q.context().as_str()), Cell::from(q.sql().as_str()), Cell::from(q.execution_time().as_millis().to_string()), - Cell::from(q.scanned_bytes().to_string()), + // Not sure showing scanned_bytes is useful anymore in the context of + // paginated queries. Hard coding to zero for now but this will need to be + // revisted. One option I have is removing these type of stats from the + // query history table (so we only show execution time) and then + // _anything_ ExecutionPlan statistics related is shown in the lower pane + // and their is a `analyze` mode that runs the query to completion and + // collects all stats to show in a table next to the query. + Cell::from(0.to_string()), ]) }) .collect(); @@ -96,8 +105,9 @@ pub fn render_query_history(area: Rect, buf: &mut Buffer, app: &App) { Cell::from("Execution Time(ms)"), Cell::from("Scanned Bytes"), ]) - .bg(tailwind::WHITE) - .fg(tailwind::BLACK); + .bg(tailwind::ORANGE.c300) + .fg(tailwind::BLACK) + .bold(); let table = Table::new(rows, widths).header(header).block(block.clone()); let table = table diff --git a/src/app/ui/tabs/sql.rs b/src/app/ui/tabs/sql.rs index 16ee0cb8..10f656de 100644 --- a/src/app/ui/tabs/sql.rs +++ b/src/app/ui/tabs/sql.rs @@ -20,7 +20,7 @@ use ratatui::{ layout::{Alignment, Constraint, Direction, Layout, Rect}, style::{palette::tailwind, Style, Stylize}, text::Span, - widgets::{Block, Borders, Paragraph, Row, StatefulWidget, Table, Widget}, + widgets::{block::Title, Block, Borders, Paragraph, Row, StatefulWidget, Table, Widget}, }; use crate::app::ui::convert::record_batches_to_table; @@ -44,48 +44,62 @@ pub fn render_sql_editor(area: Rect, buf: &mut Buffer, app: &App) { } pub fn render_sql_results(area: Rect, buf: &mut Buffer, app: &App) { - let block = Block::default().title(" Results ").borders(Borders::ALL); - if let Some(q) = app.state.sql_tab.query() { - if let Some(r) = q.results() { - if let Some(s) = app.state.sql_tab.query_results_state() { - let stats = Span::from(format!( - " {} rows in {}ms ", - q.num_rows().unwrap_or(0), - q.execution_time().as_millis() - )) - .fg(tailwind::WHITE); - let block = block.title_bottom(stats).fg(tailwind::ORANGE.c500); - let maybe_table = record_batches_to_table(r); - match maybe_table { - Ok(table) => { - let table = table - .highlight_style( - Style::default().bg(tailwind::WHITE).fg(tailwind::BLACK), - ) - .block(block); + // TODO: Change this to a match on state and batch + let sql_tab = &app.state.sql_tab; + match ( + sql_tab.current_batch(), + sql_tab.results_page(), + sql_tab.query_results_state(), + sql_tab.execution_error(), + ) { + (Some(batch), Some(p), Some(s), None) => { + let block = Block::default() + .title(" Results ") + .borders(Borders::ALL) + .title(Title::from(format!(" Page {p} ")).alignment(Alignment::Right)); + let batches = vec![batch]; + let maybe_table = record_batches_to_table(&batches); - let mut s = s.borrow_mut(); - StatefulWidget::render(table, area, buf, &mut s); - } - Err(e) => { - let row = Row::new(vec![e.to_string()]); - let widths = vec![Constraint::Percentage(100)]; - let table = Table::new(vec![row], widths).block(block); - Widget::render(table, area, buf); - } + let block = block.title_bottom("Stats").fg(tailwind::ORANGE.c500); + match maybe_table { + Ok(table) => { + let table = table + .highlight_style(Style::default().bg(tailwind::WHITE).fg(tailwind::BLACK)) + .block(block); + + let mut s = s.borrow_mut(); + StatefulWidget::render(table, area, buf, &mut s); + } + Err(e) => { + let row = Row::new(vec![e.to_string()]); + let widths = vec![Constraint::Percentage(100)]; + let table = Table::new(vec![row], widths).block(block); + Widget::render(table, area, buf); } } - } else if let Some(e) = q.error() { - let row = Row::new(vec![e.to_string()]); + } + (_, _, _, Some(e)) => { + let dur = e.duration().as_millis(); + let block = Block::default() + .title(" Results ") + .borders(Borders::ALL) + .title(Title::from(" Page ").alignment(Alignment::Right)) + .title_bottom(format!(" {}ms ", dur)); + let row = Row::new(vec![e.error().to_string()]); + let widths = vec![Constraint::Percentage(100)]; + let table = Table::new(vec![row], widths).block(block); + Widget::render(table, area, buf); + } + _ => { + let block = Block::default() + .title(" Results ") + .borders(Borders::ALL) + .title(Title::from(" Page ").alignment(Alignment::Right)); + let row = Row::new(vec!["Run a query to generate results"]); let widths = vec![Constraint::Percentage(100)]; let table = Table::new(vec![row], widths).block(block); Widget::render(table, area, buf); } - } else { - let row = Row::new(vec!["Run a query to generate results"]); - let widths = vec![Constraint::Percentage(100)]; - let table = Table::new(vec![row], widths).block(block); - Widget::render(table, area, buf); } } diff --git a/src/execution/mod.rs b/src/execution/mod.rs index d0c6b5ea..9db64772 100644 --- a/src/execution/mod.rs +++ b/src/execution/mod.rs @@ -19,10 +19,13 @@ //! mod stats; +use std::sync::Arc; + pub use stats::{collect_plan_stats, ExecutionStats}; use color_eyre::eyre::Result; use datafusion::execution::SendableRecordBatchStream; +use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::*; use datafusion::sql::parser::Statement; use tokio_stream::StreamExt; @@ -54,6 +57,12 @@ pub struct ExecutionContext { flightsql_client: Mutex>>, } +impl std::fmt::Debug for ExecutionContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ExecutionContext").finish() + } +} + impl ExecutionContext { /// Construct a new `ExecutionContext` with the specified configuration pub fn try_new(config: &ExecutionConfig) -> Result { @@ -107,6 +116,16 @@ impl ExecutionContext { Ok(()) } + /// Create a physical plan from the specified SQL string. This is useful if you want to store + /// the plan and collect metrics from it. + pub async fn create_physical_plan( + &self, + sql: &str, + ) -> datafusion::error::Result> { + let df = self.session_ctx.sql(sql).await?; + df.create_physical_plan().await + } + /// Executes the specified sql string, returning the resulting /// [`SendableRecordBatchStream`] of results pub async fn execute_sql( diff --git a/src/extensions/builder.rs b/src/extensions/builder.rs index e3724cb5..095e60e0 100644 --- a/src/extensions/builder.rs +++ b/src/extensions/builder.rs @@ -75,7 +75,7 @@ impl DftSessionStateBuilder { pub fn new() -> Self { let session_config = SessionConfig::default() // TODO why is batch size 1? - .with_batch_size(1) + .with_batch_size(100) .with_information_schema(true); Self { diff --git a/tests/tui.rs b/tests/tui.rs index 5b9b7ced..b66ecb11 100644 --- a/tests/tui.rs +++ b/tests/tui.rs @@ -14,78 +14,15 @@ // KIND, either express or implied. See the License for the // specific language governing permissions and limitations // under the License. +// -//! Tests for the TUI (e.g. user application with keyboard commands) - -use dft::app::state::initialize; -use dft::app::{App, AppEvent}; -use dft::execution::ExecutionContext; -use ratatui::crossterm::event; +use dft::{ + app::{state::initialize, App, AppEvent}, + execution::ExecutionContext, +}; use tempfile::{tempdir, TempDir}; -#[tokio::test] -async fn construct_with_no_args() { - let _test_app = TestApp::new(); -} - -#[tokio::test] -async fn quit_app_from_sql_tab() { - let mut test_app = TestApp::new(); - // SQL Tab - let key = event::KeyEvent::new(event::KeyCode::Char('q'), event::KeyModifiers::NONE); - let app_event = AppEvent::Key(key); - test_app.handle_app_event(app_event).unwrap(); - // Ideally, we figure out a way to check that the app actually quits - assert!(test_app.state().should_quit); -} - -#[tokio::test] -async fn quit_app_from_flightsql_tab() { - let mut test_app = TestApp::new(); - let flightsql_key = event::KeyEvent::new(event::KeyCode::Char('2'), event::KeyModifiers::NONE); - let app_event = AppEvent::Key(flightsql_key); - test_app.handle_app_event(app_event).unwrap(); - let key = event::KeyEvent::new(event::KeyCode::Char('q'), event::KeyModifiers::NONE); - let app_event = AppEvent::Key(key); - test_app.handle_app_event(app_event).unwrap(); - assert!(test_app.state().should_quit); -} - -#[tokio::test] -async fn quit_app_from_history_tab() { - let mut test_app = TestApp::new(); - let history_key = event::KeyEvent::new(event::KeyCode::Char('3'), event::KeyModifiers::NONE); - let app_event = AppEvent::Key(history_key); - test_app.handle_app_event(app_event).unwrap(); - let key = event::KeyEvent::new(event::KeyCode::Char('q'), event::KeyModifiers::NONE); - let app_event = AppEvent::Key(key); - test_app.handle_app_event(app_event).unwrap(); - assert!(test_app.state().should_quit); -} - -#[tokio::test] -async fn quit_app_from_logs_tab() { - let mut test_app = TestApp::new(); - let logs_key = event::KeyEvent::new(event::KeyCode::Char('4'), event::KeyModifiers::NONE); - let app_event = AppEvent::Key(logs_key); - test_app.handle_app_event(app_event).unwrap(); - let key = event::KeyEvent::new(event::KeyCode::Char('q'), event::KeyModifiers::NONE); - let app_event = AppEvent::Key(key); - test_app.handle_app_event(app_event).unwrap(); - assert!(test_app.state().should_quit); -} - -#[tokio::test] -async fn quit_app_from_context_tab() { - let mut test_app = TestApp::new(); - let context_key = event::KeyEvent::new(event::KeyCode::Char('5'), event::KeyModifiers::NONE); - let app_event = AppEvent::Key(context_key); - test_app.handle_app_event(app_event).unwrap(); - let key = event::KeyEvent::new(event::KeyCode::Char('q'), event::KeyModifiers::NONE); - let app_event = AppEvent::Key(key); - test_app.handle_app_event(app_event).unwrap(); - assert!(test_app.state().should_quit); -} +mod tui_cases; /// Fixture with an [`App`] instance and other temporary state struct TestApp<'app> { diff --git a/tests/tui_cases/mod.rs b/tests/tui_cases/mod.rs new file mode 100644 index 00000000..f38f8676 --- /dev/null +++ b/tests/tui_cases/mod.rs @@ -0,0 +1,20 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// + +mod pagination; +mod quit; diff --git a/tests/tui_cases/pagination.rs b/tests/tui_cases/pagination.rs new file mode 100644 index 00000000..0ff85410 --- /dev/null +++ b/tests/tui_cases/pagination.rs @@ -0,0 +1,211 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for the TUI (e.g. user application with keyboard commands) + +use std::sync::Arc; +use std::time::Duration; + +use datafusion::arrow::array::{ArrayRef, RecordBatch, UInt32Array}; +use datafusion::assert_batches_eq; +use dft::app::{AppEvent, ExecutionResultsBatch}; + +use crate::TestApp; + +fn create_batch(adj: u32) -> RecordBatch { + let arr1: ArrayRef = Arc::new(UInt32Array::from(vec![1 + adj, 2 + adj])); + let arr2: ArrayRef = Arc::new(UInt32Array::from(vec![3 + adj, 4 + adj])); + + RecordBatch::try_from_iter(vec![("a", arr1), ("b", arr2)]).unwrap() +} + +fn create_execution_results(query: &str, adj: u32) -> ExecutionResultsBatch { + let duration = Duration::from_secs(1); + let batch = create_batch(adj); + ExecutionResultsBatch::new(query.to_string(), batch, duration) +} + +// Tests that a single page of results is displayed correctly +#[tokio::test] +async fn single_page() { + let mut test_app = TestApp::new(); + let res1 = create_execution_results("SELECT 1", 0); + let event1 = AppEvent::ExecutionResultsNextPage(res1); + + test_app.handle_app_event(AppEvent::NewExecution).unwrap(); + test_app.handle_app_event(event1).unwrap(); + + let state = test_app.state(); + + let page = state.sql_tab.results_page().unwrap(); + assert_eq!(page, 0); + + let batch = state.sql_tab.current_batch(); + assert!(batch.is_some()); + + let batch = batch.unwrap(); + let batches = vec![batch.clone()]; + let expected = [ + "+---+---+", + "| a | b |", + "+---+---+", + "| 1 | 3 |", + "| 2 | 4 |", + "+---+---+", + ]; + assert_batches_eq!(expected, &batches); + let table_state = state.sql_tab.query_results_state(); + assert!(table_state.is_some()); + let table_state = table_state.as_ref().unwrap(); + assert_eq!(table_state.borrow().selected(), None); +} + +// Tests that we can paginate through multiple pages and go back to the first page +#[tokio::test] +async fn multiple_pages_forward_and_back() { + let mut test_app = TestApp::new(); + let res1 = create_execution_results("SELECT 1", 0); + let event1 = AppEvent::ExecutionResultsNextPage(res1); + + test_app.handle_app_event(AppEvent::NewExecution).unwrap(); + test_app.handle_app_event(event1).unwrap(); + + { + let state = test_app.state(); + let page = state.sql_tab.results_page().unwrap(); + assert_eq!(page, 0); + } + + let res2 = create_execution_results("SELECT 1", 1); + let event2 = AppEvent::ExecutionResultsNextPage(res2); + test_app.handle_app_event(event2).unwrap(); + + { + let state = test_app.state(); + let page = state.sql_tab.results_page().unwrap(); + assert_eq!(page, 1); + } + + { + let state = test_app.state(); + let batch = state.sql_tab.current_batch(); + assert!(batch.is_some()); + + let batch = batch.unwrap(); + let batches = vec![batch.clone()]; + let expected = [ + "+---+---+", + "| a | b |", + "+---+---+", + "| 2 | 4 |", + "| 3 | 5 |", + "+---+---+", + ]; + assert_batches_eq!(expected, &batches); + } + + let left_key = crossterm::event::KeyEvent::new( + crossterm::event::KeyCode::Left, + crossterm::event::KeyModifiers::NONE, + ); + let event3 = AppEvent::Key(left_key); + test_app.handle_app_event(event3).unwrap(); + + { + let state = test_app.state(); + let page = state.sql_tab.results_page().unwrap(); + assert_eq!(page, 0); + } + + { + let state = test_app.state(); + let batch = state.sql_tab.current_batch(); + assert!(batch.is_some()); + + let batch = batch.unwrap(); + let batches = vec![batch.clone()]; + let expected = [ + "+---+---+", + "| a | b |", + "+---+---+", + "| 1 | 3 |", + "| 2 | 4 |", + "+---+---+", + ]; + assert_batches_eq!(expected, &batches); + } +} + +// Tests that we can still paginate when we already have the batch because we previously viewed the +// page +#[tokio::test] +async fn multiple_pages_forward_and_back_and_forward() { + let mut test_app = TestApp::new(); + let res1 = create_execution_results("SELECT 1", 0); + let event1 = AppEvent::ExecutionResultsNextPage(res1); + + test_app.handle_app_event(AppEvent::NewExecution).unwrap(); + test_app.handle_app_event(event1).unwrap(); + + { + let state = test_app.state(); + let page = state.sql_tab.results_page().unwrap(); + assert_eq!(page, 0); + } + + let res2 = create_execution_results("SELECT 1", 1); + let event2 = AppEvent::ExecutionResultsNextPage(res2); + test_app.handle_app_event(event2).unwrap(); + + let left_key = crossterm::event::KeyEvent::new( + crossterm::event::KeyCode::Left, + crossterm::event::KeyModifiers::NONE, + ); + let event3 = AppEvent::Key(left_key); + test_app.handle_app_event(event3).unwrap(); + + let right_key = crossterm::event::KeyEvent::new( + crossterm::event::KeyCode::Right, + crossterm::event::KeyModifiers::NONE, + ); + let event4 = AppEvent::Key(right_key); + test_app.handle_app_event(event4).unwrap(); + + { + let state = test_app.state(); + let page = state.sql_tab.results_page().unwrap(); + assert_eq!(page, 1); + } + + { + let state = test_app.state(); + let batch = state.sql_tab.current_batch(); + assert!(batch.is_some()); + + let batch = batch.unwrap(); + let batches = vec![batch.clone()]; + let expected = [ + "+---+---+", + "| a | b |", + "+---+---+", + "| 2 | 4 |", + "| 3 | 5 |", + "+---+---+", + ]; + assert_batches_eq!(expected, &batches); + } +} diff --git a/tests/tui_cases/quit.rs b/tests/tui_cases/quit.rs new file mode 100644 index 00000000..5b9b7ced --- /dev/null +++ b/tests/tui_cases/quit.rs @@ -0,0 +1,121 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tests for the TUI (e.g. user application with keyboard commands) + +use dft::app::state::initialize; +use dft::app::{App, AppEvent}; +use dft::execution::ExecutionContext; +use ratatui::crossterm::event; +use tempfile::{tempdir, TempDir}; + +#[tokio::test] +async fn construct_with_no_args() { + let _test_app = TestApp::new(); +} + +#[tokio::test] +async fn quit_app_from_sql_tab() { + let mut test_app = TestApp::new(); + // SQL Tab + let key = event::KeyEvent::new(event::KeyCode::Char('q'), event::KeyModifiers::NONE); + let app_event = AppEvent::Key(key); + test_app.handle_app_event(app_event).unwrap(); + // Ideally, we figure out a way to check that the app actually quits + assert!(test_app.state().should_quit); +} + +#[tokio::test] +async fn quit_app_from_flightsql_tab() { + let mut test_app = TestApp::new(); + let flightsql_key = event::KeyEvent::new(event::KeyCode::Char('2'), event::KeyModifiers::NONE); + let app_event = AppEvent::Key(flightsql_key); + test_app.handle_app_event(app_event).unwrap(); + let key = event::KeyEvent::new(event::KeyCode::Char('q'), event::KeyModifiers::NONE); + let app_event = AppEvent::Key(key); + test_app.handle_app_event(app_event).unwrap(); + assert!(test_app.state().should_quit); +} + +#[tokio::test] +async fn quit_app_from_history_tab() { + let mut test_app = TestApp::new(); + let history_key = event::KeyEvent::new(event::KeyCode::Char('3'), event::KeyModifiers::NONE); + let app_event = AppEvent::Key(history_key); + test_app.handle_app_event(app_event).unwrap(); + let key = event::KeyEvent::new(event::KeyCode::Char('q'), event::KeyModifiers::NONE); + let app_event = AppEvent::Key(key); + test_app.handle_app_event(app_event).unwrap(); + assert!(test_app.state().should_quit); +} + +#[tokio::test] +async fn quit_app_from_logs_tab() { + let mut test_app = TestApp::new(); + let logs_key = event::KeyEvent::new(event::KeyCode::Char('4'), event::KeyModifiers::NONE); + let app_event = AppEvent::Key(logs_key); + test_app.handle_app_event(app_event).unwrap(); + let key = event::KeyEvent::new(event::KeyCode::Char('q'), event::KeyModifiers::NONE); + let app_event = AppEvent::Key(key); + test_app.handle_app_event(app_event).unwrap(); + assert!(test_app.state().should_quit); +} + +#[tokio::test] +async fn quit_app_from_context_tab() { + let mut test_app = TestApp::new(); + let context_key = event::KeyEvent::new(event::KeyCode::Char('5'), event::KeyModifiers::NONE); + let app_event = AppEvent::Key(context_key); + test_app.handle_app_event(app_event).unwrap(); + let key = event::KeyEvent::new(event::KeyCode::Char('q'), event::KeyModifiers::NONE); + let app_event = AppEvent::Key(key); + test_app.handle_app_event(app_event).unwrap(); + assert!(test_app.state().should_quit); +} + +/// Fixture with an [`App`] instance and other temporary state +struct TestApp<'app> { + /// Temporary directory for configuration files + /// + /// The directory is removed when the object is dropped so this + /// field must remain alive while the app is running + #[allow(dead_code)] + config_path: TempDir, + /// The [`App`] instance + app: App<'app>, +} + +impl<'app> TestApp<'app> { + /// Create a new [`TestApp`] instance configured with a temporary directory + fn new() -> Self { + let config_path = tempdir().unwrap(); + let state = initialize(config_path.path().to_path_buf()); + let execution = ExecutionContext::try_new(&state.config.execution).unwrap(); + let app = App::new(state, execution); + Self { config_path, app } + } + + /// Call app.event_handler with the given event + pub fn handle_app_event(&mut self, event: AppEvent) -> color_eyre::Result<()> { + self.app.handle_app_event(event) + } + + /// Return the app state + pub fn state(&self) -> &dft::app::state::AppState { + self.app.state() + } +}