diff --git a/pkg-r/DESCRIPTION b/pkg-r/DESCRIPTION index 9d87b99e..ce63e3a2 100644 --- a/pkg-r/DESCRIPTION +++ b/pkg-r/DESCRIPTION @@ -17,6 +17,8 @@ Depends: Imports: bslib, DBI, + dbplyr, + dplyr, duckdb, ellmer (>= 0.3.0), htmltools, diff --git a/pkg-r/NAMESPACE b/pkg-r/NAMESPACE index cbc584f3..8a966ed9 100644 --- a/pkg-r/NAMESPACE +++ b/pkg-r/NAMESPACE @@ -5,6 +5,7 @@ S3method(create_system_prompt,querychat_data_source) S3method(execute_query,dbi_source) S3method(get_db_type,data_frame_source) S3method(get_db_type,dbi_source) +S3method(get_lazy_data,dbi_source) S3method(get_schema,dbi_source) S3method(querychat_data_source,DBIConnection) S3method(querychat_data_source,data.frame) @@ -13,6 +14,7 @@ export(cleanup_source) export(create_system_prompt) export(execute_query) export(get_db_type) +export(get_lazy_data) export(get_schema) export(querychat_app) export(querychat_data_source) diff --git a/pkg-r/R/data_source.R b/pkg-r/R/data_source.R index 9a1282ef..1318f0a8 100644 --- a/pkg-r/R/data_source.R +++ b/pkg-r/R/data_source.R @@ -120,108 +120,64 @@ execute_query.dbi_source <- function(source, query, ...) { DBI::dbGetQuery(source$conn, query) } -#' Test a SQL query on a data source. -#' -#' @param source A querychat_data_source object -#' @param query SQL query string -#' @param ... Additional arguments passed to methods -#' @return Result of the query, limited to one row of data. -#' @export -test_query <- function(source, query, ...) { - UseMethod("test_query") -} - -#' @export -test_query.dbi_source <- function(source, query, ...) { - rs <- DBI::dbSendQuery(source$conn, query) - df <- DBI::dbFetch(rs, n = 1) - DBI::dbClearResult(rs) - df -} - -#' Get type information for a data source +#' Get a lazy representation of a data source #' #' @param source A querychat_data_source object +#' @param query SQL query string #' @param ... Additional arguments passed to methods -#' @return A character string containing the type information +#' @return A lazy representation (typically a dbplyr tbl) #' @export -get_db_type <- function(source, ...) { - UseMethod("get_db_type") +get_lazy_data <- function(source, query, ...) { + UseMethod("get_lazy_data") } #' @export -get_db_type.data_frame_source <- function(source, ...) { - # Local dataframes are always duckdb! - return("DuckDB") -} - -#' @export -get_db_type.dbi_source <- function(source, ...) { - conn <- source$conn - conn_info <- DBI::dbGetInfo(conn) - # default to 'POSIX' if dbms name not found - dbms_name <- purrr::pluck(conn_info, "dbms.name", .default = "POSIX") - # Special handling for known database types - if (inherits(conn, "SQLiteConnection")) { - return("SQLite") +get_lazy_data.dbi_source <- function( + source, + query = NULL, + ... +) { + if (is.null(query) || query == "") { + # For a null or empty query, default to returning the whole table (ie SELECT *) + dplyr::tbl(source$conn, source$table_name) + } else { + # Clean the SQL query to avoid dbplyr issues with syntax problems + cleaned_query <- clean_sql(query, enforce_select = TRUE) + + if (is.null(cleaned_query)) { + # If cleaning results in an empty query, raise an error + rlang::abort(c( + "Query cleaning resulted in an empty query.", + "i" = "Check the original query for proper syntax.", + "i" = "Query may consist only of comments or invalid SQL." + )) + } else { + # Use dbplyr::sql to create a safe SQL query object with the cleaned query + # No fallback to full table on error - let errors propagate to the caller + dplyr::tbl(source$conn, dbplyr::sql(cleaned_query)) + } } - # remove ' SQL', if exists (SQL is already in the prompt) - return(gsub(" SQL", "", dbms_name)) } -#' Create a system prompt for the data source +#' Test a SQL query on a data source. #' #' @param source A querychat_data_source object -#' @param data_description Optional description of the data -#' @param extra_instructions Optional additional instructions +#' @param query SQL query string #' @param ... Additional arguments passed to methods -#' @return A string with the system prompt +#' @return Result of the query, limited to one row of data. #' @export -create_system_prompt <- function( - source, - data_description = NULL, - extra_instructions = NULL, - ... -) { - UseMethod("create_system_prompt") +test_query <- function(source, query, ...) { + UseMethod("test_query") } #' @export -create_system_prompt.querychat_data_source <- function( - source, - data_description = NULL, - extra_instructions = NULL, - ... -) { - if (!is.null(data_description)) { - data_description <- paste(data_description, collapse = "\n") - } - if (!is.null(extra_instructions)) { - extra_instructions <- paste(extra_instructions, collapse = "\n") - } - - # Read the prompt file - prompt_path <- system.file("prompt", "prompt.md", package = "querychat") - prompt_content <- readLines(prompt_path, warn = FALSE) - prompt_text <- paste(prompt_content, collapse = "\n") - - # Get schema for the data source - schema <- get_schema(source) - - # Examine the data source and get the type for the prompt - db_type <- get_db_type(source) - - whisker::whisker.render( - prompt_text, - list( - schema = schema, - data_description = data_description, - extra_instructions = extra_instructions, - db_type = db_type - ) - ) +test_query.dbi_source <- function(source, query, ...) { + rs <- DBI::dbSendQuery(source$conn, query) + df <- DBI::dbFetch(rs, n = 1) + DBI::dbClearResult(rs) + df } #' Clean up a data source (close connections, etc.) @@ -241,208 +197,3 @@ cleanup_source.dbi_source <- function(source, ...) { } invisible(NULL) } - - -#' Get schema for a data source -#' -#' @param source A querychat_data_source object -#' @param ... Additional arguments passed to methods -#' @return A character string describing the schema -#' @export -get_schema <- function(source, ...) { - UseMethod("get_schema") -} - -#' @export -get_schema.dbi_source <- function(source, ...) { - conn <- source$conn - table_name <- source$table_name - categorical_threshold <- source$categorical_threshold - - # Get column information - columns <- DBI::dbListFields(conn, table_name) - - schema_lines <- c( - paste("Table:", DBI::dbQuoteIdentifier(conn, table_name)), - "Columns:" - ) - - # Build single query to get column statistics - select_parts <- character(0) - numeric_columns <- character(0) - text_columns <- character(0) - - # Get sample of data to determine types - sample_query <- paste0( - "SELECT * FROM ", - DBI::dbQuoteIdentifier(conn, table_name), - " LIMIT 1" - ) - sample_data <- DBI::dbGetQuery(conn, sample_query) - - for (col in columns) { - col_class <- class(sample_data[[col]])[1] - - if ( - col_class %in% - c("integer", "numeric", "double", "Date", "POSIXct", "POSIXt") - ) { - numeric_columns <- c(numeric_columns, col) - select_parts <- c( - select_parts, - paste0( - "MIN(", - DBI::dbQuoteIdentifier(conn, col), - ") as ", - DBI::dbQuoteIdentifier(conn, paste0(col, '__min')) - ), - paste0( - "MAX(", - DBI::dbQuoteIdentifier(conn, col), - ") as ", - DBI::dbQuoteIdentifier(conn, paste0(col, '__max')) - ) - ) - } else if (col_class %in% c("character", "factor")) { - text_columns <- c(text_columns, col) - select_parts <- c( - select_parts, - paste0( - "COUNT(DISTINCT ", - DBI::dbQuoteIdentifier(conn, col), - ") as ", - DBI::dbQuoteIdentifier(conn, paste0(col, '__distinct_count')) - ) - ) - } - } - - # Execute statistics query - column_stats <- list() - if (length(select_parts) > 0) { - tryCatch( - { - stats_query <- paste0( - "SELECT ", - paste0(select_parts, collapse = ", "), - " FROM ", - DBI::dbQuoteIdentifier(conn, table_name) - ) - result <- DBI::dbGetQuery(conn, stats_query) - if (nrow(result) > 0) { - column_stats <- as.list(result[1, ]) - } - }, - error = function(e) { - # Fall back to no statistics if query fails - } - ) - } - - # Get categorical values for text columns below threshold - categorical_values <- list() - text_cols_to_query <- character(0) - - for (col_name in text_columns) { - distinct_count_key <- paste0(col_name, "__distinct_count") - if ( - distinct_count_key %in% - names(column_stats) && - !is.na(column_stats[[distinct_count_key]]) && - column_stats[[distinct_count_key]] <= categorical_threshold - ) { - text_cols_to_query <- c(text_cols_to_query, col_name) - } - } - - # Remove duplicates - text_cols_to_query <- unique(text_cols_to_query) - - # Get categorical values - if (length(text_cols_to_query) > 0) { - for (col_name in text_cols_to_query) { - tryCatch( - { - cat_query <- paste0( - "SELECT DISTINCT ", - DBI::dbQuoteIdentifier(conn, col_name), - " FROM ", - DBI::dbQuoteIdentifier(conn, table_name), - " WHERE ", - DBI::dbQuoteIdentifier(conn, col_name), - " IS NOT NULL ORDER BY ", - DBI::dbQuoteIdentifier(conn, col_name) - ) - result <- DBI::dbGetQuery(conn, cat_query) - if (nrow(result) > 0) { - categorical_values[[col_name]] <- result[[1]] - } - }, - error = function(e) { - # Skip categorical values if query fails - } - ) - } - } - - # Build schema description - for (col in columns) { - col_class <- class(sample_data[[col]])[1] - sql_type <- r_class_to_sql_type(col_class) - - column_info <- paste0("- ", col, " (", sql_type, ")") - - # Add range info for numeric columns - if (col %in% numeric_columns) { - min_key <- paste0(col, "__min") - max_key <- paste0(col, "__max") - if ( - min_key %in% - names(column_stats) && - max_key %in% names(column_stats) && - !is.na(column_stats[[min_key]]) && - !is.na(column_stats[[max_key]]) - ) { - range_info <- paste0( - " Range: ", - column_stats[[min_key]], - " to ", - column_stats[[max_key]] - ) - column_info <- paste(column_info, range_info, sep = "\n") - } - } - - # Add categorical values for text columns - if (col %in% names(categorical_values)) { - values <- categorical_values[[col]] - if (length(values) > 0) { - values_str <- paste0("'", values, "'", collapse = ", ") - cat_info <- paste0(" Categorical values: ", values_str) - column_info <- paste(column_info, cat_info, sep = "\n") - } - } - - schema_lines <- c(schema_lines, column_info) - } - - paste(schema_lines, collapse = "\n") -} - - -# Helper function to map R classes to SQL types -r_class_to_sql_type <- function(r_class) { - switch( - r_class, - "integer" = "INTEGER", - "numeric" = "FLOAT", - "double" = "FLOAT", - "logical" = "BOOLEAN", - "Date" = "DATE", - "POSIXct" = "TIMESTAMP", - "POSIXt" = "TIMESTAMP", - "character" = "TEXT", - "factor" = "TEXT", - "TEXT" # default - ) -} diff --git a/pkg-r/R/querychat.R b/pkg-r/R/querychat.R index f35e5b9e..e7e917de 100644 --- a/pkg-r/R/querychat.R +++ b/pkg-r/R/querychat.R @@ -187,6 +187,9 @@ querychat_ui <- function(id) { #' - `sql`: A reactive that returns the current SQL query. #' - `title`: A reactive that returns the current title. #' - `df`: A reactive that returns the filtered data as a data.frame. +#' - `tbl`: A reactive that returns a lazy `dbplyr::tbl()` object that supports +#' lazy evaluation and query chaining. This can be further manipulated with +#' dplyr verbs before calling `collect()` to materialize the results. #' - `chat`: The [ellmer::Chat] object that powers the chat interface. #' #' @export @@ -204,6 +207,9 @@ querychat_server <- function(id, querychat_config) { filtered_df <- shiny::reactive({ execute_query(data_source, query = DBI::SQL(current_query())) }) + filtered_tbl <- shiny::reactive({ + get_lazy_data(data_source, query = current_query()) + }) append_output <- function(...) { txt <- paste0(...) @@ -270,6 +276,7 @@ querychat_server <- function(id, querychat_config) { sql = shiny::reactive(current_query()), title = shiny::reactive(current_title()), df = filtered_df, + tbl = filtered_tbl, update_query = function(query, title = NULL) { current_query(query) current_title(title) diff --git a/pkg-r/R/schema_utils.R b/pkg-r/R/schema_utils.R new file mode 100644 index 00000000..cfa9ce0e --- /dev/null +++ b/pkg-r/R/schema_utils.R @@ -0,0 +1,287 @@ +#' Get type information for a data source +#' +#' @param source A querychat_data_source object +#' @param ... Additional arguments passed to methods +#' @return A character string containing the type information +#' @export +get_db_type <- function(source, ...) { + UseMethod("get_db_type") +} + +#' @export +get_db_type.data_frame_source <- function(source, ...) { + # Local dataframes are always duckdb! + return("DuckDB") +} + +#' @export +get_db_type.dbi_source <- function(source, ...) { + conn <- source$conn + conn_info <- DBI::dbGetInfo(conn) + # default to 'POSIX' if dbms name not found + dbms_name <- purrr::pluck(conn_info, "dbms.name", .default = "POSIX") + # Special handling for known database types + if (inherits(conn, "SQLiteConnection")) { + return("SQLite") + } + # remove ' SQL', if exists (SQL is already in the prompt) + return(gsub(" SQL", "", dbms_name)) +} + + +#' Create a system prompt for the data source +#' +#' @param source A querychat_data_source object +#' @param data_description Optional description of the data +#' @param extra_instructions Optional additional instructions +#' @param ... Additional arguments passed to methods +#' @return A string with the system prompt +#' @export +create_system_prompt <- function( + source, + data_description = NULL, + extra_instructions = NULL, + ... +) { + UseMethod("create_system_prompt") +} + +#' @export +create_system_prompt.querychat_data_source <- function( + source, + data_description = NULL, + extra_instructions = NULL, + ... +) { + if (!is.null(data_description)) { + data_description <- paste(data_description, collapse = "\n") + } + if (!is.null(extra_instructions)) { + extra_instructions <- paste(extra_instructions, collapse = "\n") + } + + # Read the prompt file + prompt_path <- system.file("prompt", "prompt.md", package = "querychat") + prompt_content <- readLines(prompt_path, warn = FALSE) + prompt_text <- paste(prompt_content, collapse = "\n") + + # Get schema for the data source + schema <- get_schema(source) + + # Examine the data source and get the type for the prompt + db_type <- get_db_type(source) + + whisker::whisker.render( + prompt_text, + list( + schema = schema, + data_description = data_description, + extra_instructions = extra_instructions, + db_type = db_type + ) + ) +} + +#' Get schema for a data source +#' +#' @param source A querychat_data_source object +#' @param ... Additional arguments passed to methods +#' @return A character string describing the schema +#' @export +get_schema <- function(source, ...) { + UseMethod("get_schema") +} + +#' @export +get_schema.dbi_source <- function(source, ...) { + conn <- source$conn + table_name <- source$table_name + categorical_threshold <- source$categorical_threshold + + # Get column information + columns <- DBI::dbListFields(conn, table_name) + + schema_lines <- c( + paste("Table:", DBI::dbQuoteIdentifier(conn, table_name)), + "Columns:" + ) + + # Build single query to get column statistics + select_parts <- character(0) + numeric_columns <- character(0) + text_columns <- character(0) + + # Get sample of data to determine types + sample_query <- paste0( + "SELECT * FROM ", + DBI::dbQuoteIdentifier(conn, table_name), + " LIMIT 1" + ) + sample_data <- DBI::dbGetQuery(conn, sample_query) + + for (col in columns) { + col_class <- class(sample_data[[col]])[1] + + if ( + col_class %in% + c("integer", "numeric", "double", "Date", "POSIXct", "POSIXt") + ) { + numeric_columns <- c(numeric_columns, col) + select_parts <- c( + select_parts, + paste0( + "MIN(", + DBI::dbQuoteIdentifier(conn, col), + ") as ", + DBI::dbQuoteIdentifier(conn, paste0(col, '__min')) + ), + paste0( + "MAX(", + DBI::dbQuoteIdentifier(conn, col), + ") as ", + DBI::dbQuoteIdentifier(conn, paste0(col, '__max')) + ) + ) + } else if (col_class %in% c("character", "factor")) { + text_columns <- c(text_columns, col) + select_parts <- c( + select_parts, + paste0( + "COUNT(DISTINCT ", + DBI::dbQuoteIdentifier(conn, col), + ") as ", + DBI::dbQuoteIdentifier(conn, paste0(col, '__distinct_count')) + ) + ) + } + } + + # Execute statistics query + column_stats <- list() + if (length(select_parts) > 0) { + tryCatch( + { + stats_query <- paste0( + "SELECT ", + paste0(select_parts, collapse = ", "), + " FROM ", + DBI::dbQuoteIdentifier(conn, table_name) + ) + result <- DBI::dbGetQuery(conn, stats_query) + if (nrow(result) > 0) { + column_stats <- as.list(result[1, ]) + } + }, + error = function(e) { + # Fall back to no statistics if query fails + } + ) + } + + # Get categorical values for text columns below threshold + categorical_values <- list() + text_cols_to_query <- character(0) + + for (col_name in text_columns) { + distinct_count_key <- paste0(col_name, "__distinct_count") + if ( + distinct_count_key %in% + names(column_stats) && + !is.na(column_stats[[distinct_count_key]]) && + column_stats[[distinct_count_key]] <= categorical_threshold + ) { + text_cols_to_query <- c(text_cols_to_query, col_name) + } + } + + # Remove duplicates + text_cols_to_query <- unique(text_cols_to_query) + + # Get categorical values + if (length(text_cols_to_query) > 0) { + for (col_name in text_cols_to_query) { + tryCatch( + { + cat_query <- paste0( + "SELECT DISTINCT ", + DBI::dbQuoteIdentifier(conn, col_name), + " FROM ", + DBI::dbQuoteIdentifier(conn, table_name), + " WHERE ", + DBI::dbQuoteIdentifier(conn, col_name), + " IS NOT NULL ORDER BY ", + DBI::dbQuoteIdentifier(conn, col_name) + ) + result <- DBI::dbGetQuery(conn, cat_query) + if (nrow(result) > 0) { + categorical_values[[col_name]] <- result[[1]] + } + }, + error = function(e) { + # Skip categorical values if query fails + } + ) + } + } + + # Build schema description + for (col in columns) { + col_class <- class(sample_data[[col]])[1] + sql_type <- r_class_to_sql_type(col_class) + + column_info <- paste0("- ", col, " (", sql_type, ")") + + # Add range info for numeric columns + if (col %in% numeric_columns) { + min_key <- paste0(col, "__min") + max_key <- paste0(col, "__max") + if ( + min_key %in% + names(column_stats) && + max_key %in% names(column_stats) && + !is.na(column_stats[[min_key]]) && + !is.na(column_stats[[max_key]]) + ) { + range_info <- paste0( + " Range: ", + column_stats[[min_key]], + " to ", + column_stats[[max_key]] + ) + column_info <- paste(column_info, range_info, sep = "\n") + } + } + + # Add categorical values for text columns + if (col %in% names(categorical_values)) { + values <- categorical_values[[col]] + if (length(values) > 0) { + values_str <- paste0("'", values, "'", collapse = ", ") + cat_info <- paste0(" Categorical values: ", values_str) + column_info <- paste(column_info, cat_info, sep = "\n") + } + } + + schema_lines <- c(schema_lines, column_info) + } + + paste(schema_lines, collapse = "\n") +} + + +# Helper function to map R classes to SQL types +r_class_to_sql_type <- function(r_class) { + switch( + r_class, + "integer" = "INTEGER", + "numeric" = "FLOAT", + "double" = "FLOAT", + "logical" = "BOOLEAN", + "Date" = "DATE", + "POSIXct" = "TIMESTAMP", + "POSIXt" = "TIMESTAMP", + "character" = "TEXT", + "factor" = "TEXT", + "TEXT" # default + ) +} diff --git a/pkg-r/R/sql_utils.R b/pkg-r/R/sql_utils.R new file mode 100644 index 00000000..d1b6540f --- /dev/null +++ b/pkg-r/R/sql_utils.R @@ -0,0 +1,185 @@ +#' Clean SQL query for safe execution with dbplyr +#' +#' This function cleans an SQL query by removing comments, trailing semicolons, +#' and handling other syntax issues that can cause problems with dbplyr's tbl() function. +#' +#' @param query A character string containing an SQL query +#' @param enforce_select Logical, whether to validate that the query is a SELECT statement +#' @return A cleaned SQL query string or NULL if the query is empty or invalid +#' @keywords internal +clean_sql <- function(query, enforce_select = TRUE) { + # Check input + if (!is.character(query)) { + query <- as.character(query) + } + + # Save original query for error messages + original_query <- query + + # First, handle nested multiline comments safely + query <- gsub("/\\*[^*]*\\*+(?:[^/*][^*]*\\*+)*/", "", query, perl = TRUE) + + # Remove single-line comments (--) anywhere in the string + query <- gsub("--[^\n]*", "", query) + + # Remove GO statements (common in T-SQL as batch separators) + query <- gsub("\\bGO\\b", "", query, ignore.case = TRUE) + + # Split multiple statements and keep only the first one + # First, check if we have any semicolons + if (grepl(";", query, fixed = TRUE)) { + # Check if we have quotes that might contain semicolons + quote_matches <- gregexpr("'[^']*'", query, perl = TRUE) + if (quote_matches[[1]][1] != -1) { + # We have quoted strings, check if any contain semicolons + has_quoted_semicolon <- FALSE + + for (i in 1:length(quote_matches[[1]])) { + start_pos <- quote_matches[[1]][i] + end_pos <- start_pos + attr(quote_matches[[1]], "match.length")[i] - 1 + quoted_part <- substr(query, start_pos, end_pos) + + if (grepl(";", quoted_part, fixed = TRUE)) { + has_quoted_semicolon <- TRUE + break + } + } + + if (!has_quoted_semicolon) { + # No semicolons inside quotes, we can safely split by semicolons + parts <- strsplit(query, ";", fixed = TRUE)[[1]] + if (length(parts) > 1) { + first_statement <- trimws(parts[1]) + rlang::warn( + c( + "Multiple SQL statements detected. Only the first statement will be used:", + "i" = paste0( + "Using: ", + substr(first_statement, 1, 60), + if (nchar(first_statement) > 60) "..." else "" + ), + "i" = paste0( + "Ignoring ", + length(parts) - 1, + " additional statement(s)" + ) + ) + ) + query <- first_statement + } + } + } else { + # No quotes, we can safely split by semicolons + parts <- strsplit(query, ";", fixed = TRUE)[[1]] + if (length(parts) > 1) { + first_statement <- trimws(parts[1]) + rlang::warn( + c( + "Multiple SQL statements detected. Only the first statement will be used:", + "i" = paste0( + "Using: ", + substr(first_statement, 1, 60), + if (nchar(first_statement) > 60) "..." else "" + ), + "i" = paste0( + "Ignoring ", + length(parts) - 1, + " additional statement(s)" + ) + ) + ) + query <- first_statement + } + } + } + + # Remove trailing semicolons + query <- gsub(";\\s*$", "", query) + + # Trim whitespace + query <- trimws(query) + + # Handle empty query + if (nchar(query) == 0) { + return(NULL) + } + + # Check for unbalanced quotes + single_quotes <- gregexpr("'", query, fixed = TRUE)[[1]] + if (length(single_quotes) > 0 && single_quotes[1] != -1) { + single_quote_count <- length(single_quotes) + if (single_quote_count %% 2 != 0) { + rlang::warn( + c( + "SQL contains unbalanced single quotes, which may cause errors:", + "i" = substr(original_query, 1, 100) + ) + ) + # Attempt to fix by adding a closing quote at the end + query <- paste0(query, "'") + } + } + + double_quotes <- gregexpr("\"", query, fixed = TRUE)[[1]] + if (length(double_quotes) > 0 && double_quotes[1] != -1) { + double_quote_count <- length(double_quotes) + if (double_quote_count %% 2 != 0) { + rlang::warn( + c( + "SQL contains unbalanced double quotes, which may cause errors:", + "i" = substr(original_query, 1, 100) + ) + ) + # Attempt to fix by adding a closing quote at the end + query <- paste0(query, "\"") + } + } + + # Check for unbalanced parentheses + open_parens <- gregexpr("\\(", query, perl = TRUE)[[1]] + if (open_parens[1] == -1) { + open_parens <- integer(0) + } + + close_parens <- gregexpr("\\)", query, perl = TRUE)[[1]] + if (close_parens[1] == -1) { + close_parens <- integer(0) + } + + if (length(open_parens) != length(close_parens)) { + rlang::warn( + c( + "SQL contains unbalanced parentheses, which may cause errors:", + "i" = substr(original_query, 1, 100) + ) + ) + + # Attempt to fix by adding closing parentheses if there are more open ones + if (length(open_parens) > length(close_parens)) { + diff <- length(open_parens) - length(close_parens) + query <- paste0(query, paste0(rep(")", diff), collapse = "")) + } + } + + # Filter out non-standard characters that might break SQL + query <- gsub("[^\x20-\x7E\r\n\t]", "", query) + + # Validate that it's a SELECT statement if requested + if (enforce_select) { + # Check if it starts with SELECT (case insensitive, allowing for whitespace) + if (!grepl("^\\s*SELECT\\b", query, ignore.case = TRUE)) { + rlang::abort( + c( + "SQL query does not appear to start with SELECT:", + "x" = substr(query, 1, 100), + "i" = "dbplyr::tbl() requires a SELECT statement." + ) + ) + } + } + + # Final trimming + query <- trimws(query) + + return(query) +} diff --git a/pkg-r/examples/README.md b/pkg-r/examples/README.md new file mode 100644 index 00000000..6e5bd242 --- /dev/null +++ b/pkg-r/examples/README.md @@ -0,0 +1,116 @@ +# Querychat R Examples + +This directory contains examples demonstrating different ways to use the querychat R package. Each example is contained in its own folder with a complete Shiny application. + +## Examples Overview + +### [basic-dataframe](basic-dataframe/) +- **Description**: Simple example using querychat with a regular R data frame +- **Dataset**: Titanic passenger data +- **Key feature**: Shows how querychat works with standard data frames +- **To run**: `shiny::runApp("basic-dataframe")` + +### [basic-database](basic-database/) +- **Description**: Simple example using querychat with a SQL database +- **Dataset**: Iris flower measurements +- **Key feature**: Shows how querychat connects to databases via DBI +- **To run**: `shiny::runApp("basic-database")` + +### [chained-query](chained-query/) +- **Description**: Advanced example demonstrating query chaining functionality +- **Dataset**: Titanic passenger data in SQLite +- **Key feature**: Shows how to use `querychat_server$tbl()` to chain additional dplyr operations +- **To run**: `shiny::runApp("chained-query")` + +## Key Concepts Demonstrated + +### Data Sources + +The examples demonstrate the two main ways to create a data source: + +```r +# With a data frame +df_source <- querychat_data_source(your_dataframe) + +# With a database connection +db_source <- querychat_data_source(conn, table_name = "your_table") +``` + +### Initialization + +All examples use the same pattern for initialization: + +```r +querychat_config <- querychat_init( + data_source = your_source, + greeting = greeting, + data_description = "Description of your data", + extra_instructions = "Additional instructions for the LLM" +) +``` + +### UI Integration + +The examples show how to integrate querychat into a Shiny UI: + +```r +ui <- bslib::page_sidebar( + sidebar = querychat_sidebar("chat"), + # Your main UI content here +) +``` + +### Server Logic + +The server function follows this pattern: + +```r +server <- function(input, output, session) { + chat <- querychat_server("chat", querychat_config) + + # Access chat outputs: + output$data_table <- DT::renderDT({ + chat$df() # Direct data frame access + }) + + output$sql_query <- renderText({ + chat$sql() # Access generated SQL + }) + + # Advanced: Use tbl() for query chaining + output$chained_results <- DT::renderDT({ + chat$tbl() %>% + filter(your_condition) %>% + collect() + }) +} +``` + +## Running the Examples + +To run any of these examples: + +1. Make sure you have the querychat package installed: + ```r + devtools::install("path/to/querychat/pkg-r") + ``` + +2. Install example-specific dependencies: + ```r + install.packages(c("shiny", "bslib", "DT", "dplyr", "DBI", "RSQLite")) + ``` + +3. Run the specific example: + ```r + shiny::runApp("path/to/querychat/pkg-r/examples/example-name") + ``` + +## Creating Your Own Apps + +These examples are designed to be starting points for your own applications. The core concepts apply regardless of your specific data or use case. + +For more advanced usage, explore: +- Custom LLM settings via the `create_chat_func` parameter +- Adding additional UI elements to work with querychat outputs +- Integrating with other Shiny packages and extensions +- Saving and restoring chat history \ No newline at end of file diff --git a/pkg-r/examples/basic-database/README.md b/pkg-r/examples/basic-database/README.md new file mode 100644 index 00000000..6d36ff90 --- /dev/null +++ b/pkg-r/examples/basic-database/README.md @@ -0,0 +1,47 @@ +# Querychat Basic Database Example + +This example demonstrates how to use querychat with a database connection. The app connects to a SQLite database containing the iris dataset and provides a chat interface for querying the data. + +## Features + +- SQLite database with iris dataset +- Natural language querying using the chat sidebar +- Display of query results in a table +- Display of the generated SQL query +- Basic information about the dataset + +## How It Works + +1. The app creates a temporary SQLite database and loads the iris dataset +2. A querychat data source is created with `querychat_data_source(conn, table_name = "iris")` +3. The chat interface is configured with `querychat_init()` and a custom greeting +4. Users can enter natural language queries in the chat sidebar +5. Results are displayed in a table and the corresponding SQL is shown + +## Running This Example + +To run this example: + +1. Make sure you have all dependencies installed: + - shiny, bslib, querychat, DBI, RSQLite + +2. Run the app with: + ```r + shiny::runApp("path/to/basic-database") + ``` + +3. Try asking questions in the chat sidebar like: + - "Show me the first 10 rows of the iris dataset" + - "What's the average sepal length by species?" + - "Which species has the largest petals?" + +## Connecting to Other Databases + +This example uses an in-memory SQLite database for simplicity, but querychat works with any database supported by the DBI package: + +- PostgreSQL (using RPostgreSQL or RPostgres) +- MySQL (using RMySQL) +- Microsoft SQL Server (using odbc) +- And more + +Replace the connection setup with your preferred database, and querychat will handle the rest! \ No newline at end of file diff --git a/pkg-r/examples/app-database.R b/pkg-r/examples/basic-database/app.R similarity index 61% rename from pkg-r/examples/app-database.R rename to pkg-r/examples/basic-database/app.R index 668b32c5..54e7248d 100644 --- a/pkg-r/examples/app-database.R +++ b/pkg-r/examples/basic-database/app.R @@ -21,19 +21,9 @@ conn <- dbConnect(RSQLite::SQLite(), temp_db) iris_data <- iris dbWriteTable(conn, "iris", iris_data, overwrite = TRUE) -# Define a custom greeting for the database app -greeting <- " -# Welcome to the Database Query Assistant! 📊 - -I can help you explore and analyze the iris dataset from the connected database. -Ask me questions about the iris flowers, and I'll generate SQL queries to get the answers. - -Try asking: -- Show me the first 10 rows of the iris dataset -- What's the average sepal length by species? -- Which species has the largest petals? -- Create a summary of measurements grouped by species -" +# Load greeting from external markdown file +greeting <- readLines("greeting.md", warn = FALSE) +greeting <- paste(greeting, collapse = "\n") # Create data source using querychat_data_source iris_source <- querychat_data_source(conn, table_name = "iris") @@ -46,24 +36,37 @@ querychat_config <- querychat_init( extra_instructions = "When showing results, always explain what the data represents and highlight any interesting patterns you observe." ) -ui <- page_sidebar( +ui <- bslib::page_sidebar( title = "Database Query Chat", sidebar = querychat_sidebar("chat"), - h2("Current Data View"), - p( - "The table below shows the current filtered data based on your chat queries:" + + bslib::card( + bslib::card_header("Current Data View"), + bslib::card_body( + p( + "The table below shows the current filtered data based on your chat queries:" + ), + DT::DTOutput("data_table", fill = FALSE) + ) ), - DT::DTOutput("data_table", fill = FALSE), - br(), - h3("Current SQL Query"), - verbatimTextOutput("sql_query"), - br(), - h3("Dataset Information"), - p("This demo database contains:"), - tags$ul( - tags$li("iris - Famous iris flower dataset (150 rows, 5 columns)"), - tags$li( - "Columns: Sepal.Length, Sepal.Width, Petal.Length, Petal.Width, Species" + + bslib::card( + bslib::card_header("Current SQL Query"), + bslib::card_body( + verbatimTextOutput("sql_query") + ) + ), + + bslib::card( + bslib::card_header("Dataset Information"), + bslib::card_body( + p("This demo database contains:"), + tags$ul( + tags$li("iris - Famous iris flower dataset (150 rows, 5 columns)"), + tags$li( + "Columns: Sepal.Length, Sepal.Width, Petal.Length, Petal.Width, Species" + ) + ) ) ) ) @@ -89,4 +92,4 @@ server <- function(input, output, session) { }) } -shinyApp(ui = ui, server = server) +shiny::shinyApp(ui = ui, server = server) diff --git a/pkg-r/examples/basic-database/greeting.md b/pkg-r/examples/basic-database/greeting.md new file mode 100644 index 00000000..50bc34f0 --- /dev/null +++ b/pkg-r/examples/basic-database/greeting.md @@ -0,0 +1,10 @@ +# Welcome to the Database Query Assistant! 📊 + +I can help you explore and analyze the iris dataset from the connected database. +Ask me questions about the iris flowers, and I'll generate SQL queries to get the answers. + +Try asking: +- Show me the first 10 rows of the iris dataset +- What's the average sepal length by species? +- Which species has the largest petals? +- Create a summary of measurements grouped by species \ No newline at end of file diff --git a/pkg-r/examples/basic-dataframe/README.md b/pkg-r/examples/basic-dataframe/README.md new file mode 100644 index 00000000..2ed14f15 --- /dev/null +++ b/pkg-r/examples/basic-dataframe/README.md @@ -0,0 +1,47 @@ +# Querychat Data Frame Example + +This example demonstrates how to use querychat with a regular R data frame. The app uses the Titanic dataset and provides a chat interface for querying the data. + +## Features + +- Uses a simple data frame (no database required) +- Natural language querying using the chat sidebar +- Display of query results in a table +- Display of the generated SQL query +- Information about the dataset structure + +## How It Works + +1. The app prepares the Titanic dataset as a data frame +2. A querychat data source is created with `querychat_data_source(titanic_expanded)` +3. The chat interface is configured with `querychat_init()` and a custom greeting +4. Users can enter natural language queries in the chat sidebar +5. Results are displayed in a table, and the corresponding SQL is shown + +## Under the Hood + +Even though we're using a data frame, querychat still uses SQL for query execution: + +- When using `querychat_data_source()` with a data frame, it creates a temporary in-memory DuckDB database +- Your natural language queries are converted to SQL +- SQL queries are executed against the DuckDB instance +- Results are returned as R data frames + +This approach ensures consistent behavior between data frames and external databases. + +## Running This Example + +To run this example: + +1. Make sure you have all dependencies installed: + - shiny, bslib, querychat, dplyr, DT, tidyr + +2. Run the app with: + ```r + shiny::runApp("path/to/basic-dataframe") + ``` + +3. Try asking questions in the chat sidebar like: + - "Show me the first 10 passengers" + - "What was the survival rate by class?" + - "Show me all children who survived" \ No newline at end of file diff --git a/pkg-r/examples/basic-dataframe/app.R b/pkg-r/examples/basic-dataframe/app.R new file mode 100644 index 00000000..15adf21f --- /dev/null +++ b/pkg-r/examples/basic-dataframe/app.R @@ -0,0 +1,134 @@ +library(shiny) +library(bslib) +library(querychat) +library(dplyr) +library(DT) + +# Prepare the Titanic dataset +titanic_data <- datasets::Titanic +titanic_df <- as.data.frame(titanic_data) + +# Rename and restructure the data to match a typical passenger list +names(titanic_df) <- c("Class", "Sex", "Age", "Survived", "Count") + +# Expand the data to have one row per passenger +titanic_expanded <- tidyr::uncount(titanic_df, Count) + +# Add a passenger ID column +titanic_expanded$PassengerId <- 1:nrow(titanic_expanded) + +# Reorder columns to have ID first +titanic_expanded <- titanic_expanded[, c( + "PassengerId", + "Class", + "Sex", + "Age", + "Survived" +)] + +# Load greeting from external markdown file +greeting_path <- file.path(getwd(), "greeting.md") +greeting <- readLines(greeting_path, warn = FALSE) +greeting <- paste(greeting, collapse = "\n") + +# Create data source using querychat_data_source with data frame +titanic_source <- querychat_data_source(titanic_expanded) + +# Configure querychat +querychat_config <- querychat_init( + data_source = titanic_source, + greeting = greeting, + data_description = "This is the Titanic dataset with information about passengers with columns for passenger ID, class (1st, 2nd, 3rd, or Crew), sex (Male or Female), age category (Adult or Child), and survival (Yes or No).", + extra_instructions = "When showing results, always explain survival patterns across different demographic groups." +) + +ui <- bslib::page_sidebar( + title = "Titanic Dataset Query Chat", + sidebar = querychat_sidebar("chat"), + + bslib::layout_column_wrap( + width = 1, + card( + card_header("Dataset Overview"), + card_body( + p( + "This example demonstrates using querychat with a data frame (rather than a database)." + ), + p("Ask questions about the Titanic dataset in the chat sidebar.") + ) + ) + ), + + bslib::layout_column_wrap( + width = 1, + card( + card_header("Query Results"), + card_body( + p("The table below shows the results from your chat queries:"), + DT::DTOutput("data_table", fill = FALSE) + ) + ) + ), + + bslib::layout_column_wrap( + width = 1, + card( + card_header("Generated SQL Query"), + card_body( + p( + "Even though we're using a data frame, querychat translates natural language to SQL under the hood:" + ), + verbatimTextOutput("sql_query") + ) + ) + ), + + bslib::layout_column_wrap( + width = 1, + card( + card_header("Dataset Information"), + card_body( + p("The Titanic dataset contains:"), + tags$ul( + tags$li( + strong("PassengerId:"), + "Unique identifier for each passenger" + ), + tags$li(strong("Class:"), "Passenger class (1st, 2nd, 3rd, or Crew)"), + tags$li(strong("Sex:"), "Passenger sex (Male or Female)"), + tags$li(strong("Age:"), "Age category (Adult or Child)"), + tags$li( + strong("Survived:"), + "Whether the passenger survived (Yes or No)" + ) + ) + ) + ) + ) +) + +server <- function(input, output, session) { + # Initialize querychat + chat <- querychat_server("chat", querychat_config) + + # Display query results + output$data_table <- DT::renderDT( + { + df <- chat$df() + df + }, + options = list(pageLength = 10, scrollX = TRUE) + ) + + # Display the SQL query + output$sql_query <- renderText({ + query <- chat$sql() + if (query == "") { + "No filter applied - showing all data" + } else { + query + } + }) +} + +shiny::shinyApp(ui = ui, server = server) diff --git a/pkg-r/examples/basic-dataframe/greeting.md b/pkg-r/examples/basic-dataframe/greeting.md new file mode 100644 index 00000000..87ad01ce --- /dev/null +++ b/pkg-r/examples/basic-dataframe/greeting.md @@ -0,0 +1,11 @@ +# Welcome to the Titanic Dataset Explorer! 🚢 + +I can help you explore and analyze data about Titanic passengers. +Ask me questions about passengers, survival rates, or demographic information. + +Try asking: +- Show me the first 10 passengers +- What was the survival rate for men vs women? +- How many passengers were in each class? +- What age groups had the highest survival rates? +- Show me all children who survived \ No newline at end of file diff --git a/pkg-r/examples/chained-query/README.md b/pkg-r/examples/chained-query/README.md new file mode 100644 index 00000000..ea88b996 --- /dev/null +++ b/pkg-r/examples/chained-query/README.md @@ -0,0 +1,60 @@ +# Querychat Query Chaining Example + +This example demonstrates how to use query chaining with the `querychat_server$tbl()` output. The app allows you to: + +1. Use natural language to query a database through the chat sidebar +2. Chain additional dplyr operations to the query results programmatically + +## Key Features + +- Uses SQLite database with Titanic dataset +- Demonstrates how to apply additional filters, sorts, and aggregations after a chat query +- Shows how to materialize results with `collect()` + +## Understanding Query Chaining + +The querychat package supports query chaining through the `tbl()` reactive output from `querychat_server()`. This reactive returns a lazy `dbplyr::tbl()` object that can be further manipulated with dplyr verbs. + +For example: + +```r +# Start with chat-based query +base_query <- chat$tbl() + +# Chain additional operations +results <- base_query %>% + filter(Age == "Adult") %>% + group_by(Class, Sex) %>% + summarize( + Total = n(), + Survived = sum(Survived == "Yes"), + Survival_Rate = round(sum(Survived == "Yes") / n() * 100, 1) + ) %>% + collect() +``` + +## How It Works + +1. When a user asks a question in the chat sidebar, querychat generates a SQL query +2. The `querychat_server$tbl()` function returns a lazy dbplyr table based on this query +3. You can chain additional dplyr operations to further refine or transform the data +4. The operations remain lazy until you call `collect()` to execute the query and retrieve results + +## Running This Example + +To run this example: + +1. Make sure you have all dependencies installed: + - shiny, bslib, querychat, DBI, RSQLite, dplyr, DT, tidyr, ggplot2, plotly + +2. Run the app with: + ```r + shiny::runApp("path/to/chained-query") + ``` + +3. Try asking questions in the chat sidebar like: + - "Show me all passengers" + - "Show me passengers who survived" + - "What was the survival rate by class?" + +4. Notice how the app shows both the direct query results and additional chained transformations \ No newline at end of file diff --git a/pkg-r/examples/chained-query/app.R b/pkg-r/examples/chained-query/app.R new file mode 100644 index 00000000..24c3cfd5 --- /dev/null +++ b/pkg-r/examples/chained-query/app.R @@ -0,0 +1,181 @@ +library(shiny) +library(bslib) +library(querychat) +library(DBI) +library(RSQLite) +library(dplyr) +library(DT) +library(ggplot2) +library(plotly) + +# Create a sample SQLite database with Titanic dataset +temp_db <- tempfile(fileext = ".db") +onStop(function() { + if (file.exists(temp_db)) { + unlink(temp_db) + } +}) + +conn <- dbConnect(RSQLite::SQLite(), temp_db) + +# Load Titanic dataset and prepare it for the database +titanic_data <- datasets::Titanic +titanic_df <- as.data.frame(titanic_data) + +# Rename and restructure the data to match a typical passenger list +names(titanic_df) <- c("Class", "Sex", "Age", "Survived", "Count") + +# Expand the data to have one row per passenger +titanic_expanded <- tidyr::uncount(titanic_df, Count) + +# Add a passenger ID column +titanic_expanded$PassengerId <- 1:nrow(titanic_expanded) + +# Reorder columns to have ID first +titanic_expanded <- titanic_expanded[, c( + "PassengerId", + "Class", + "Sex", + "Age", + "Survived" +)] + +# Write to SQLite database +dbWriteTable(conn, "titanic", titanic_expanded, overwrite = TRUE) + +# Load greeting from external markdown file +greeting <- readLines("greeting.md", warn = FALSE) +greeting <- paste(greeting, collapse = "\n") + +# Create data source using querychat_data_source +titanic_source <- querychat_data_source(conn, table_name = "titanic") + +# Configure querychat for database +querychat_config <- querychat_init( + data_source = titanic_source, + greeting = greeting, + data_description = "This database contains information about Titanic passengers with columns for passenger ID, class (1st, 2nd, 3rd, or Crew), sex (Male or Female), age category (Adult or Child), and survival (Yes or No).", + extra_instructions = "When showing results, always explain survival patterns across different demographic groups." +) + +ui <- bslib::page_sidebar( + title = "Titanic Query Chaining Demo", + sidebar = querychat_sidebar("chat"), + + # Main content + bslib::layout_column_wrap( + width = 1 / 2, + card( + card_header("Passenger Count Summary"), + card_body(plotlyOutput("passenger_chart")) + ), + card( + card_header("Survival Rate by Class"), + card_body(plotlyOutput("class_survival_chart")) + ) + ), + + bslib::layout_column_wrap( + width = 1, + card( + card_header("Query Results"), + card_body( + p("Basic results from chat query:"), + DTOutput("data_table") + ) + ) + ), + + hr(), + + card( + card_header("Current SQL Query"), + card_body(textOutput("sql_query")) + ), + + hr(), + + card( + card_header("About This Example"), + card_body( + p( + "This example demonstrates how to use querychat_server$tbl() to chain additional dplyr operations after a natural language query." + ), + p( + "The chat sidebar generates a base query, then we apply additional transformations programmatically." + ) + ) + ) +) + +server <- function(input, output, session) { + # Initialize querychat + chat <- querychat_server("chat", querychat_config) + + # Create high-level passenger counts chart + output$passenger_chart <- renderPlotly({ + # Get base data from current query or all passengers if no query + base_data <- chat$tbl() %>% + group_by(Class, Sex) %>% + summarize(Count = n(), .groups = "drop") %>% + collect() + + p <- ggplot(base_data, aes(x = Class, y = Count, fill = Sex)) + + geom_col(position = "dodge") + + theme_minimal() + + labs( + title = "Passenger Count by Class and Sex", + x = "Passenger Class", + y = "Count" + ) + + ggplotly(p) + }) + + # Create survival rate by class chart + output$class_survival_chart <- renderPlotly({ + survival_data <- chat$tbl() %>% + group_by(Class) %>% + summarize( + Total = n(), + Survived = sum(Survived == "Yes"), + ) %>% + mutate( + SurvivalRate = Survived / n() + ) %>% + collect() + + p <- ggplot(survival_data, aes(x = Class, y = SurvivalRate, fill = Class)) + + geom_col() + + theme_minimal() + + labs( + title = "Survival Rate by Class", + x = "Passenger Class", + y = "Survival Rate (%)" + ) + + scale_y_continuous(limits = c(0, 100)) + + ggplotly(p) + }) + + # Basic query results from chat + output$data_table <- DT::renderDT( + { + df <- chat$tbl() %>% collect() + df + }, + options = list(pageLength = 5) + ) + + # Show the current SQL query + output$sql_query <- renderText({ + query <- chat$sql() + if (query == "") { + "No query yet - try asking a question about the Titanic data!" + } else { + query + } + }) +} + +shiny::shinyApp(ui = ui, server = server) diff --git a/pkg-r/examples/chained-query/greeting.md b/pkg-r/examples/chained-query/greeting.md new file mode 100644 index 00000000..3159798b --- /dev/null +++ b/pkg-r/examples/chained-query/greeting.md @@ -0,0 +1,11 @@ +# Welcome to the Titanic Dataset Query Explorer! 🚢 + +I can help you explore and analyze data about Titanic passengers. +Ask me questions about passengers, survival rates, or demographic information, +and I'll generate SQL queries to get the answers. + +Try asking: +- Show me the first 10 passengers +- What was the survival rate for men vs women? +- How many passengers were in each class? +- What age groups had the highest survival rates? \ No newline at end of file diff --git a/pkg-r/examples/database-setup.md b/pkg-r/examples/database-setup.md deleted file mode 100644 index 31426f9e..00000000 --- a/pkg-r/examples/database-setup.md +++ /dev/null @@ -1,122 +0,0 @@ -# Database Setup Examples for querychat - -This document provides examples of how to set up querychat with various database types using the new `database_source()` functionality. - -## SQLite - -```r -library(DBI) -library(RSQLite) -library(querychat) - -# Connect to SQLite database -conn <- dbConnect(RSQLite::SQLite(), "path/to/your/database.db") - -# Create database source -db_source <- database_source(conn, "your_table_name") - -# Initialize querychat -config <- querychat_init( - data_source = db_source, - greeting = "Welcome! Ask me about your data.", - data_description = "Description of your data..." -) -``` - -## PostgreSQL - -```r -library(DBI) -library(RPostgreSQL) # or library(RPostgres) -library(querychat) - -# Connect to PostgreSQL -conn <- dbConnect( - RPostgreSQL::PostgreSQL(), # or RPostgres::Postgres() - dbname = "your_database", - host = "localhost", - port = 5432, - user = "your_username", - password = "your_password" -) - -# Create database source -db_source <- database_source(conn, "your_table_name") - -# Initialize querychat -config <- querychat_init(data_source = db_source) -``` - -## MySQL - -```r -library(DBI) -library(RMySQL) -library(querychat) - -# Connect to MySQL -conn <- dbConnect( - RMySQL::MySQL(), - dbname = "your_database", - host = "localhost", - user = "your_username", - password = "your_password" -) - -# Create database source -db_source <- database_source(conn, "your_table_name") - -# Initialize querychat -config <- querychat_init(data_source = db_source) -``` - -## Connection Management - -When using database sources in Shiny apps, make sure to properly manage connections: - -```r -server <- function(input, output, session) { - # Your querychat server logic here - chat <- querychat_server("chat", querychat_config) - - # Clean up connection when session ends - session$onSessionEnded(function() { - if (dbIsValid(conn)) { - dbDisconnect(conn) - } - }) -} -``` - -## Configuration Options - -The `database_source()` function accepts a `categorical_threshold` parameter: - -```r -# Columns with <= 50 unique values will be treated as categorical -db_source <- database_source(conn, "table_name", categorical_threshold = 50) -``` - -## Security Considerations - -- Only SELECT queries are allowed - no INSERT, UPDATE, or DELETE operations -- All SQL queries are visible to users for transparency -- Use appropriate database user permissions (read-only recommended) -- Consider connection pooling for production applications -- Validate that users only have access to intended tables - -## Error Handling - -The database source implementation includes robust error handling: - -- Validates table existence during creation -- Handles database connection issues gracefully -- Provides informative error messages for invalid queries -- Falls back gracefully when statistical queries fail - -## Performance Tips - -- Use appropriate database indexes for columns commonly used in queries -- Consider limiting row counts for very large tables -- Database connections are reused for better performance -- Schema information is cached to avoid repeated metadata queries \ No newline at end of file diff --git a/pkg-r/man/clean_sql.Rd b/pkg-r/man/clean_sql.Rd new file mode 100644 index 00000000..708a9c68 --- /dev/null +++ b/pkg-r/man/clean_sql.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/sql_utils.R +\name{clean_sql} +\alias{clean_sql} +\title{Clean SQL query for safe execution with dbplyr} +\usage{ +clean_sql(query, enforce_select = TRUE) +} +\arguments{ +\item{query}{A character string containing an SQL query} + +\item{enforce_select}{Logical, whether to validate that the query is a SELECT statement} +} +\value{ +A cleaned SQL query string or NULL if the query is empty or invalid +} +\description{ +This function cleans an SQL query by removing comments, trailing semicolons, +and handling other syntax issues that can cause problems with dbplyr's tbl() function. +} +\keyword{internal} diff --git a/pkg-r/man/create_system_prompt.Rd b/pkg-r/man/create_system_prompt.Rd index 34269018..ea78a5a4 100644 --- a/pkg-r/man/create_system_prompt.Rd +++ b/pkg-r/man/create_system_prompt.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/data_source.R +% Please edit documentation in R/schema_utils.R \name{create_system_prompt} \alias{create_system_prompt} \title{Create a system prompt for the data source} diff --git a/pkg-r/man/get_db_type.Rd b/pkg-r/man/get_db_type.Rd index e3fd6429..36cc02a5 100644 --- a/pkg-r/man/get_db_type.Rd +++ b/pkg-r/man/get_db_type.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/data_source.R +% Please edit documentation in R/schema_utils.R \name{get_db_type} \alias{get_db_type} \title{Get type information for a data source} diff --git a/pkg-r/man/get_lazy_data.Rd b/pkg-r/man/get_lazy_data.Rd new file mode 100644 index 00000000..4c2a75f4 --- /dev/null +++ b/pkg-r/man/get_lazy_data.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/data_source.R +\name{get_lazy_data} +\alias{get_lazy_data} +\title{Get a lazy representation of a data source} +\usage{ +get_lazy_data(source, query, ...) +} +\arguments{ +\item{source}{A querychat_data_source object} + +\item{query}{SQL query string} + +\item{...}{Additional arguments passed to methods} +} +\value{ +A lazy representation (typically a dbplyr tbl) +} +\description{ +Get a lazy representation of a data source +} diff --git a/pkg-r/man/get_schema.Rd b/pkg-r/man/get_schema.Rd index 22d24ff1..495a2fca 100644 --- a/pkg-r/man/get_schema.Rd +++ b/pkg-r/man/get_schema.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/data_source.R +% Please edit documentation in R/schema_utils.R \name{get_schema} \alias{get_schema} \title{Get schema for a data source} diff --git a/pkg-r/man/querychat_server.Rd b/pkg-r/man/querychat_server.Rd index eec8f892..09d6f68a 100644 --- a/pkg-r/man/querychat_server.Rd +++ b/pkg-r/man/querychat_server.Rd @@ -19,6 +19,9 @@ elements: \item \code{sql}: A reactive that returns the current SQL query. \item \code{title}: A reactive that returns the current title. \item \code{df}: A reactive that returns the filtered data as a data.frame. +\item \code{tbl}: A reactive that returns a lazy \code{dbplyr::tbl()} object that supports +lazy evaluation and query chaining. This can be further manipulated with +dplyr verbs before calling \code{collect()} to materialize the results. \item \code{chat}: The \link[ellmer:Chat]{ellmer::Chat} object that powers the chat interface. } } diff --git a/pkg-r/tests/testthat/test-lazy-tbl-errors.R b/pkg-r/tests/testthat/test-lazy-tbl-errors.R new file mode 100644 index 00000000..13b35d84 --- /dev/null +++ b/pkg-r/tests/testthat/test-lazy-tbl-errors.R @@ -0,0 +1,47 @@ +library(testthat) +library(DBI) +library(RSQLite) +library(dplyr) +library(querychat) + +test_that("get_lazy_data properly propagates errors without fallback", { + # Create a simple test dataframe + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + # Create a data source + df_source <- querychat_data_source(test_df, table_name = "test_table") + + # Invalid SQL query with syntax error + invalid_query <- "SELECT * FROM test_table WHERE non_existent_column = 'value'" + + # Check that the error is propagated rather than falling back to the full table + expect_error(get_lazy_data(df_source, invalid_query)) + + # Clean up + cleanup_source(df_source) +}) + +test_that("get_lazy_data errors on empty query after cleaning", { + # Create a simple test dataframe + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + # Create a data source + df_source <- querychat_data_source(test_df, table_name = "test_table") + + # Query that will be empty after cleaning (only comments) + comment_only_query <- "-- This is just a comment\n/* Another comment */" + + # Check that an error is raised instead of falling back to the full table + expect_error(get_lazy_data(df_source, comment_only_query), "empty query") + + # Clean up + cleanup_source(df_source) +}) diff --git a/pkg-r/tests/testthat/test-lazy-tbl.R b/pkg-r/tests/testthat/test-lazy-tbl.R new file mode 100644 index 00000000..ae2e50af --- /dev/null +++ b/pkg-r/tests/testthat/test-lazy-tbl.R @@ -0,0 +1,479 @@ +library(testthat) +library(DBI) +library(RSQLite) +library(dplyr) +library(querychat) + +test_that("get_lazy_data returns tbl objects", { + # Test with data frame source + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + df_source <- querychat_data_source(test_df, table_name = "test_table") + lazy_data <- get_lazy_data(df_source) + expect_s3_class(lazy_data, "tbl") + + # Test with DBI source + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) + + dbi_source <- querychat_data_source(conn, "test_table") + lazy_data <- get_lazy_data(dbi_source) + expect_s3_class(lazy_data, "tbl") + + # Test chaining with dplyr + filtered_data <- lazy_data %>% + dplyr::filter(value > 25) %>% + dplyr::collect() + expect_equal(nrow(filtered_data), 3) # Should return 3 rows (30, 40, 50) + + # Clean up + cleanup_source(df_source) + dbDisconnect(conn) + unlink(temp_db) +}) + +test_that("get_lazy_data works with empty query", { + # Test with data frame source + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + df_source <- querychat_data_source(test_df, table_name = "test_table") + + # Test with NULL query + lazy_data_null <- get_lazy_data(df_source, NULL) + expect_s3_class(lazy_data_null, "tbl") + result_null <- dplyr::collect(lazy_data_null) + expect_equal(nrow(result_null), 5) + + # Test with empty string query + lazy_data_empty <- get_lazy_data(df_source, "") + expect_s3_class(lazy_data_empty, "tbl") + result_empty <- dplyr::collect(lazy_data_empty) + expect_equal(nrow(result_empty), 5) + + # Test with DBI source + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) + + dbi_source <- querychat_data_source(conn, "test_table") + + # Test with NULL query + lazy_data_null <- get_lazy_data(dbi_source, NULL) + expect_s3_class(lazy_data_null, "tbl") + result_null <- dplyr::collect(lazy_data_null) + expect_equal(nrow(result_null), 5) + + # Test with empty string query + lazy_data_empty <- get_lazy_data(dbi_source, "") + expect_s3_class(lazy_data_empty, "tbl") + result_empty <- dplyr::collect(lazy_data_empty) + expect_equal(nrow(result_empty), 5) + + # Clean up + cleanup_source(df_source) + dbDisconnect(conn) + unlink(temp_db) +}) + +test_that("get_lazy_data handles problematic SQL with clean_sql", { + # Create a simple test dataframe + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + # Create data source + df_source <- querychat_data_source(test_df, table_name = "test_table") + + # Test with inline comments + inline_comment_query <- " + SELECT id, value -- This is a comment + FROM test_table + WHERE value > 25 -- Filter for higher values + " + + lazy_result <- get_lazy_data(df_source, inline_comment_query) + expect_s3_class(lazy_result, "tbl") + result <- dplyr::collect(lazy_result) + expect_equal(nrow(result), 3) # Should return 3 rows (30, 40, 50) + expect_equal(ncol(result), 2) + + # Test with multiple inline comments + multiple_comments_query <- " + SELECT -- Get only these columns + id, -- ID column + value -- Value column + FROM test_table -- Our test table + WHERE value > 25 -- Only higher values + " + + lazy_result <- get_lazy_data(df_source, multiple_comments_query) + expect_s3_class(lazy_result, "tbl") + result <- dplyr::collect(lazy_result) + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Test with trailing semicolons + query_with_semicolon <- " + SELECT id, value + FROM test_table + WHERE value > 25; + " + + lazy_result <- get_lazy_data(df_source, query_with_semicolon) + expect_s3_class(lazy_result, "tbl") + result <- dplyr::collect(lazy_result) + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Test with multiple semicolons + query_with_multiple_semicolons <- " + SELECT id, value + FROM test_table + WHERE value > 25;;;; + " + + lazy_result <- get_lazy_data(df_source, query_with_multiple_semicolons) + expect_s3_class(lazy_result, "tbl") + result <- dplyr::collect(lazy_result) + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Test with multiline comments + multiline_comment_query <- " + /* + * This is a multiline comment + * that spans multiple lines + */ + SELECT id, value + FROM test_table + WHERE value > 25 + " + + lazy_result <- get_lazy_data(df_source, multiline_comment_query) + expect_s3_class(lazy_result, "tbl") + result <- dplyr::collect(lazy_result) + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Test with a mix of comment styles and semicolons + complex_query <- " + /* + * This is a complex query with different comment styles + */ + SELECT + id, -- This is the ID column + value /* Value column */ + FROM + test_table -- Our test table + WHERE + /* Only get higher values */ + value > 25; -- End of query + " + + lazy_result <- get_lazy_data(df_source, complex_query) + expect_s3_class(lazy_result, "tbl") + result <- dplyr::collect(lazy_result) + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Clean up + cleanup_source(df_source) +}) + +test_that("querychat_server has tbl output", { + # Create a simple test dataframe + test_df <- data.frame( + id = 1:5, + name = c("Alice", "Bob", "Charlie", "Diana", "Eve"), + age = c(25, 30, 35, 28, 32), + stringsAsFactors = FALSE + ) + + # Create a data source + source <- querychat_data_source(test_df) + + # Create a mock current_query reactiveVal + current_query_val <- "SELECT * FROM test_df WHERE age > 30" + + # Test that get_lazy_data works with this query directly + lazy_result <- get_lazy_data(source, current_query_val) + expect_s3_class(lazy_result, "tbl") + + # Collect the data from the lazy_result + collected <- dplyr::collect(lazy_result) + + # Check that we get the expected results + expect_equal(nrow(collected), 2) # Both Charlie and Eve are over 30 + expect_true("Charlie" %in% collected$name) + expect_true("Eve" %in% collected$name) + + # Clean up + cleanup_source(source) +}) + +test_that("get_lazy_data returns tbl that supports full dplyr verb chaining", { + # Create a more complex test dataframe + test_df <- data.frame( + id = 1:10, + name = c( + "Alice", + "Bob", + "Charlie", + "Diana", + "Eve", + "Frank", + "Grace", + "Henry", + "Irene", + "Jack" + ), + age = c(25, 30, 35, 28, 32, 40, 22, 45, 33, 27), + department = c( + "Sales", + "IT", + "HR", + "IT", + "Sales", + "HR", + "Sales", + "IT", + "HR", + "Sales" + ), + salary = c( + 50000, + 65000, + 70000, + 62000, + 55000, + 75000, + 48000, + 80000, + 68000, + 52000 + ), + stringsAsFactors = FALSE + ) + + # Create a data source + source <- querychat_data_source(test_df) + + # Get a lazy tbl from the source + lazy_tbl <- get_lazy_data(source) + expect_s3_class(lazy_tbl, "tbl") + + # Test filter + filtered <- lazy_tbl %>% dplyr::filter(age > 30) + expect_s3_class(filtered, "tbl") + + # Test select + selected <- filtered %>% dplyr::select(name, department, salary) + expect_s3_class(selected, "tbl") + + # Test arrange + arranged <- selected %>% dplyr::arrange(desc(salary)) + expect_s3_class(arranged, "tbl") + + # Test mutate + mutated <- arranged %>% dplyr::mutate(bonus = salary * 0.1) + expect_s3_class(mutated, "tbl") + + # Test group_by and summarize + summarized <- lazy_tbl %>% + dplyr::group_by(department) %>% + dplyr::summarize( + avg_age = mean(age, na.rm = TRUE), + avg_salary = mean(salary, na.rm = TRUE), + count = n() + ) + expect_s3_class(summarized, "tbl") + + # Collect the results of the full chain + result <- mutated %>% dplyr::collect() + + # Check that we got the expected results + expect_equal(ncol(result), 4) # name, department, salary, bonus + expect_true("bonus" %in% colnames(result)) + expect_equal(nrow(result), 5) # Five people over 30 + expect_equal(result$name[1], "Henry") # Highest salary should be first + + # Check the summarized results + summary_result <- summarized %>% dplyr::collect() + expect_equal(nrow(summary_result), 3) # Three departments + expect_equal(sum(summary_result$count), 10) # Total count matches original + + # Clean up + cleanup_source(source) +}) + +test_that("get_lazy_data with query supports full dplyr verb chaining", { + # Create a test dataframe + test_df <- data.frame( + id = 1:10, + name = c( + "Alice", + "Bob", + "Charlie", + "Diana", + "Eve", + "Frank", + "Grace", + "Henry", + "Irene", + "Jack" + ), + age = c(25, 30, 35, 28, 32, 40, 22, 45, 33, 27), + department = c( + "Sales", + "IT", + "HR", + "IT", + "Sales", + "HR", + "Sales", + "IT", + "HR", + "Sales" + ), + salary = c( + 50000, + 65000, + 70000, + 62000, + 55000, + 75000, + 48000, + 80000, + 68000, + 52000 + ), + stringsAsFactors = FALSE + ) + + # Create a data source + source <- querychat_data_source(test_df) + + # Get a lazy tbl with a base query + query <- "SELECT * FROM test_df WHERE department = 'IT'" + lazy_tbl <- get_lazy_data(source, query) + expect_s3_class(lazy_tbl, "tbl") + + # Test a complex chain of operations + result <- lazy_tbl %>% + dplyr::filter(age >= 30) %>% + dplyr::select(id, name, age, salary, department) %>% # Include department for the test below + dplyr::mutate( + bonus = case_when( + salary >= 70000 ~ salary * 0.15, + salary >= 60000 ~ salary * 0.10, + TRUE ~ salary * 0.05 + ) + ) %>% + dplyr::arrange(desc(bonus)) %>% + dplyr::collect() + + # Check the results + expect_equal(nrow(result), 2) # Bob and Henry from IT, over 30 + expect_equal(ncol(result), 6) # id, name, age, salary, department, bonus + expect_equal(result$name[1], "Henry") # Should be first with highest bonus + expect_true(all(result$department %in% c("IT"))) # Only IT department + + # Clean up + cleanup_source(source) +}) + +test_that("get_lazy_data handles complex SQL queries with dplyr chaining", { + # Create test dataframe + test_df <- data.frame( + id = 1:10, + name = c( + "Alice", + "Bob", + "Charlie", + "Diana", + "Eve", + "Frank", + "Grace", + "Henry", + "Irene", + "Jack" + ), + age = c(25, 30, 35, 28, 32, 40, 22, 45, 33, 27), + department = c( + "Sales", + "IT", + "HR", + "IT", + "Sales", + "HR", + "Sales", + "IT", + "HR", + "Sales" + ), + salary = c( + 50000, + 65000, + 70000, + 62000, + 55000, + 75000, + 48000, + 80000, + 68000, + 52000 + ), + stringsAsFactors = FALSE + ) + + # Create a data source + source <- querychat_data_source(test_df) + + # Complex query with comments, subqueries, and functions + complex_query <- " + /* This is a complex query with subqueries and functions */ + SELECT + id, + name, + age, + department, + salary, -- original salary + ROUND(salary * 1.05) AS projected_salary -- with 5% increase + FROM test_df + WHERE + age > (SELECT AVG(age) FROM test_df) -- only above average age + AND department IN ('IT', 'HR') -- only certain departments + ORDER BY salary DESC; -- sort by salary + " + + # Get lazy tbl with complex query + lazy_tbl <- get_lazy_data(source, complex_query) + expect_s3_class(lazy_tbl, "tbl") + + # Chain dplyr operations + result <- lazy_tbl %>% + dplyr::filter(projected_salary > 70000) %>% + dplyr::mutate(bonus_eligible = projected_salary > 75000) %>% + dplyr::select(name, department, projected_salary, bonus_eligible) %>% + dplyr::collect() + + # Check results + expect_s3_class(result, "data.frame") + expect_true(all(result$projected_salary > 70000)) + expect_true("bonus_eligible" %in% colnames(result)) + expect_true(all(result$department %in% c("IT", "HR"))) + + # Clean up + cleanup_source(source) +}) diff --git a/pkg-r/tests/testthat/test-sql-cleaning.R b/pkg-r/tests/testthat/test-sql-cleaning.R new file mode 100644 index 00000000..a2f54f4b --- /dev/null +++ b/pkg-r/tests/testthat/test-sql-cleaning.R @@ -0,0 +1,229 @@ +library(testthat) +library(querychat) + +# Access the internal clean_sql function for testing +clean_sql <- querychat:::clean_sql + +test_that("clean_sql handles comments correctly", { + # Inline comments + expect_equal( + clean_sql("SELECT * FROM table -- This is a comment"), + "SELECT * FROM table" + ) + + # Multiline comments + expect_equal( + clean_sql("SELECT * FROM /* this is a comment */ table"), + "SELECT * FROM table" + ) + + # Nested multiline comments + expect_equal( + clean_sql("SELECT * FROM /* outer /* nested */ comment */ table"), + "SELECT * FROM comment */ table" + ) + + # Comment with asterisks + expect_equal( + clean_sql("SELECT * FROM table /* ** multiple ** asterisks ** */"), + "SELECT * FROM table" + ) + + # Comment at the beginning + expect_equal( + clean_sql("-- This is a comment\nSELECT * FROM table"), + "SELECT * FROM table" + ) + + # Comment only query + expect_null(clean_sql("-- just a comment")) + expect_null(clean_sql("/* just a comment */")) +}) + +test_that("clean_sql handles semicolons correctly", { + # Trailing semicolon + result <- clean_sql("SELECT * FROM table;") + expect_equal( + result, + "SELECT * FROM table" + ) + + # Multiple trailing semicolons (should be treated as multiple statements with empty ones) + expect_warning( + result <- clean_sql("SELECT * FROM table;;;"), + "Multiple SQL statements detected" + ) + expect_equal(result, "SELECT * FROM table") + + # Multiple statements (should keep only the first one and warn) + expect_warning( + result <- clean_sql("SELECT * FROM table1; SELECT * FROM table2;"), + "Multiple SQL statements detected" + ) + expect_equal(result, "SELECT * FROM table1") + + # Warning message includes information about the SQL statement + # No need to test this explicitly again since we've already verified the warning is thrown + # and the warning format is consistent + + # Semicolon in quoted string (should be preserved) + sql_with_quoted_semicolon <- clean_sql( + "SELECT * FROM table WHERE col = 'text;with;semicolons'" + ) + expect_match(sql_with_quoted_semicolon, "text;with;semicolons", fixed = TRUE) + + # Complex case with multiple statements, comments, and quoted semicolons + complex_sql <- " + /* Comment at start */ + SELECT * + FROM table1 + WHERE col = 'text;with;semicolons' -- inline comment + AND col2 > 10 + " + + cleaned_sql <- clean_sql(complex_sql) + expect_match(cleaned_sql, "text;with;semicolons", fixed = TRUE) +}) + +test_that("clean_sql detects and handles unbalanced quotes", { + # Unbalanced single quotes + expect_warning( + result <- clean_sql("SELECT * FROM table WHERE col = 'unbalanced"), + "unbalanced single quotes" + ) + expect_equal(result, "SELECT * FROM table WHERE col = 'unbalanced'") + + # Unbalanced double quotes + expect_warning( + result <- clean_sql('SELECT * FROM table WHERE col = "unbalanced'), + "unbalanced double quotes" + ) + expect_equal(result, 'SELECT * FROM table WHERE col = "unbalanced"') + + # Balanced quotes should not trigger warnings + expect_silent( + clean_sql("SELECT * FROM table WHERE col = 'balanced'") + ) +}) + +test_that("clean_sql detects and handles unbalanced parentheses", { + # Unbalanced open parentheses + expect_warning( + result <- clean_sql("SELECT * FROM table WHERE (col = 10 AND (col2 = 20"), + "unbalanced parentheses" + ) + expect_equal(result, "SELECT * FROM table WHERE (col = 10 AND (col2 = 20))") + + # Unbalanced close parentheses + expect_warning( + clean_sql("SELECT * FROM table WHERE (col = 10))"), + "unbalanced parentheses" + ) + + # Balanced parentheses should not trigger warnings + expect_silent( + clean_sql("SELECT * FROM table WHERE (col = 10)") + ) +}) + +test_that("clean_sql handles GO statements", { + # Simple GO statement + expect_equal( + clean_sql("SELECT * FROM table\nGO"), + "SELECT * FROM table" + ) + + # GO with subsequent commands (should be removed) + sql_with_go <- " + SELECT * FROM table1 + GO + SELECT * FROM table2 + " + cleaned_sql <- clean_sql(sql_with_go) + expect_true(!grepl("GO", cleaned_sql)) + # The second SELECT might still be there, but the GO is gone +}) + +test_that("clean_sql filters non-standard characters", { + # Non-ASCII characters + expect_equal( + clean_sql("SELECT * FROM table WHERE col = 'special\u2013char'"), + "SELECT * FROM table WHERE col = 'specialchar'" + ) + + # Control characters + expect_equal( + clean_sql("SELECT * FROM\ttable"), + "SELECT * FROM\ttable" + ) + + # Non-printable characters + expect_equal( + clean_sql("SELECT * FROM table\x01name"), + "SELECT * FROM tablename" + ) +}) + +test_that("clean_sql validates SELECT statements", { + # Valid SELECT + expect_silent( + clean_sql("SELECT * FROM table", enforce_select = TRUE) + ) + + # Not a SELECT statement + expect_error( + clean_sql("UPDATE table SET col = 10", enforce_select = TRUE), + "not appear to start with SELECT" + ) + + # Non-SELECT but enforce_select = FALSE + expect_silent( + clean_sql("UPDATE table SET col = 10", enforce_select = FALSE) + ) +}) + +test_that("clean_sql works with edge cases", { + # Empty query + expect_null(clean_sql("")) + expect_null(clean_sql(" ")) + + # Only comments + expect_null(clean_sql("-- comment only")) + + # Query with just whitespace after cleaning + expect_null(clean_sql("/* everything is a comment */")) +}) + +test_that("clean_sql handles complex SQL queries", { + # Complex query with subqueries, joins, and functions + complex_query <- " + /* This tests a complex query */ + SELECT + t1.col1, + t2.col2, + COALESCE(t1.col3, 'N/A') AS col3, + (SELECT COUNT(*) FROM table3 WHERE table3.id = t1.id) AS subquery_result + FROM + table1 t1 + LEFT JOIN + table2 t2 ON t1.id = t2.id + WHERE + t1.col1 IN ( + SELECT col1 + FROM table4 + WHERE active = 1 + ) + GROUP BY + t1.col1, t2.col2 + HAVING + COUNT(*) > 0 + ORDER BY + t1.col1 DESC + " + + cleaned_sql <- clean_sql(complex_query) + expect_true(grepl("LEFT JOIN", cleaned_sql)) + expect_true(grepl("GROUP BY", cleaned_sql)) + expect_true(grepl("HAVING", cleaned_sql)) + expect_true(grepl("subquery_result", cleaned_sql)) +})