diff --git a/.github/workflows/py-test.yml b/.github/workflows/py-test.yml index 4486a805a..2e5742330 100644 --- a/.github/workflows/py-test.yml +++ b/.github/workflows/py-test.yml @@ -37,8 +37,8 @@ jobs: - name: πŸ“¦ Install the project run: uv sync --python ${{matrix.config.python-version }} --all-extras --all-groups - # - name: πŸ§ͺ Check tests - # run: make py-check-tests + - name: πŸ§ͺ Check tests + run: make py-check-tests - name: πŸ“ Check types run: make py-check-types diff --git a/.gitignore b/.gitignore index 67ecb914f..7a68cf392 100644 --- a/.gitignore +++ b/.gitignore @@ -250,8 +250,13 @@ po/*~ # RStudio Connect folder rsconnect/ +python-package/CLAUDE.md uv.lock _dev +# R ignores /.quarto/ +.Rprofile +renv/ +renv.lock diff --git a/Makefile b/Makefile index 0a93a448d..a82b5b618 100644 --- a/Makefile +++ b/Makefile @@ -123,12 +123,11 @@ py-check-tox: ## [py] Run python 3.9 - 3.12 checks with tox @echo "πŸ”„ Running tests and type checking with tox for Python 3.9--3.12" uv run tox run-parallel -# .PHONY: py-check-tests -# py-check-tests: ## [py] Run python tests -# @echo "" -# @echo "πŸ§ͺ Running tests with pytest" -# uv run playwright install -# uv run pytest +.PHONY: py-check-tests +py-check-tests: ## [py] Run python tests + @echo "" + @echo "πŸ§ͺ Running tests with pytest" + uv run pytest .PHONY: py-check-types py-check-types: ## [py] Run python type checks diff --git a/README.md b/README.md index 4a07aa781..8ede04f52 100644 --- a/README.md +++ b/README.md @@ -36,11 +36,11 @@ querychat does not have direct access to the raw data; it can _only_ read or fil - **Transparency:** querychat always displays the SQL to the user, so it can be vetted instead of blindly trusted. - **Reproducibility:** The SQL query can be easily copied and reused. -Currently, querychat uses DuckDB for its SQL engine. It's extremely fast and has a surprising number of statistical functions. +Currently, querychat uses DuckDB for its SQL engine when working with data frames. For database sources, it uses the native SQL dialect of the connected database. ## Language-specific Documentation For detailed information on how to use querychat in your preferred language, see the language-specific READMEs: - [R Documentation](pkg-r/README.md) -- [Python Documentation](pkg-py/README.md) +- [Python Documentation](pkg-py/README.md) \ No newline at end of file diff --git a/pkg-py/examples/app.py b/pkg-py/examples/app.py index b14777909..5870d21cc 100644 --- a/pkg-py/examples/app.py +++ b/pkg-py/examples/app.py @@ -49,4 +49,4 @@ def data_table(): # Create Shiny app -app = App(app_ui, server) +app = App(app_ui, server) \ No newline at end of file diff --git a/pkg-py/src/querychat/__init__.py b/pkg-py/src/querychat/__init__.py index 660a202f6..71dce11c4 100644 --- a/pkg-py/src/querychat/__init__.py +++ b/pkg-py/src/querychat/__init__.py @@ -1,5 +1,13 @@ -from querychat.querychat import init, sidebar, system_prompt -from querychat.querychat import mod_server as server -from querychat.querychat import mod_ui as ui +from querychat.querychat import ( + init, + sidebar, + system_prompt, +) +from querychat.querychat import ( + mod_server as server, +) +from querychat.querychat import ( + mod_ui as ui, +) __all__ = ["init", "server", "sidebar", "system_prompt", "ui"] diff --git a/pkg-py/src/querychat/datasource.py b/pkg-py/src/querychat/datasource.py index 24c3a30fa..3261e0a93 100644 --- a/pkg-py/src/querychat/datasource.py +++ b/pkg-py/src/querychat/datasource.py @@ -178,7 +178,7 @@ def __init__(self, engine: Engine, table_name: str): if not inspector.has_table(table_name): raise ValueError(f"Table '{table_name}' not found in database") - def get_schema(self, *, categorical_threshold: int) -> str: + def get_schema(self, *, categorical_threshold: int) -> str: # noqa: PLR0912 """ Generate schema information from database table. @@ -191,12 +191,15 @@ def get_schema(self, *, categorical_threshold: int) -> str: schema = [f"Table: {self._table_name}", "Columns:"] + # Build a single query to get all column statistics + select_parts = [] + numeric_columns = [] + text_columns = [] + for col in columns: - # Get SQL type name - sql_type = self._get_sql_type_name(col["type"]) - column_info = [f"- {col['name']} ({sql_type})"] + col_name = col["name"] - # For numeric columns, try to get range + # Check if column is numeric if isinstance( col["type"], ( @@ -208,44 +211,103 @@ def get_schema(self, *, categorical_threshold: int) -> str: sqltypes.DateTime, sqltypes.BigInteger, sqltypes.SmallInteger, - # sqltypes.Interval, ), ): - try: - query = text( - f"SELECT MIN({col['name']}), MAX({col['name']}) FROM {self._table_name}", - ) - with self._get_connection() as conn: - result = conn.execute(query).fetchone() - if result and result[0] is not None and result[1] is not None: - column_info.append(f" Range: {result[0]} to {result[1]}") - except Exception: # noqa: S110 - pass # Silently skip range info if query fails - - # For string/text columns, check if categorical + numeric_columns.append(col_name) + select_parts.extend( + [ + f"MIN({col_name}) as {col_name}__min", + f"MAX({col_name}) as {col_name}__max", + ], + ) + + # Check if column is text/string elif isinstance( col["type"], (sqltypes.String, sqltypes.Text, sqltypes.Enum), ): - try: - count_query = text( - f"SELECT COUNT(DISTINCT {col['name']}) FROM {self._table_name}", - ) + text_columns.append(col_name) + select_parts.append( + f"COUNT(DISTINCT {col_name}) as {col_name}__distinct_count", + ) + + # Execute single query to get all statistics + column_stats = {} + if select_parts: + try: + stats_query = text( + f"SELECT {', '.join(select_parts)} FROM {self._table_name}", + ) + with self._get_connection() as conn: + result = conn.execute(stats_query).fetchone() + if result: + # Convert result to dict for easier access + column_stats = dict(zip(result._fields, result)) + except Exception: # noqa: S110 + pass # Fall back to no statistics if query fails + + # Get categorical values for text columns that are below threshold + categorical_values = {} + text_cols_to_query = [] + for col_name in text_columns: + distinct_count_key = f"{col_name}__distinct_count" + if ( + distinct_count_key in column_stats + and column_stats[distinct_count_key] + and column_stats[distinct_count_key] <= categorical_threshold + ): + text_cols_to_query.append(col_name) + + # Get categorical values in a single query if needed + if text_cols_to_query: + try: + # Build UNION query for all categorical columns + union_parts = [ + f"SELECT '{col_name}' as column_name, {col_name} as value " + f"FROM {self._table_name} WHERE {col_name} IS NOT NULL " + f"GROUP BY {col_name}" + for col_name in text_cols_to_query + ] + + if union_parts: + categorical_query = text(" UNION ALL ".join(union_parts)) with self._get_connection() as conn: - distinct_count = conn.execute(count_query).scalar() - if distinct_count and distinct_count <= categorical_threshold: - values_query = text( - f"SELECT DISTINCT {col['name']} FROM {self._table_name} " - f"WHERE {col['name']} IS NOT NULL", - ) - values = [ - str(row[0]) - for row in conn.execute(values_query).fetchall() - ] - values_str = ", ".join([f"'{v}'" for v in values]) - column_info.append(f" Categorical values: {values_str}") - except Exception: # noqa: S110 - pass # Silently skip categorical info if query fails + results = conn.execute(categorical_query).fetchall() + for row in results: + col_name, value = row + if col_name not in categorical_values: + categorical_values[col_name] = [] + categorical_values[col_name].append(str(value)) + except Exception: # noqa: S110 + pass # Skip categorical values if query fails + + # Build schema description using collected statistics + for col in columns: + col_name = col["name"] + sql_type = self._get_sql_type_name(col["type"]) + column_info = [f"- {col_name} ({sql_type})"] + + # Add range info for numeric columns + if col_name in numeric_columns: + min_key = f"{col_name}__min" + max_key = f"{col_name}__max" + if ( + min_key in column_stats + and max_key in column_stats + and column_stats[min_key] is not None + and column_stats[max_key] is not None + ): + column_info.append( + f" Range: {column_stats[min_key]} to {column_stats[max_key]}", + ) + + # Add categorical values for text columns + elif col_name in categorical_values: + values = categorical_values[col_name] + # Remove duplicates and sort + unique_values = sorted(set(values)) + values_str = ", ".join([f"'{v}'" for v in unique_values]) + column_info.append(f" Categorical values: {values_str}") schema.extend(column_info) diff --git a/pkg-py/src/querychat/querychat.py b/pkg-py/src/querychat/querychat.py index a999f0c7c..7bca2079d 100644 --- a/pkg-py/src/querychat/querychat.py +++ b/pkg-py/src/querychat/querychat.py @@ -118,19 +118,12 @@ def __getitem__(self, key: str) -> Any: backwards compatibility only; new code should use the attributes directly instead. """ - if key == "chat": # noqa: SIM116 - return self.chat - elif key == "sql": - return self.sql - elif key == "title": - return self.title - elif key == "df": - return self.df - - raise KeyError( - f"`QueryChat` does not have a key `'{key}'`. " - "Use the attributes `chat`, `sql`, `title`, or `df` instead.", - ) + return { + "chat": self.chat, + "sql": self.sql, + "title": self.title, + "df": self.df, + }.get(key) def system_prompt( diff --git a/pkg-py/tests/__init__.py b/pkg-py/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/pkg-py/tests/test_datasource.py b/pkg-py/tests/test_datasource.py new file mode 100644 index 000000000..734cc4c7d --- /dev/null +++ b/pkg-py/tests/test_datasource.py @@ -0,0 +1,219 @@ +import sqlite3 +import tempfile +from pathlib import Path + +import pytest +from sqlalchemy import create_engine, text +from src.querychat.datasource import SQLAlchemySource + + +@pytest.fixture +def test_db_engine(): + """Create a temporary SQLite database with test data.""" + # Create temporary database file + temp_db = tempfile.NamedTemporaryFile(delete=False, suffix=".db") # noqa: SIM115 + temp_db.close() + + # Connect and create test table with various data types + conn = sqlite3.connect(temp_db.name) + cursor = conn.cursor() + + # Create table with different column types + cursor.execute(""" + CREATE TABLE test_table ( + id INTEGER PRIMARY KEY, + name TEXT, + age INTEGER, + salary REAL, + is_active BOOLEAN, + join_date DATE, + category TEXT, + score NUMERIC, + description TEXT + ) + """) + + # Insert test data + test_data = [ + (1, "Alice", 30, 75000.50, True, "2023-01-15", "A", 95.5, "Senior developer"), + (2, "Bob", 25, 60000.00, True, "2023-03-20", "B", 87.2, "Junior developer"), + (3, "Charlie", 35, 85000.75, False, "2022-12-01", "A", 92.1, "Team lead"), + ( + 4, + "Diana", + 28, + 70000.25, + True, + "2023-05-10", + "C", + 89.8, + "Mid-level developer", + ), + (5, "Eve", 32, 80000.00, True, "2023-02-28", "A", 91.3, "Senior developer"), + (6, "Frank", 26, 62000.50, False, "2023-04-15", "B", 85.7, "Junior developer"), + (7, "Grace", 29, 72000.75, True, "2023-01-30", "A", 93.4, "Developer"), + (8, "Henry", 31, 78000.25, True, "2023-03-05", "C", 88.9, "Senior developer"), + ] + + cursor.executemany( + """ + INSERT INTO test_table + (id, name, age, salary, is_active, join_date, category, score, description) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + """, + test_data, + ) + + conn.commit() + conn.close() + + # Create SQLAlchemy engine + engine = create_engine(f"sqlite:///{temp_db.name}") + + yield engine + + # Cleanup + Path(temp_db.name).unlink() + + +def test_get_schema_numeric_ranges(test_db_engine): + """Test that numeric columns include min/max ranges.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + # Check that numeric columns have range information + assert "- id (INTEGER)" in schema + assert "Range: 1 to 8" in schema + + assert "- age (INTEGER)" in schema + assert "Range: 25 to 35" in schema + + assert "- salary (FLOAT)" in schema + assert "Range: 60000.0 to 85000.75" in schema + + assert "- score (NUMERIC)" in schema + assert "Range: 85.7 to 95.5" in schema + + +def test_get_schema_categorical_values(test_db_engine): + """Test that text columns with few unique values show categorical values.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + # Category column should be treated as categorical (3 unique values: A, B, C) + assert "- category (TEXT)" in schema + assert "Categorical values:" in schema + assert "'A'" in schema and "'B'" in schema and "'C'" in schema # noqa: PT018 + + +def test_get_schema_non_categorical_text(test_db_engine): + """Test that text columns with many unique values don't show categorical values.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=3) + + # Name and description columns should not be categorical (8 and 6 unique values respectively) + lines = schema.split("\n") + name_line_idx = next(i for i, line in enumerate(lines) if "- name (TEXT)" in line) + description_line_idx = next( + i for i, line in enumerate(lines) if "- description (TEXT)" in line + ) + + # Check that the next line after name column doesn't contain categorical values + if name_line_idx + 1 < len(lines): + assert "Categorical values:" not in lines[name_line_idx + 1] + + # Check that the next line after description column doesn't contain categorical values + if description_line_idx + 1 < len(lines): + assert "Categorical values:" not in lines[description_line_idx + 1] + + +def test_get_schema_different_thresholds(test_db_engine): + """Test that categorical_threshold parameter works correctly.""" + source = SQLAlchemySource(test_db_engine, "test_table") + + # With threshold 2, only category column (3 unique) should not be categorical + schema_low = source.get_schema(categorical_threshold=2) + assert "- category (TEXT)" in schema_low + assert "'A'" not in schema_low # Should not show categorical values + + # With threshold 5, category column should be categorical + schema_high = source.get_schema(categorical_threshold=5) + assert "- category (TEXT)" in schema_high + assert "'A'" in schema_high # Should show categorical values + + +def test_get_schema_table_structure(test_db_engine): + """Test the overall structure of the schema output.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + lines = schema.split("\n") + + # Check header + assert lines[0] == "Table: test_table" + assert lines[1] == "Columns:" + + # Check that all columns are present + expected_columns = [ + "id", + "name", + "age", + "salary", + "is_active", + "join_date", + "category", + "score", + "description", + ] + for col in expected_columns: + assert any(f"- {col} (" in line for line in lines), ( + f"Column {col} not found in schema" + ) + + +def test_get_schema_empty_result_handling(test_db_engine): + """Test handling when statistics queries return empty results.""" + # Create empty table + conn = sqlite3.connect(":memory:") + cursor = conn.cursor() + cursor.execute("CREATE TABLE empty_table (id INTEGER, name TEXT)") + conn.commit() + + engine = create_engine("sqlite:///:memory:") + # Recreate table in the new engine + with engine.connect() as connection: + connection.execute(text("CREATE TABLE empty_table (id INTEGER, name TEXT)")) + connection.commit() + + source = SQLAlchemySource(engine, "empty_table") + schema = source.get_schema(categorical_threshold=5) + + # Should still work but without range/categorical info + assert "Table: empty_table" in schema + assert "- id (INTEGER)" in schema + assert "- name (TEXT)" in schema + # Should not have range or categorical information + assert "Range:" not in schema + assert "Categorical values:" not in schema + + +def test_get_schema_boolean_and_date_types(test_db_engine): + """Test handling of boolean and date column types.""" + source = SQLAlchemySource(test_db_engine, "test_table") + schema = source.get_schema(categorical_threshold=5) + + # Boolean column should show range + assert "- is_active (BOOLEAN)" in schema + # SQLite stores booleans as integers, so should show 0 to 1 range + + # Date column should show range + assert "- join_date (DATE)" in schema + assert "Range:" in schema + + +def test_invalid_table_name(): + """Test that invalid table name raises appropriate error.""" + engine = create_engine("sqlite:///:memory:") + + with pytest.raises(ValueError, match="Table 'nonexistent' not found in database"): + SQLAlchemySource(engine, "nonexistent") diff --git a/pkg-r/DESCRIPTION b/pkg-r/DESCRIPTION index d602bde47..cacfb1273 100644 --- a/pkg-r/DESCRIPTION +++ b/pkg-r/DESCRIPTION @@ -19,15 +19,19 @@ Imports: DBI, duckdb, ellmer, - glue, htmltools, - jsonlite, purrr, rlang, shiny, shinychat (>= 0.2.0), whisker, xtable +Suggests: + DT, + RSQLite, + shinytest2, + testthat (>= 3.0.0) +Config/testthat/edition: 3 Encoding: UTF-8 Roxygen: list(markdown = TRUE) RoxygenNote: 7.3.2 diff --git a/pkg-r/NAMESPACE b/pkg-r/NAMESPACE index 077a6ed08..8c75247d0 100644 --- a/pkg-r/NAMESPACE +++ b/pkg-r/NAMESPACE @@ -1,8 +1,22 @@ # Generated by roxygen2: do not edit by hand -export(df_to_schema) +S3method(cleanup_source,dbi_source) +S3method(create_system_prompt,querychat_data_source) +S3method(execute_query,dbi_source) +S3method(get_db_type,data_frame_source) +S3method(get_db_type,dbi_source) +S3method(get_schema,dbi_source) +S3method(querychat_data_source,DBIConnection) +S3method(querychat_data_source,data.frame) +S3method(test_query,dbi_source) +export(cleanup_source) +export(create_system_prompt) +export(execute_query) +export(get_db_type) +export(get_schema) +export(querychat_data_source) export(querychat_init) export(querychat_server) export(querychat_sidebar) -export(querychat_system_prompt) export(querychat_ui) +export(test_query) diff --git a/pkg-r/R/data_source.R b/pkg-r/R/data_source.R new file mode 100644 index 000000000..9a1282ef3 --- /dev/null +++ b/pkg-r/R/data_source.R @@ -0,0 +1,448 @@ +#' Create a data source for querychat +#' +#' Generic function to create a data source for querychat. This function +#' dispatches to appropriate methods based on input. +#' +#' @param x A data frame or DBI connection +#' @param table_name The name to use for the table in the data source. Can be: +#' - A character string (e.g., "table_name") +#' - Or, for tables contained within catalogs or schemas, a [DBI::Id()] object (e.g., `DBI::Id(schema = "schema_name", table = "table_name")`) +#' @param categorical_threshold For text columns, the maximum number of unique values to consider as a categorical variable +#' @param ... Additional arguments passed to specific methods +#' @return A querychat_data_source object +#' @export +querychat_data_source <- function(x, ...) { + UseMethod("querychat_data_source") +} + +#' @export +#' @rdname querychat_data_source +querychat_data_source.data.frame <- function( + x, + table_name = NULL, + categorical_threshold = 20, + ... +) { + if (is.null(table_name)) { + # Infer table name from dataframe name, if not already added + table_name <- deparse(substitute(x)) + if (is.null(table_name) || table_name == "NULL" || table_name == "x") { + rlang::abort( + "Unable to infer table name. Please specify `table_name` argument explicitly." + ) + } + } + + is_table_name_ok <- is.character(table_name) && + length(table_name) == 1 && + grepl("^[a-zA-Z][a-zA-Z0-9_]*$", table_name, perl = TRUE) + if (!is_table_name_ok) { + rlang::abort( + "`table_name` argument must be a string containing a valid table name." + ) + } + + # Create duckdb connection + conn <- DBI::dbConnect(duckdb::duckdb(), dbdir = ":memory:") + duckdb::duckdb_register(conn, table_name, x, experimental = FALSE) + + structure( + list( + conn = conn, + table_name = table_name, + categorical_threshold = categorical_threshold + ), + class = c("data_frame_source", "dbi_source", "querychat_data_source") + ) +} + +#' @export +#' @rdname querychat_data_source +querychat_data_source.DBIConnection <- function( + x, + table_name, + categorical_threshold = 20, + ... +) { + # Handle different types of table_name inputs + if (inherits(table_name, "Id")) { + # DBI::Id object - keep as is + } else if (is.character(table_name) && length(table_name) == 1) { + # Character string - keep as is + } else { + # Invalid input + rlang::abort( + "`table_name` must be a single character string or a DBI::Id object" + ) + } + + # Check if table exists + if (!DBI::dbExistsTable(x, table_name)) { + rlang::abort(paste0( + "Table ", + DBI::dbQuoteIdentifier(x, table_name), + " not found in database. If you're using a table in a catalog or schema, pass a DBI::Id", + " object to `table_name`" + )) + } + + structure( + list( + conn = x, + table_name = table_name, + categorical_threshold = categorical_threshold + ), + class = c("dbi_source", "querychat_data_source") + ) +} + +#' Execute a SQL query on a data source +#' +#' @param source A querychat_data_source object +#' @param query SQL query string +#' @param ... Additional arguments passed to methods +#' @return Result of the query as a data frame +#' @export +execute_query <- function(source, query, ...) { + UseMethod("execute_query") +} + +#' @export +execute_query.dbi_source <- function(source, query, ...) { + if (is.null(query) || query == "") { + # For a null or empty query, default to returning the whole table (ie SELECT *) + query <- paste0( + "SELECT * FROM ", + DBI::dbQuoteIdentifier(source$conn, source$table_name) + ) + } + # Execute the query directly + DBI::dbGetQuery(source$conn, query) +} + +#' Test a SQL query on a data source. +#' +#' @param source A querychat_data_source object +#' @param query SQL query string +#' @param ... Additional arguments passed to methods +#' @return Result of the query, limited to one row of data. +#' @export +test_query <- function(source, query, ...) { + UseMethod("test_query") +} + +#' @export +test_query.dbi_source <- function(source, query, ...) { + rs <- DBI::dbSendQuery(source$conn, query) + df <- DBI::dbFetch(rs, n = 1) + DBI::dbClearResult(rs) + df +} + + +#' Get type information for a data source +#' +#' @param source A querychat_data_source object +#' @param ... Additional arguments passed to methods +#' @return A character string containing the type information +#' @export +get_db_type <- function(source, ...) { + UseMethod("get_db_type") +} + +#' @export +get_db_type.data_frame_source <- function(source, ...) { + # Local dataframes are always duckdb! + return("DuckDB") +} + +#' @export +get_db_type.dbi_source <- function(source, ...) { + conn <- source$conn + conn_info <- DBI::dbGetInfo(conn) + # default to 'POSIX' if dbms name not found + dbms_name <- purrr::pluck(conn_info, "dbms.name", .default = "POSIX") + # Special handling for known database types + if (inherits(conn, "SQLiteConnection")) { + return("SQLite") + } + # remove ' SQL', if exists (SQL is already in the prompt) + return(gsub(" SQL", "", dbms_name)) +} + + +#' Create a system prompt for the data source +#' +#' @param source A querychat_data_source object +#' @param data_description Optional description of the data +#' @param extra_instructions Optional additional instructions +#' @param ... Additional arguments passed to methods +#' @return A string with the system prompt +#' @export +create_system_prompt <- function( + source, + data_description = NULL, + extra_instructions = NULL, + ... +) { + UseMethod("create_system_prompt") +} + +#' @export +create_system_prompt.querychat_data_source <- function( + source, + data_description = NULL, + extra_instructions = NULL, + ... +) { + if (!is.null(data_description)) { + data_description <- paste(data_description, collapse = "\n") + } + if (!is.null(extra_instructions)) { + extra_instructions <- paste(extra_instructions, collapse = "\n") + } + + # Read the prompt file + prompt_path <- system.file("prompt", "prompt.md", package = "querychat") + prompt_content <- readLines(prompt_path, warn = FALSE) + prompt_text <- paste(prompt_content, collapse = "\n") + + # Get schema for the data source + schema <- get_schema(source) + + # Examine the data source and get the type for the prompt + db_type <- get_db_type(source) + + whisker::whisker.render( + prompt_text, + list( + schema = schema, + data_description = data_description, + extra_instructions = extra_instructions, + db_type = db_type + ) + ) +} + +#' Clean up a data source (close connections, etc.) +#' +#' @param source A querychat_data_source object +#' @param ... Additional arguments passed to methods +#' @return NULL (invisibly) +#' @export +cleanup_source <- function(source, ...) { + UseMethod("cleanup_source") +} + +#' @export +cleanup_source.dbi_source <- function(source, ...) { + if (!is.null(source$conn) && DBI::dbIsValid(source$conn)) { + DBI::dbDisconnect(source$conn) + } + invisible(NULL) +} + + +#' Get schema for a data source +#' +#' @param source A querychat_data_source object +#' @param ... Additional arguments passed to methods +#' @return A character string describing the schema +#' @export +get_schema <- function(source, ...) { + UseMethod("get_schema") +} + +#' @export +get_schema.dbi_source <- function(source, ...) { + conn <- source$conn + table_name <- source$table_name + categorical_threshold <- source$categorical_threshold + + # Get column information + columns <- DBI::dbListFields(conn, table_name) + + schema_lines <- c( + paste("Table:", DBI::dbQuoteIdentifier(conn, table_name)), + "Columns:" + ) + + # Build single query to get column statistics + select_parts <- character(0) + numeric_columns <- character(0) + text_columns <- character(0) + + # Get sample of data to determine types + sample_query <- paste0( + "SELECT * FROM ", + DBI::dbQuoteIdentifier(conn, table_name), + " LIMIT 1" + ) + sample_data <- DBI::dbGetQuery(conn, sample_query) + + for (col in columns) { + col_class <- class(sample_data[[col]])[1] + + if ( + col_class %in% + c("integer", "numeric", "double", "Date", "POSIXct", "POSIXt") + ) { + numeric_columns <- c(numeric_columns, col) + select_parts <- c( + select_parts, + paste0( + "MIN(", + DBI::dbQuoteIdentifier(conn, col), + ") as ", + DBI::dbQuoteIdentifier(conn, paste0(col, '__min')) + ), + paste0( + "MAX(", + DBI::dbQuoteIdentifier(conn, col), + ") as ", + DBI::dbQuoteIdentifier(conn, paste0(col, '__max')) + ) + ) + } else if (col_class %in% c("character", "factor")) { + text_columns <- c(text_columns, col) + select_parts <- c( + select_parts, + paste0( + "COUNT(DISTINCT ", + DBI::dbQuoteIdentifier(conn, col), + ") as ", + DBI::dbQuoteIdentifier(conn, paste0(col, '__distinct_count')) + ) + ) + } + } + + # Execute statistics query + column_stats <- list() + if (length(select_parts) > 0) { + tryCatch( + { + stats_query <- paste0( + "SELECT ", + paste0(select_parts, collapse = ", "), + " FROM ", + DBI::dbQuoteIdentifier(conn, table_name) + ) + result <- DBI::dbGetQuery(conn, stats_query) + if (nrow(result) > 0) { + column_stats <- as.list(result[1, ]) + } + }, + error = function(e) { + # Fall back to no statistics if query fails + } + ) + } + + # Get categorical values for text columns below threshold + categorical_values <- list() + text_cols_to_query <- character(0) + + for (col_name in text_columns) { + distinct_count_key <- paste0(col_name, "__distinct_count") + if ( + distinct_count_key %in% + names(column_stats) && + !is.na(column_stats[[distinct_count_key]]) && + column_stats[[distinct_count_key]] <= categorical_threshold + ) { + text_cols_to_query <- c(text_cols_to_query, col_name) + } + } + + # Remove duplicates + text_cols_to_query <- unique(text_cols_to_query) + + # Get categorical values + if (length(text_cols_to_query) > 0) { + for (col_name in text_cols_to_query) { + tryCatch( + { + cat_query <- paste0( + "SELECT DISTINCT ", + DBI::dbQuoteIdentifier(conn, col_name), + " FROM ", + DBI::dbQuoteIdentifier(conn, table_name), + " WHERE ", + DBI::dbQuoteIdentifier(conn, col_name), + " IS NOT NULL ORDER BY ", + DBI::dbQuoteIdentifier(conn, col_name) + ) + result <- DBI::dbGetQuery(conn, cat_query) + if (nrow(result) > 0) { + categorical_values[[col_name]] <- result[[1]] + } + }, + error = function(e) { + # Skip categorical values if query fails + } + ) + } + } + + # Build schema description + for (col in columns) { + col_class <- class(sample_data[[col]])[1] + sql_type <- r_class_to_sql_type(col_class) + + column_info <- paste0("- ", col, " (", sql_type, ")") + + # Add range info for numeric columns + if (col %in% numeric_columns) { + min_key <- paste0(col, "__min") + max_key <- paste0(col, "__max") + if ( + min_key %in% + names(column_stats) && + max_key %in% names(column_stats) && + !is.na(column_stats[[min_key]]) && + !is.na(column_stats[[max_key]]) + ) { + range_info <- paste0( + " Range: ", + column_stats[[min_key]], + " to ", + column_stats[[max_key]] + ) + column_info <- paste(column_info, range_info, sep = "\n") + } + } + + # Add categorical values for text columns + if (col %in% names(categorical_values)) { + values <- categorical_values[[col]] + if (length(values) > 0) { + values_str <- paste0("'", values, "'", collapse = ", ") + cat_info <- paste0(" Categorical values: ", values_str) + column_info <- paste(column_info, cat_info, sep = "\n") + } + } + + schema_lines <- c(schema_lines, column_info) + } + + paste(schema_lines, collapse = "\n") +} + + +# Helper function to map R classes to SQL types +r_class_to_sql_type <- function(r_class) { + switch( + r_class, + "integer" = "INTEGER", + "numeric" = "FLOAT", + "double" = "FLOAT", + "logical" = "BOOLEAN", + "Date" = "DATE", + "POSIXct" = "TIMESTAMP", + "POSIXt" = "TIMESTAMP", + "character" = "TEXT", + "factor" = "TEXT", + "TEXT" # default + ) +} diff --git a/pkg-r/R/prompt.R b/pkg-r/R/prompt.R deleted file mode 100644 index a8e9deb96..000000000 --- a/pkg-r/R/prompt.R +++ /dev/null @@ -1,137 +0,0 @@ -#' Create a system prompt for the chat model -#' -#' This function generates a system prompt for the chat model based on a data frame's -#' schema and optional additional context and instructions. -#' -#' @param df A data frame to generate schema information from. -#' @param table_name A string containing the name of the table in SQL queries. -#' @param data_description Optional string or existing file path. The contents -#' should be in plain text or Markdown format, containing a description of the -#' data frame or any additional context that might be helpful in understanding -#' the data. This will be included in the system prompt for the chat model. -#' @param extra_instructions Optional string or existing file path. The contents -#' should be in plain text or Markdown format, containing any additional -#' instructions for the chat model. These will be appended at the end of the -#' system prompt. -#' @param prompt_template Optional string or existing file path. If `NULL`, the -#' default prompt file in the package will be used. The contents should -#' contain a whisker template for the system prompt, with placeholders for -#' `{{schema}}`, `{{data_description}}`, and `{{extra_instructions}}`. -#' @param categorical_threshold The maximum number of unique values for a text -#' column to be considered categorical. -#' @param ... Ignored. Used to allow for future parameters. -#' -#' @return A string containing the system prompt for the chat model. -#' -#' @export -querychat_system_prompt <- function( - df, - table_name, - ..., - data_description = NULL, - extra_instructions = NULL, - prompt_template = NULL, - categorical_threshold = 10 -) { - rlang::check_dots_empty() - - schema <- df_to_schema(df, table_name, categorical_threshold) - - data_description <- read_path_or_string(data_description, "data_description") - extra_instructions <- read_path_or_string( - extra_instructions, - "extra_instructions" - ) - if (is.null(prompt_template)) { - prompt_template <- system.file("prompt", "prompt.md", package = "querychat") - } - prompt_text <- read_path_or_string(prompt_template, "prompt_template") - - processed_template <- - whisker::whisker.render( - prompt_text, - list( - schema = schema, - data_description = data_description, - extra_instructions = extra_instructions - ) - ) - - attr(processed_template, "table_name") <- table_name - - processed_template -} - -read_path_or_string <- function(x, name) { - if (is.null(x)) { - return(NULL) - } - if (!is.character(x)) { - stop(sprintf("`%s=` must be a string or a path to a file.", name)) - } - if (file.exists(x)) { - x <- readLines(x, warn = FALSE) - } - return(paste(x, collapse = "\n")) -} - - -#' Generate a schema description from a data frame -#' -#' This function generates a schema description for a data frame, including -#' the column names, their types, and additional information such as ranges for -#' numeric columns and unique values for text columns. -#' -#' @param df A data frame to generate schema information from. -#' @param table_name A string containing the name of the table in SQL queries. -#' @param categorical_threshold The maximum number of unique values for a text column to be considered categorical. -#' -#' @return A string containing the schema description for the data frame. -#' The schema includes the table name, column names, their types, and additional -#' information such as ranges for numeric columns and unique values for text columns. -#' @export -df_to_schema <- function( - df, - table_name = deparse(substitute(df)), - categorical_threshold = 10 -) { - schema <- c(paste("Table:", table_name), "Columns:") - - column_info <- lapply(names(df), function(column) { - # Map R classes to SQL-like types - sql_type <- if (is.integer(df[[column]])) { - "INTEGER" - } else if (is.numeric(df[[column]])) { - "FLOAT" - } else if (is.logical(df[[column]])) { - "BOOLEAN" - } else if (inherits(df[[column]], "POSIXt")) { - "DATETIME" - } else { - "TEXT" - } - - info <- paste0("- ", column, " (", sql_type, ")") - - # For TEXT columns, check if they're categorical - if (sql_type == "TEXT") { - unique_values <- length(unique(df[[column]])) - if (unique_values <= categorical_threshold) { - categories <- unique(df[[column]]) - categories_str <- paste0("'", categories, "'", collapse = ", ") - info <- c(info, paste0(" Categorical values: ", categories_str)) - } - } else if (sql_type %in% c("INTEGER", "FLOAT", "DATETIME")) { - rng <- range(df[[column]], na.rm = TRUE) - if (all(is.na(rng))) { - info <- c(info, " Range: NULL to NULL") - } else { - info <- c(info, paste0(" Range: ", rng[1], " to ", rng[2])) - } - } - return(info) - }) - - schema <- c(schema, unlist(column_info)) - return(paste(schema, collapse = "\n")) -} diff --git a/pkg-r/R/querychat.R b/pkg-r/R/querychat.R index 27601e94a..7a1c76f63 100644 --- a/pkg-r/R/querychat.R +++ b/pkg-r/R/querychat.R @@ -3,85 +3,79 @@ #' This will perform one-time initialization that can then be shared by all #' Shiny sessions in the R process. #' -#' @param df A data frame. -#' @param table_name A string containing a valid table name for the data frame, -#' that will appear in SQL queries. Ensure that it begins with a letter, and -#' contains only letters, numbers, and underscores. By default, querychat will -#' try to infer a table name using the name of the `df` argument. +#' @param data_source A querychat_data_source object created by `querychat_data_source()`. +#' To create a data source: +#' - For data frame: `querychat_data_source(df, tbl_name = "my_table")` +#' - For database: `querychat_data_source(conn, "table_name")` #' @param greeting A string in Markdown format, containing the initial message #' to display to the user upon first loading the chatbot. If not provided, the #' LLM will be invoked at the start of the conversation to generate one. -#' @param ... Additional arguments passed to the `querychat_system_prompt()` -#' function, such as `categorical_threshold`. If a -#' `system_prompt` argument is provided, the `...` arguments will be silently -#' ignored. -#' @inheritParams querychat_system_prompt -#' @param system_prompt A string containing the system prompt for the chat model. -#' The default uses `querychat_system_prompt()` to generate a generic prompt, -#' which you can enhance via the `data_description` and `extra_instructions` -#' arguments. +#' @param data_description A string containing a data description for the chat model. We have found +#' that formatting the data description as a markdown bulleted list works best. +#' @param extra_instructions A string containing extra instructions for the chat model. #' @param create_chat_func A function that takes a system prompt and returns a #' chat object. The default uses `ellmer::chat_openai()`. +#' @param system_prompt A string containing the system prompt for the chat model. +#' The default generates a generic prompt, which you can enhance via the `data_description` and +#' `extra_instructions` arguments. +#' @param auto_close_data_source Should the data source connection be automatically +#' closed when the shiny app stops? Defaults to TRUE. +#' #' @returns An object that can be passed to `querychat_server()` as the #' `querychat_config` argument. By convention, this object should be named #' `querychat_config`. #' #' @export querychat_init <- function( - df, - ..., - table_name = deparse(substitute(df)), + data_source, greeting = NULL, data_description = NULL, extra_instructions = NULL, - prompt_template = NULL, - system_prompt = querychat_system_prompt( - df, - table_name, - # By default, pass through any params supplied to querychat_init() - ..., - data_description = data_description, - extra_instructions = extra_instructions, - prompt_template = prompt_template - ), - create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o") + create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o"), + system_prompt = NULL, + auto_close_data_source = TRUE ) { - is_table_name_ok <- is.character(table_name) && - length(table_name) == 1 && - grepl("^[a-zA-Z][a-zA-Z0-9_]*$", table_name, perl = TRUE) - if (!is_table_name_ok) { - if (missing(table_name)) { - rlang::abort( - "Unable to infer table name from `df` argument. Please specify `table_name` argument explicitly." - ) - } else { - rlang::abort( - "`table_name` argument must be a string containing a valid table name." - ) - } + force(create_chat_func) + + # If the user passes a data.frame to data_source, create a correct data source for them + if (inherits(data_source, "data.frame")) { + data_source <- querychat_data_source( + data_source, + table_name = deparse(substitute(data_source)) + ) } - force(df) - force(system_prompt) # Have default `...` params evaluated - force(create_chat_func) + # Check that data_source is a querychat_data_source object + if (!inherits(data_source, "querychat_data_source")) { + rlang::abort( + "`data_source` must be a querychat_data_source object. Use querychat_data_source() to create one." + ) + } - # TODO: Provide nicer looking errors here + if (auto_close_data_source) { + # Close the data source when the Shiny app stops (or, if some reason the + # querychat_init call is within a specific session, when the session ends) + shiny::onStop(function() { + message("Closing data source...") + cleanup_source(data_source) + }) + } + + # Generate system prompt if not provided + if (is.null(system_prompt)) { + system_prompt <- create_system_prompt( + data_source, + data_description = data_description, + extra_instructions = extra_instructions + ) + } + + # Validate system prompt and create_chat_func stopifnot( - "df must be a data frame" = is.data.frame(df), - "table_name must be a string" = is.character(table_name), "system_prompt must be a string" = is.character(system_prompt), "create_chat_func must be a function" = is.function(create_chat_func) ) - if ("table_name" %in% names(attributes(system_prompt))) { - # If available, be sure to use the `table_name` argument to `querychat_init()` - # matches the one supplied to the system prompt - if (table_name != attr(system_prompt, "table_name")) { - rlang::abort( - "`querychat_init(table_name=)` must match system prompt `table_name` supplied to `querychat_system_prompt()`." - ) - } - } if (!is.null(greeting)) { greeting <- paste(collapse = "\n", greeting) } else { @@ -91,14 +85,9 @@ querychat_init <- function( )) } - conn <- DBI::dbConnect(duckdb::duckdb(), dbdir = ":memory:") - duckdb::duckdb_register(conn, table_name, df, experimental = FALSE) - shiny::onStop(function() DBI::dbDisconnect(conn)) - structure( list( - df = df, - conn = conn, + data_source = data_source, system_prompt = system_prompt, greeting = greeting, create_chat_func = create_chat_func @@ -154,19 +143,15 @@ querychat_ui <- function(id) { #' #' - `sql`: A reactive that returns the current SQL query. #' - `title`: A reactive that returns the current title. -#' - `df`: A reactive that returns the data frame, filtered and sorted by the -#' current SQL query. +#' - `df`: A reactive that returns the filtered data as a data.frame. #' - `chat`: The [ellmer::Chat] object that powers the chat interface. #' -#' By convention, this object should be named `querychat_config`. -#' #' @export querychat_server <- function(id, querychat_config) { shiny::moduleServer(id, function(input, output, session) { # πŸ”„ Reactive state/computation -------------------------------------------- - df <- querychat_config[["df"]] - conn <- querychat_config[["conn"]] + data_source <- querychat_config[["data_source"]] system_prompt <- querychat_config[["system_prompt"]] greeting <- querychat_config[["greeting"]] create_chat_func <- querychat_config[["create_chat_func"]] @@ -174,11 +159,7 @@ querychat_server <- function(id, querychat_config) { current_title <- shiny::reactiveVal(NULL) current_query <- shiny::reactiveVal("") filtered_df <- shiny::reactive({ - if (current_query() == "") { - df - } else { - DBI::dbGetQuery(conn, current_query()) - } + execute_query(data_source, query = DBI::SQL(current_query())) }) append_output <- function(...) { @@ -194,7 +175,7 @@ querychat_server <- function(id, querychat_config) { # Modifies the data presented in the data dashboard, based on the given SQL # query, and also updates the title. - # @param query A DuckDB SQL query; must be a SELECT statement. + # @param query A SQL query; must be a SELECT statement. # @param title A title to display at the top of the data dashboard, # summarizing the intent of the SQL query. update_dashboard <- function(query, title) { @@ -203,7 +184,7 @@ querychat_server <- function(id, querychat_config) { tryCatch( { # Try it to see if it errors; if so, the LLM will see the error - DBI::dbGetQuery(conn, query) + test_query(data_source, query) }, error = function(err) { append_output("> Error: ", conditionMessage(err), "\n\n") @@ -220,26 +201,22 @@ querychat_server <- function(id, querychat_config) { } # Perform a SQL query on the data, and return the results as JSON. - # @param query A DuckDB SQL query; must be a SELECT statement. - # @return The results of the query as a JSON string. + # @param query A SQL query; must be a SELECT statement. + # @return The results of the query as a data frame. query <- function(query) { # Do this before query, in case it errors - append_output("\n```sql\n", query, "\n```\n\n") + append_output("\n```sql\n", query, "\n```\n") tryCatch( { - df <- DBI::dbGetQuery(conn, query) + # Execute the query and return the results + execute_query(data_source, query) }, error = function(e) { append_output("> Error: ", conditionMessage(e), "\n\n") stop(e) } ) - - tbl_html <- df_to_html(df, maxrows = 5) - append_output(tbl_html, "\n\n") - - df |> jsonlite::toJSON(auto_unbox = TRUE) } # Preload the conversation with the system prompt. These are instructions for @@ -249,7 +226,7 @@ querychat_server <- function(id, querychat_config) { update_dashboard, "Modifies the data presented in the data dashboard, based on the given SQL query, and also updates the title.", query = ellmer::type_string( - "A DuckDB SQL query; must be a SELECT statement." + "A SQL query; must be a SELECT statement." ), title = ellmer::type_string( "A title to display at the top of the data dashboard, summarizing the intent of the SQL query." @@ -257,9 +234,9 @@ querychat_server <- function(id, querychat_config) { )) chat$register_tool(ellmer::tool( query, - "Perform a SQL query on the data, and return the results as JSON.", + "Perform a SQL query on the data, and return the results.", query = ellmer::type_string( - "A DuckDB SQL query; must be a SELECT statement." + "A SQL query; must be a SELECT statement." ) )) @@ -312,8 +289,12 @@ df_to_html <- function(df, maxrows = 5) { paste(collapse = "\n") if (nrow(df_short) != nrow(df)) { - rows_notice <- glue::glue( - "\n\n(Showing only the first {maxrows} rows out of {nrow(df)}.)\n" + rows_notice <- paste0( + "\n\n(Showing only the first ", + maxrows, + " rows out of ", + nrow(df), + ".)\n" ) } else { rows_notice <- "" diff --git a/pkg-r/README.md b/pkg-r/README.md index 03b5802ad..e73ce98ac 100644 --- a/pkg-r/README.md +++ b/pkg-r/README.md @@ -27,12 +27,14 @@ library(shiny) library(bslib) library(querychat) -# 1. Configure querychat. This is where you specify the dataset and can also -# override options like the greeting message, system prompt, model, etc. -querychat_config <- querychat_init(mtcars) +# 1. Create a data source for querychat +mtcars_source <- querychat_data_source(mtcars) + +# 2. Configure querychat with the data source +querychat_config <- querychat_init(mtcars_source) ui <- page_sidebar( - # 2. Use querychat_sidebar(id) in a bslib::page_sidebar. + # 3. Use querychat_sidebar(id) in a bslib::page_sidebar. # Alternatively, use querychat_ui(id) elsewhere if you don't want your # chat interface to live in a sidebar. sidebar = querychat_sidebar("chat"), @@ -41,11 +43,11 @@ ui <- page_sidebar( server <- function(input, output, session) { - # 3. Create a querychat object using the config from step 1. + # 4. Create a querychat object using the config from step 2. querychat <- querychat_server("chat", querychat_config) output$dt <- DT::renderDT({ - # 4. Use the filtered/sorted data frame anywhere you wish, via the + # 5. Use the filtered/sorted data frame anywhere you wish, via the # querychat$df() reactive. DT::datatable(querychat$df()) }) @@ -54,6 +56,29 @@ server <- function(input, output, session) { shinyApp(ui, server) ``` +## Using Database Sources + +In addition to data frames, querychat can connect to external databases via DBI: + +```r +library(shiny) +library(bslib) +library(querychat) +library(DBI) +library(RSQLite) + +# 1. Connect to a database +conn <- DBI::dbConnect(RSQLite::SQLite(), "path/to/database.db") + +# 2. Create a database data source for querychat +db_source <- querychat_data_source(conn, "table_name") + +# 3. Configure querychat with the database source +querychat_config <- querychat_init(db_source) + +# Then use querychat_config in your Shiny app as shown above +``` + ## How it works ### Powered by LLMs @@ -76,7 +101,7 @@ querychat does not have direct access to the raw data; it can _only_ read or fil - **Transparency:** querychat always displays the SQL to the user, so it can be vetted instead of blindly trusted. - **Reproducibility:** The SQL query can be easily copied and reused. -Currently, querychat uses DuckDB for its SQL engine. It's extremely fast and has a surprising number of [statistical functions](https://duckdb.org/docs/stable/sql/functions/aggregates.html#statistical-aggregates). +Currently, querychat uses DuckDB for its SQL engine when working with data frames. For database sources, it uses the native SQL dialect of the connected database. DuckDB is extremely fast and has a surprising number of [statistical functions](https://duckdb.org/docs/stable/sql/functions/aggregates.html#statistical-aggregates). ## Customizing querychat @@ -116,7 +141,7 @@ Alternatively, you can completely suppress the greeting by passing `greeting = " In LLM parlance, the _system prompt_ is the set of instructions and specific knowledge you want the model to use during a conversation. querychat automatically creates a system prompt which is comprised of: 1. The basic set of behaviors the LLM must follow in order for querychat to work properly. (See `inst/prompt/prompt.md` if you're curious what this looks like.) -2. The SQL schema of the data frame you provided. +2. The SQL schema of the data source you provided. 3. (Optional) Any additional description of the data you choose to provide. 4. (Optional) Any additional instructions you want to use to guide querychat's behavior. @@ -125,7 +150,7 @@ In LLM parlance, the _system prompt_ is the set of instructions and specific kno If you give querychat your dataset and nothing else, it will provide the LLM with the basic schema of your data: - Column names -- DuckDB data type (integer, float, boolean, datetime, text) +- SQL data type (integer, float, boolean, datetime, text) - For text columns with less than 10 unique values, we assume they are categorical variables and include the list of values - For integer and float columns, we include the range @@ -158,8 +183,12 @@ performance for 32 automobiles (1973–74 models). which you can then pass via: ```r +# Create data source first +mtcars_source <- querychat_data_source(mtcars, tbl_name = "cars") + +# Then initialize with the data source and description querychat_config <- querychat_init( - mtcars, + data_source = mtcars_source, data_description = readLines("data_description.md") ) ``` @@ -171,11 +200,18 @@ querychat doesn't need this information in any particular format; just put whate You can add additional instructions of your own to the end of the system prompt, by passing `extra_instructions` into `query_init`. ```r -querychat_config <- querychat_init(mtcars, extra_instructions = c( +# Create data source first +mtcars_source <- querychat_data_source(mtcars, tbl_name = "cars") + +# Then initialize with instructions +querychat_config <- querychat_init( + data_source = mtcars_source, + extra_instructions = c( "You're speaking to a British audience--please use appropriate spelling conventions.", "Use lots of emojis! πŸ˜ƒ Emojis everywhere, 🌍 emojis forever. ♾️", "Stay on topic, only talk about the data dashboard and refuse to answer other questions." -)) + ) +) ``` You can also put these instructions in a separate file and use `readLines()` to load them, as we did for `data_description` above. @@ -204,11 +240,15 @@ my_chat_func <- function(system_prompt) { library(ellmer) library(purrr) +# Create data source first +mtcars_source <- querychat_data_source(mtcars, tbl_name = "cars") + # Option 2: Use partial -querychat_config <- querychat_init(mtcars, +querychat_config <- querychat_init( + data_source = mtcars_source, create_chat_func = purrr::partial(ellmer::chat_claude, model = "claude-3-7-sonnet-latest") ) ``` This would use Claude 3.7 Sonnet instead, which would require you to provide an API key. -See the [instructions from Ellmer](https://ellmer.tidyverse.org/reference/chat_claude.html) for more information on how to authenticate with different providers. +See the [instructions from Ellmer](https://ellmer.tidyverse.org/reference/chat_claude.html) for more information on how to authenticate with different providers. \ No newline at end of file diff --git a/pkg-r/examples/app-database.R b/pkg-r/examples/app-database.R new file mode 100644 index 000000000..668b32c5e --- /dev/null +++ b/pkg-r/examples/app-database.R @@ -0,0 +1,92 @@ +library(shiny) +library(bslib) +library(querychat) +library(DBI) +library(RSQLite) + +# Create a sample SQLite database for demonstration +# In a real app, you would connect to your existing database +temp_db <- tempfile(fileext = ".db") +onStop(function() { + if (file.exists(temp_db)) { + unlink(temp_db) + } +}) + +conn <- dbConnect(RSQLite::SQLite(), temp_db) +# The connection will automatically be closed when the app stops, thanks to +# querychat_init + +# Create sample data in the database +iris_data <- iris +dbWriteTable(conn, "iris", iris_data, overwrite = TRUE) + +# Define a custom greeting for the database app +greeting <- " +# Welcome to the Database Query Assistant! πŸ“Š + +I can help you explore and analyze the iris dataset from the connected database. +Ask me questions about the iris flowers, and I'll generate SQL queries to get the answers. + +Try asking: +- Show me the first 10 rows of the iris dataset +- What's the average sepal length by species? +- Which species has the largest petals? +- Create a summary of measurements grouped by species +" + +# Create data source using querychat_data_source +iris_source <- querychat_data_source(conn, table_name = "iris") + +# Configure querychat for database +querychat_config <- querychat_init( + data_source = iris_source, + greeting = greeting, + data_description = "This database contains the famous iris flower dataset with measurements of sepal and petal dimensions across three species (setosa, versicolor, and virginica).", + extra_instructions = "When showing results, always explain what the data represents and highlight any interesting patterns you observe." +) + +ui <- page_sidebar( + title = "Database Query Chat", + sidebar = querychat_sidebar("chat"), + h2("Current Data View"), + p( + "The table below shows the current filtered data based on your chat queries:" + ), + DT::DTOutput("data_table", fill = FALSE), + br(), + h3("Current SQL Query"), + verbatimTextOutput("sql_query"), + br(), + h3("Dataset Information"), + p("This demo database contains:"), + tags$ul( + tags$li("iris - Famous iris flower dataset (150 rows, 5 columns)"), + tags$li( + "Columns: Sepal.Length, Sepal.Width, Petal.Length, Petal.Width, Species" + ) + ) +) + +server <- function(input, output, session) { + chat <- querychat_server("chat", querychat_config) + + output$data_table <- DT::renderDT( + { + df <- chat$df() + df + }, + options = list(pageLength = 10, scrollX = TRUE) + ) + + output$sql_query <- renderText({ + query <- chat$sql() + if (query == "") { + "No filter applied - showing all data" + } else { + query + } + }) +} + +shinyApp(ui = ui, server = server) diff --git a/pkg-r/examples/database-setup.md b/pkg-r/examples/database-setup.md new file mode 100644 index 000000000..31426f9e9 --- /dev/null +++ b/pkg-r/examples/database-setup.md @@ -0,0 +1,122 @@ +# Database Setup Examples for querychat + +This document provides examples of how to set up querychat with various database types using the new `database_source()` functionality. + +## SQLite + +```r +library(DBI) +library(RSQLite) +library(querychat) + +# Connect to SQLite database +conn <- dbConnect(RSQLite::SQLite(), "path/to/your/database.db") + +# Create database source +db_source <- database_source(conn, "your_table_name") + +# Initialize querychat +config <- querychat_init( + data_source = db_source, + greeting = "Welcome! Ask me about your data.", + data_description = "Description of your data..." +) +``` + +## PostgreSQL + +```r +library(DBI) +library(RPostgreSQL) # or library(RPostgres) +library(querychat) + +# Connect to PostgreSQL +conn <- dbConnect( + RPostgreSQL::PostgreSQL(), # or RPostgres::Postgres() + dbname = "your_database", + host = "localhost", + port = 5432, + user = "your_username", + password = "your_password" +) + +# Create database source +db_source <- database_source(conn, "your_table_name") + +# Initialize querychat +config <- querychat_init(data_source = db_source) +``` + +## MySQL + +```r +library(DBI) +library(RMySQL) +library(querychat) + +# Connect to MySQL +conn <- dbConnect( + RMySQL::MySQL(), + dbname = "your_database", + host = "localhost", + user = "your_username", + password = "your_password" +) + +# Create database source +db_source <- database_source(conn, "your_table_name") + +# Initialize querychat +config <- querychat_init(data_source = db_source) +``` + +## Connection Management + +When using database sources in Shiny apps, make sure to properly manage connections: + +```r +server <- function(input, output, session) { + # Your querychat server logic here + chat <- querychat_server("chat", querychat_config) + + # Clean up connection when session ends + session$onSessionEnded(function() { + if (dbIsValid(conn)) { + dbDisconnect(conn) + } + }) +} +``` + +## Configuration Options + +The `database_source()` function accepts a `categorical_threshold` parameter: + +```r +# Columns with <= 50 unique values will be treated as categorical +db_source <- database_source(conn, "table_name", categorical_threshold = 50) +``` + +## Security Considerations + +- Only SELECT queries are allowed - no INSERT, UPDATE, or DELETE operations +- All SQL queries are visible to users for transparency +- Use appropriate database user permissions (read-only recommended) +- Consider connection pooling for production applications +- Validate that users only have access to intended tables + +## Error Handling + +The database source implementation includes robust error handling: + +- Validates table existence during creation +- Handles database connection issues gracefully +- Provides informative error messages for invalid queries +- Falls back gracefully when statistical queries fail + +## Performance Tips + +- Use appropriate database indexes for columns commonly used in queries +- Consider limiting row counts for very large tables +- Database connections are reused for better performance +- Schema information is cached to avoid repeated metadata queries \ No newline at end of file diff --git a/pkg-r/inst/prompt/prompt.md b/pkg-r/inst/prompt/prompt.md index 9ed80f43e..3ffce7648 100644 --- a/pkg-r/inst/prompt/prompt.md +++ b/pkg-r/inst/prompt/prompt.md @@ -4,7 +4,7 @@ It's important that you get clear, unambiguous instructions from the user, so if The user interface in which this conversation is being shown is a narrow sidebar of a dashboard, so keep your answers concise and don't include unnecessary patter, nor additional prompts or offers for further assistance. -You have at your disposal a DuckDB database containing this schema: +You have at your disposal a {{db_type}} SQL database containing this schema: {{schema}} @@ -25,7 +25,7 @@ There are several tasks you may be asked to do: The user may ask you to perform filtering and sorting operations on the dashboard; if so, your job is to write the appropriate SQL query for this database. Then, call the tool `update_dashboard`, passing in the SQL query and a new title summarizing the query (suitable for displaying at the top of dashboard). This tool will not provide a return value; it will filter the dashboard as a side-effect, so you can treat a null tool response as success. * **Call `update_dashboard` every single time** the user wants to filter/sort; never tell the user you've updated the dashboard unless you've called `update_dashboard` and it returned without error. -* The SQL query must be a **DuckDB SQL** SELECT query. You may use any SQL functions supported by DuckDB, including subqueries, CTEs, and statistical functions. +* The SQL query must be a **{{db_type}} SQL** SELECT query. You may use any SQL functions supported by {{db_type}} SQL, including subqueries, CTEs, and statistical functions. * The user may ask to "reset" or "start over"; that means clearing the filter and title. Do this by calling `update_dashboard({"query": "", "title": ""})`. * Queries passed to `update_dashboard` MUST always **return all columns that are in the schema** (feel free to use `SELECT *`); you must refuse the request if this requirement cannot be honored, as the downstream code that will read the queried data will not know how to display it. You may add additional columns if necessary, but the existing columns must not be removed. * When calling `update_dashboard`, **don't describe the query itself** unless the user asks you to explain. Don't pretend you have access to the resulting data set, as you don't. diff --git a/pkg-r/man/cleanup_source.Rd b/pkg-r/man/cleanup_source.Rd new file mode 100644 index 000000000..25f3f31e8 --- /dev/null +++ b/pkg-r/man/cleanup_source.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/data_source.R +\name{cleanup_source} +\alias{cleanup_source} +\title{Clean up a data source (close connections, etc.)} +\usage{ +cleanup_source(source, ...) +} +\arguments{ +\item{source}{A querychat_data_source object} + +\item{...}{Additional arguments passed to methods} +} +\value{ +NULL (invisibly) +} +\description{ +Clean up a data source (close connections, etc.) +} diff --git a/pkg-r/man/create_system_prompt.Rd b/pkg-r/man/create_system_prompt.Rd new file mode 100644 index 000000000..342690180 --- /dev/null +++ b/pkg-r/man/create_system_prompt.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/data_source.R +\name{create_system_prompt} +\alias{create_system_prompt} +\title{Create a system prompt for the data source} +\usage{ +create_system_prompt( + source, + data_description = NULL, + extra_instructions = NULL, + ... +) +} +\arguments{ +\item{source}{A querychat_data_source object} + +\item{data_description}{Optional description of the data} + +\item{extra_instructions}{Optional additional instructions} + +\item{...}{Additional arguments passed to methods} +} +\value{ +A string with the system prompt +} +\description{ +Create a system prompt for the data source +} diff --git a/pkg-r/man/df_to_schema.Rd b/pkg-r/man/df_to_schema.Rd deleted file mode 100644 index d6060c4cb..000000000 --- a/pkg-r/man/df_to_schema.Rd +++ /dev/null @@ -1,29 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/prompt.R -\name{df_to_schema} -\alias{df_to_schema} -\title{Generate a schema description from a data frame} -\usage{ -df_to_schema( - df, - table_name = deparse(substitute(df)), - categorical_threshold = 10 -) -} -\arguments{ -\item{df}{A data frame to generate schema information from.} - -\item{table_name}{A string containing the name of the table in SQL queries.} - -\item{categorical_threshold}{The maximum number of unique values for a text column to be considered categorical.} -} -\value{ -A string containing the schema description for the data frame. -The schema includes the table name, column names, their types, and additional -information such as ranges for numeric columns and unique values for text columns. -} -\description{ -This function generates a schema description for a data frame, including -the column names, their types, and additional information such as ranges for -numeric columns and unique values for text columns. -} diff --git a/pkg-r/man/execute_query.Rd b/pkg-r/man/execute_query.Rd new file mode 100644 index 000000000..00bc34fbd --- /dev/null +++ b/pkg-r/man/execute_query.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/data_source.R +\name{execute_query} +\alias{execute_query} +\title{Execute a SQL query on a data source} +\usage{ +execute_query(source, query, ...) +} +\arguments{ +\item{source}{A querychat_data_source object} + +\item{query}{SQL query string} + +\item{...}{Additional arguments passed to methods} +} +\value{ +Result of the query as a data frame +} +\description{ +Execute a SQL query on a data source +} diff --git a/pkg-r/man/get_db_type.Rd b/pkg-r/man/get_db_type.Rd new file mode 100644 index 000000000..e3fd6429b --- /dev/null +++ b/pkg-r/man/get_db_type.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/data_source.R +\name{get_db_type} +\alias{get_db_type} +\title{Get type information for a data source} +\usage{ +get_db_type(source, ...) +} +\arguments{ +\item{source}{A querychat_data_source object} + +\item{...}{Additional arguments passed to methods} +} +\value{ +A character string containing the type information +} +\description{ +Get type information for a data source +} diff --git a/pkg-r/man/get_schema.Rd b/pkg-r/man/get_schema.Rd new file mode 100644 index 000000000..22d24ff12 --- /dev/null +++ b/pkg-r/man/get_schema.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/data_source.R +\name{get_schema} +\alias{get_schema} +\title{Get schema for a data source} +\usage{ +get_schema(source, ...) +} +\arguments{ +\item{source}{A querychat_data_source object} + +\item{...}{Additional arguments passed to methods} +} +\value{ +A character string describing the schema +} +\description{ +Get schema for a data source +} diff --git a/pkg-r/man/querychat_data_source.Rd b/pkg-r/man/querychat_data_source.Rd new file mode 100644 index 000000000..7d99ac5a1 --- /dev/null +++ b/pkg-r/man/querychat_data_source.Rd @@ -0,0 +1,34 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/data_source.R +\name{querychat_data_source} +\alias{querychat_data_source} +\alias{querychat_data_source.data.frame} +\alias{querychat_data_source.DBIConnection} +\title{Create a data source for querychat} +\usage{ +querychat_data_source(x, ...) + +\method{querychat_data_source}{data.frame}(x, table_name = NULL, categorical_threshold = 20, ...) + +\method{querychat_data_source}{DBIConnection}(x, table_name, categorical_threshold = 20, ...) +} +\arguments{ +\item{x}{A data frame or DBI connection} + +\item{...}{Additional arguments passed to specific methods} + +\item{table_name}{The name to use for the table in the data source. Can be: +\itemize{ +\item A character string (e.g., "table_name") +\item Or, for tables contained within catalogs or schemas, a \code{\link[DBI:Id]{DBI::Id()}} object (e.g., \code{DBI::Id(schema = "schema_name", table = "table_name")}) +}} + +\item{categorical_threshold}{For text columns, the maximum number of unique values to consider as a categorical variable} +} +\value{ +A querychat_data_source object +} +\description{ +Generic function to create a data source for querychat. This function +dispatches to appropriate methods based on input. +} diff --git a/pkg-r/man/querychat_init.Rd b/pkg-r/man/querychat_init.Rd index b2e355df7..618d85323 100644 --- a/pkg-r/man/querychat_init.Rd +++ b/pkg-r/man/querychat_init.Rd @@ -5,58 +5,41 @@ \title{Call this once outside of any server function} \usage{ querychat_init( - df, - ..., - table_name = deparse(substitute(df)), + data_source, greeting = NULL, data_description = NULL, extra_instructions = NULL, - prompt_template = NULL, - system_prompt = querychat_system_prompt(df, table_name, ..., data_description = - data_description, extra_instructions = extra_instructions, prompt_template = - prompt_template), - create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o") + create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o"), + system_prompt = NULL, + auto_close_data_source = TRUE ) } \arguments{ -\item{df}{A data frame.} - -\item{...}{Additional arguments passed to the \code{querychat_system_prompt()} -function, such as \code{categorical_threshold}. If a -\code{system_prompt} argument is provided, the \code{...} arguments will be silently -ignored.} - -\item{table_name}{A string containing a valid table name for the data frame, -that will appear in SQL queries. Ensure that it begins with a letter, and -contains only letters, numbers, and underscores. By default, querychat will -try to infer a table name using the name of the \code{df} argument.} +\item{data_source}{A querychat_data_source object created by \code{querychat_data_source()}. +To create a data source: +\itemize{ +\item For data frame: \code{querychat_data_source(df, tbl_name = "my_table")} +\item For database: \code{querychat_data_source(conn, "table_name")} +}} \item{greeting}{A string in Markdown format, containing the initial message to display to the user upon first loading the chatbot. If not provided, the LLM will be invoked at the start of the conversation to generate one.} -\item{data_description}{Optional string or existing file path. The contents -should be in plain text or Markdown format, containing a description of the -data frame or any additional context that might be helpful in understanding -the data. This will be included in the system prompt for the chat model.} +\item{data_description}{A string containing a data description for the chat model. We have found +that formatting the data description as a markdown bulleted list works best.} -\item{extra_instructions}{Optional string or existing file path. The contents -should be in plain text or Markdown format, containing any additional -instructions for the chat model. These will be appended at the end of the -system prompt.} +\item{extra_instructions}{A string containing extra instructions for the chat model.} -\item{prompt_template}{Optional string or existing file path. If \code{NULL}, the -default prompt file in the package will be used. The contents should -contain a whisker template for the system prompt, with placeholders for -\code{{{schema}}}, \code{{{data_description}}}, and \code{{{extra_instructions}}}.} +\item{create_chat_func}{A function that takes a system prompt and returns a +chat object. The default uses \code{ellmer::chat_openai()}.} \item{system_prompt}{A string containing the system prompt for the chat model. -The default uses \code{querychat_system_prompt()} to generate a generic prompt, -which you can enhance via the \code{data_description} and \code{extra_instructions} -arguments.} +The default generates a generic prompt, which you can enhance via the \code{data_description} and +\code{extra_instructions} arguments.} -\item{create_chat_func}{A function that takes a system prompt and returns a -chat object. The default uses \code{ellmer::chat_openai()}.} +\item{auto_close_data_source}{Should the data source connection be automatically +closed when the shiny app stops? Defaults to TRUE.} } \value{ An object that can be passed to \code{querychat_server()} as the diff --git a/pkg-r/man/querychat_server.Rd b/pkg-r/man/querychat_server.Rd index f6daa5c7d..eec8f8926 100644 --- a/pkg-r/man/querychat_server.Rd +++ b/pkg-r/man/querychat_server.Rd @@ -18,12 +18,9 @@ elements: \itemize{ \item \code{sql}: A reactive that returns the current SQL query. \item \code{title}: A reactive that returns the current title. -\item \code{df}: A reactive that returns the data frame, filtered and sorted by the -current SQL query. +\item \code{df}: A reactive that returns the filtered data as a data.frame. \item \code{chat}: The \link[ellmer:Chat]{ellmer::Chat} object that powers the chat interface. } - -By convention, this object should be named \code{querychat_config}. } \description{ Initalize the querychat server diff --git a/pkg-r/man/querychat_system_prompt.Rd b/pkg-r/man/querychat_system_prompt.Rd deleted file mode 100644 index 9c5a0e955..000000000 --- a/pkg-r/man/querychat_system_prompt.Rd +++ /dev/null @@ -1,48 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/prompt.R -\name{querychat_system_prompt} -\alias{querychat_system_prompt} -\title{Create a system prompt for the chat model} -\usage{ -querychat_system_prompt( - df, - table_name, - ..., - data_description = NULL, - extra_instructions = NULL, - prompt_template = NULL, - categorical_threshold = 10 -) -} -\arguments{ -\item{df}{A data frame to generate schema information from.} - -\item{table_name}{A string containing the name of the table in SQL queries.} - -\item{...}{Ignored. Used to allow for future parameters.} - -\item{data_description}{Optional string or existing file path. The contents -should be in plain text or Markdown format, containing a description of the -data frame or any additional context that might be helpful in understanding -the data. This will be included in the system prompt for the chat model.} - -\item{extra_instructions}{Optional string or existing file path. The contents -should be in plain text or Markdown format, containing any additional -instructions for the chat model. These will be appended at the end of the -system prompt.} - -\item{prompt_template}{Optional string or existing file path. If \code{NULL}, the -default prompt file in the package will be used. The contents should -contain a whisker template for the system prompt, with placeholders for -\code{{{schema}}}, \code{{{data_description}}}, and \code{{{extra_instructions}}}.} - -\item{categorical_threshold}{The maximum number of unique values for a text -column to be considered categorical.} -} -\value{ -A string containing the system prompt for the chat model. -} -\description{ -This function generates a system prompt for the chat model based on a data frame's -schema and optional additional context and instructions. -} diff --git a/pkg-r/man/test_query.Rd b/pkg-r/man/test_query.Rd new file mode 100644 index 000000000..ec3411de7 --- /dev/null +++ b/pkg-r/man/test_query.Rd @@ -0,0 +1,21 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/data_source.R +\name{test_query} +\alias{test_query} +\title{Test a SQL query on a data source.} +\usage{ +test_query(source, query, ...) +} +\arguments{ +\item{source}{A querychat_data_source object} + +\item{query}{SQL query string} + +\item{...}{Additional arguments passed to methods} +} +\value{ +Result of the query, limited to one row of data. +} +\description{ +Test a SQL query on a data source. +} diff --git a/pkg-r/tests/testthat.R b/pkg-r/tests/testthat.R new file mode 100644 index 000000000..23f8c8185 --- /dev/null +++ b/pkg-r/tests/testthat.R @@ -0,0 +1,12 @@ +# This file is part of the standard setup for testthat. +# It is recommended that you do not modify it. +# +# Where should you do additional test configuration? +# Learn more about the roles of various files in: +# * https://r-pkgs.org/testing-design.html#sec-tests-files-overview +# * https://testthat.r-lib.org/articles/special-files.html + +library(testthat) +library(querychat) + +test_check("querychat") diff --git a/pkg-r/tests/testthat/test-data-source.R b/pkg-r/tests/testthat/test-data-source.R new file mode 100644 index 000000000..a957aae9b --- /dev/null +++ b/pkg-r/tests/testthat/test-data-source.R @@ -0,0 +1,271 @@ +library(testthat) +library(DBI) +library(RSQLite) +library(querychat) + +test_that("querychat_data_source.data.frame creates proper S3 object", { + # Create a simple data frame + test_df <- data.frame( + id = 1:5, + name = c("A", "B", "C", "D", "E"), + value = c(10.5, 20.3, 15.7, 30.1, 25.9), + stringsAsFactors = FALSE + ) + + # Test with explicit table name + source <- querychat_data_source(test_df, table_name = "test_table") + expect_s3_class(source, "data_frame_source") + expect_s3_class(source, "querychat_data_source") + expect_equal(source$table_name, "test_table") + expect_true(inherits(source$conn, "DBIConnection")) + + # Clean up + cleanup_source(source) +}) + +test_that("querychat_data_source.DBIConnection creates proper S3 object", { + # Create temporary SQLite database + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + + # Create test table + test_data <- data.frame( + id = 1:5, + name = c("Alice", "Bob", "Charlie", "Diana", "Eve"), + age = c(25, 30, 35, 28, 32), + stringsAsFactors = FALSE + ) + + dbWriteTable(conn, "users", test_data, overwrite = TRUE) + + # Test DBI source creation + db_source <- querychat_data_source(conn, "users") + expect_s3_class(db_source, "dbi_source") + expect_s3_class(db_source, "querychat_data_source") + expect_equal(db_source$table_name, "users") + expect_equal(db_source$categorical_threshold, 20) + + # Clean up + dbDisconnect(conn) + unlink(temp_db) +}) + +test_that("get_schema methods return proper schema", { + # Test with data frame source + test_df <- data.frame( + id = 1:5, + name = c("A", "B", "C", "D", "E"), + active = c(TRUE, FALSE, TRUE, TRUE, FALSE), + stringsAsFactors = FALSE + ) + + df_source <- querychat_data_source(test_df, table_name = "test_table") + schema <- get_schema(df_source) + expect_type(schema, "character") + expect_match(schema, "Table: test_table") + expect_match(schema, "id \\(INTEGER\\)") + expect_match(schema, "name \\(TEXT\\)") + expect_match(schema, "active \\(BOOLEAN\\)") + expect_match(schema, "Categorical values") # Should list categorical values + + # Test min/max values in schema - specifically for the id column + expect_match(schema, "- id \\(INTEGER\\)\\n Range: 1 to 5") + + # Test with DBI source + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) + + dbi_source <- querychat_data_source(conn, "test_table") + schema <- get_schema(dbi_source) + expect_type(schema, "character") + expect_match(schema, "Table: `test_table`") + expect_match(schema, "id \\(INTEGER\\)") + expect_match(schema, "name \\(TEXT\\)") + + # Test min/max values in DBI source schema - specifically for the id column + expect_match(schema, "- id \\(INTEGER\\)\\n Range: 1 to 5") + + # Clean up + cleanup_source(df_source) + dbDisconnect(conn) + unlink(temp_db) +}) + +test_that("execute_query works for both source types", { + # Test with data frame source + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + df_source <- querychat_data_source(test_df, table_name = "test_table") + result <- execute_query( + df_source, + "SELECT * FROM test_table WHERE value > 25" + ) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) # Should return 3 rows (30, 40, 50) + + # Test with DBI source + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) + + dbi_source <- querychat_data_source(conn, "test_table") + result <- execute_query( + dbi_source, + "SELECT * FROM test_table WHERE value > 25" + ) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) # Should return 3 rows (30, 40, 50) + + # Clean up + cleanup_source(df_source) + dbDisconnect(conn) + unlink(temp_db) +}) + +test_that("execute_query works with empty/null queries", { + # Test with data frame source + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + df_source <- querychat_data_source(test_df, table_name = "test_table") + + # Test with NULL query + result_null <- execute_query(df_source, NULL) + expect_s3_class(result_null, "data.frame") + expect_equal(nrow(result_null), 5) # Should return all rows + expect_equal(ncol(result_null), 2) # Should return all columns + + # Test with empty string query + result_empty <- execute_query(df_source, "") + expect_s3_class(result_empty, "data.frame") + expect_equal(nrow(result_empty), 5) # Should return all rows + expect_equal(ncol(result_empty), 2) # Should return all columns + + # Test with DBI source + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) + + dbi_source <- querychat_data_source(conn, "test_table") + + # Test with NULL query + result_null <- execute_query(dbi_source, NULL) + expect_s3_class(result_null, "data.frame") + expect_equal(nrow(result_null), 5) # Should return all rows + expect_equal(ncol(result_null), 2) # Should return all columns + + # Test with empty string query + result_empty <- execute_query(dbi_source, "") + expect_s3_class(result_empty, "data.frame") + expect_equal(nrow(result_empty), 5) # Should return all rows + expect_equal(ncol(result_empty), 2) # Should return all columns + + # Clean up + cleanup_source(df_source) + dbDisconnect(conn) + unlink(temp_db) +}) + + +test_that("get_schema correctly reports min/max values for numeric columns", { + # Create a dataframe with multiple numeric columns + test_df <- data.frame( + id = 1:5, + score = c(10.5, 20.3, 15.7, 30.1, 25.9), + count = c(100, 200, 150, 50, 75), + stringsAsFactors = FALSE + ) + + df_source <- querychat_data_source(test_df, table_name = "test_metrics") + schema <- get_schema(df_source) + + # Check that each numeric column has the correct min/max values + expect_match(schema, "- id \\(INTEGER\\)\\n Range: 1 to 5") + expect_match(schema, "- score \\(FLOAT\\)\\n Range: 10\\.5 to 30\\.1") + # Note: In the test output, count was detected as FLOAT rather than INTEGER + expect_match(schema, "- count \\(FLOAT\\)\\n Range: 50 to 200") + + # Clean up + cleanup_source(df_source) +}) + +test_that("create_system_prompt generates appropriate system prompt", { + test_df <- data.frame( + id = 1:3, + name = c("A", "B", "C"), + stringsAsFactors = FALSE + ) + + df_source <- querychat_data_source(test_df, table_name = "test_table") + prompt <- create_system_prompt( + df_source, + data_description = "A test dataframe" + ) + expect_type(prompt, "character") + expect_true(nchar(prompt) > 0) + expect_match(prompt, "A test dataframe") + expect_match(prompt, "Table: test_table") + + # Clean up + cleanup_source(df_source) +}) + +test_that("querychat_init automatically handles data.frame inputs", { + # Test that querychat_init accepts data frames directly + test_df <- data.frame(id = 1:3, name = c("A", "B", "C")) + + # Should work with data frame and auto-convert it + config <- querychat_init(data_source = test_df, greeting = "Test greeting") + expect_s3_class(config, "querychat_config") + expect_s3_class(config$data_source, "querychat_data_source") + expect_s3_class(config$data_source, "data_frame_source") + + # Should work with proper data source too + df_source <- querychat_data_source(test_df, table_name = "test_table") + config <- querychat_init(data_source = df_source, greeting = "Test greeting") + expect_s3_class(config, "querychat_config") + + # Clean up + cleanup_source(df_source) + cleanup_source(config$data_source) +}) + +test_that("querychat_init works with both source types", { + # Test with data frame + test_df <- data.frame( + id = 1:3, + name = c("A", "B", "C"), + stringsAsFactors = FALSE + ) + + # Create data source and test with querychat_init + df_source <- querychat_data_source(test_df, table_name = "test_source") + config <- querychat_init(data_source = df_source, greeting = "Test greeting") + expect_s3_class(config, "querychat_config") + expect_s3_class(config$data_source, "data_frame_source") + expect_equal(config$data_source$table_name, "test_source") + + # Test with database connection + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) + + dbi_source <- querychat_data_source(conn, "test_table") + config <- querychat_init(data_source = dbi_source, greeting = "Test greeting") + expect_s3_class(config, "querychat_config") + expect_s3_class(config$data_source, "dbi_source") + expect_equal(config$data_source$table_name, "test_table") + + # Clean up + cleanup_source(df_source) + dbDisconnect(conn) + unlink(temp_db) +}) diff --git a/pkg-r/tests/testthat/test-db-type.R b/pkg-r/tests/testthat/test-db-type.R new file mode 100644 index 000000000..e10967d8f --- /dev/null +++ b/pkg-r/tests/testthat/test-db-type.R @@ -0,0 +1,57 @@ +library(testthat) + +test_that("get_db_type returns correct type for data_frame_source", { + # Create a simple data frame source + df <- data.frame(x = 1:5, y = letters[1:5]) + df_source <- querychat_data_source(df, "test_table") + + # Test that get_db_type returns "DuckDB" + expect_equal(get_db_type(df_source), "DuckDB") +}) + +test_that("get_db_type returns correct type for dbi_source with SQLite", { + skip_if_not_installed("RSQLite") + + # Create a SQLite database source + temp_db <- tempfile(fileext = ".db") + conn <- DBI::dbConnect(RSQLite::SQLite(), temp_db) + DBI::dbWriteTable(conn, "test_table", data.frame(x = 1:5, y = letters[1:5])) + db_source <- querychat_data_source(conn, "test_table") + + # Test that get_db_type returns the correct database type + expect_equal(get_db_type(db_source), "SQLite") + + # Clean up + DBI::dbDisconnect(conn) + unlink(temp_db) +}) + +test_that("get_db_type is correctly used in create_system_prompt", { + # Create a simple data frame source + df <- data.frame(x = 1:5, y = letters[1:5]) + df_source <- querychat_data_source(df, "test_table") + + # Generate system prompt + sys_prompt <- create_system_prompt(df_source) + + # Check that "DuckDB" appears in the prompt content + expect_true(grepl("DuckDB SQL", sys_prompt, fixed = TRUE)) +}) + +test_that("get_db_type is used to customize prompt template", { + # Create a simple data frame source + df <- data.frame(x = 1:5, y = letters[1:5]) + df_source <- querychat_data_source(df, "test_table") + + # Get the db_type + db_type <- get_db_type(df_source) + + # Check that the db_type is correctly returned + expect_equal(db_type, "DuckDB") + + # Verify the value is used in the system prompt + # This is an indirect test that doesn't need mocking + # We just check that the string appears somewhere in the system prompt + prompt <- create_system_prompt(df_source) + expect_true(grepl(db_type, prompt, fixed = TRUE)) +}) diff --git a/pkg-r/tests/testthat/test-querychat-server.R b/pkg-r/tests/testthat/test-querychat-server.R new file mode 100644 index 000000000..a44cfb088 --- /dev/null +++ b/pkg-r/tests/testthat/test-querychat-server.R @@ -0,0 +1,46 @@ +library(testthat) +library(DBI) +library(RSQLite) +library(querychat) + +test_that("database source query functionality", { + # Create temporary SQLite database + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + + # Create test table + test_data <- data.frame( + id = 1:5, + name = c("Alice", "Bob", "Charlie", "Diana", "Eve"), + age = c(25, 30, 35, 28, 32), + stringsAsFactors = FALSE + ) + + dbWriteTable(conn, "users", test_data, overwrite = TRUE) + + # Create database source + db_source <- querychat_data_source(conn, "users") + + # Test that we can execute queries + result <- execute_query(db_source, "SELECT * FROM users WHERE age > 30") + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 2) # Charlie and Eve + expect_equal(result$name, c("Charlie", "Eve")) + + # Test that we can get all data + all_data <- execute_query(db_source, NULL) + expect_s3_class(all_data, "data.frame") + expect_equal(nrow(all_data), 5) + expect_equal(ncol(all_data), 3) + + # Test ordering works + ordered_result <- execute_query( + db_source, + "SELECT * FROM users ORDER BY age DESC" + ) + expect_equal(ordered_result$name[1], "Charlie") # Oldest first + + # Clean up + dbDisconnect(conn) + unlink(temp_db) +}) diff --git a/pkg-r/tests/testthat/test-shiny-app.R b/pkg-r/tests/testthat/test-shiny-app.R new file mode 100644 index 000000000..8f7628407 --- /dev/null +++ b/pkg-r/tests/testthat/test-shiny-app.R @@ -0,0 +1,142 @@ +library(testthat) + +test_that("app database example loads without errors", { + skip_if_not_installed("DT") + skip_if_not_installed("RSQLite") + skip_if_not_installed("shinytest2") + + # Create a simplified test app with mocked ellmer + test_app_file <- tempfile(fileext = ".R") + + test_app_content <- ' +library(shiny) +library(bslib) +library(querychat) +library(DBI) +library(RSQLite) + +# Mock chat function to avoid LLM API calls +mock_chat_func <- function(system_prompt) { + list( + register_tool = function(tool) invisible(NULL), + stream_async = function(message) { + "Welcome! This is a mock response for testing." + } + ) +} + +# Create test database +temp_db <- tempfile(fileext = ".db") +conn <- dbConnect(RSQLite::SQLite(), temp_db) +dbWriteTable(conn, "iris", iris, overwrite = TRUE) +dbDisconnect(conn) + +# Setup database source +db_conn <- dbConnect(RSQLite::SQLite(), temp_db) +iris_source <- querychat_data_source(db_conn, "iris") + +# Configure querychat with mock +querychat_config <- querychat_init( + data_source = iris_source, + greeting = "Welcome to the test app!", + create_chat_func = mock_chat_func +) + +ui <- page_sidebar( + title = "Test Database App", + sidebar = querychat_sidebar("chat"), + h2("Data"), + DT::DTOutput("data_table"), + h3("SQL Query"), + verbatimTextOutput("sql_query") +) + +server <- function(input, output, session) { + chat <- querychat_server("chat", querychat_config) + + output$data_table <- DT::renderDT({ + chat$df() + }, options = list(pageLength = 5)) + + output$sql_query <- renderText({ + query <- chat$sql() + if (query == "") "No filter applied" else query + }) + + session$onSessionEnded(function() { + if (DBI::dbIsValid(db_conn)) { + DBI::dbDisconnect(db_conn) + } + unlink(temp_db) + }) +} + +shinyApp(ui = ui, server = server) +' + + writeLines(test_app_content, test_app_file) + + # Test that the app can be loaded without immediate errors + expect_no_error({ + # Try to parse and evaluate the app code + source(test_app_file, local = TRUE) + }) + + # Clean up + unlink(test_app_file) +}) + +test_that("database reactive functionality works correctly", { + skip_if_not_installed("RSQLite") + + library(DBI) + library(RSQLite) + + # Create test database + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + dbWriteTable(conn, "iris", iris, overwrite = TRUE) + dbDisconnect(conn) + + # Test database source creation + db_conn <- dbConnect(RSQLite::SQLite(), temp_db) + iris_source <- querychat_data_source(db_conn, "iris") + + # Mock chat function + mock_chat_func <- function(system_prompt) { + list( + register_tool = function(tool) invisible(NULL), + stream_async = function(message) "Mock response" + ) + } + + # Test querychat_init with database source + config <- querychat_init( + data_source = iris_source, + greeting = "Test greeting", + create_chat_func = mock_chat_func + ) + + expect_s3_class(config$data_source, "dbi_source") + expect_s3_class(config$data_source, "querychat_data_source") + + # Test that we can get all data + result_data <- execute_query(config$data_source, NULL) + expect_s3_class(result_data, "data.frame") + expect_equal(nrow(result_data), 150) + expect_equal(ncol(result_data), 5) + + # Test with a specific query + query_result <- execute_query( + config$data_source, + "SELECT \"Sepal.Length\", \"Sepal.Width\" FROM iris WHERE \"Species\" = 'setosa'" + ) + expect_s3_class(query_result, "data.frame") + expect_equal(nrow(query_result), 50) + expect_equal(ncol(query_result), 2) + expect_true(all(c("Sepal.Length", "Sepal.Width") %in% names(query_result))) + + # Clean up + dbDisconnect(db_conn) + unlink(temp_db) +}) diff --git a/pkg-r/tests/testthat/test-sql-comments.R b/pkg-r/tests/testthat/test-sql-comments.R new file mode 100644 index 000000000..e7553ad1d --- /dev/null +++ b/pkg-r/tests/testthat/test-sql-comments.R @@ -0,0 +1,211 @@ +library(testthat) +library(DBI) +library(RSQLite) +library(querychat) + +test_that("execute_query handles SQL with inline comments", { + # Create a simple test dataframe + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + # Create data source + df_source <- querychat_data_source(test_df, table_name = "test_table") + + # Test with inline comments + inline_comment_query <- " + SELECT id, value -- This is a comment + FROM test_table + WHERE value > 25 -- Filter for higher values + " + + result <- execute_query(df_source, inline_comment_query) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) # Should return 3 rows (30, 40, 50) + expect_equal(ncol(result), 2) + + # Test with multiple inline comments + multiple_comments_query <- " + SELECT -- Get only these columns + id, -- ID column + value -- Value column + FROM test_table -- Our test table + WHERE value > 25 -- Only higher values + " + + result <- execute_query(df_source, multiple_comments_query) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Clean up + cleanup_source(df_source) +}) + +test_that("execute_query handles SQL with multiline comments", { + # Create a simple test dataframe + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + # Create data source + df_source <- querychat_data_source(test_df, table_name = "test_table") + + # Test with multiline comments + multiline_comment_query <- " + /* + * This is a multiline comment + * that spans multiple lines + */ + SELECT id, value + FROM test_table + WHERE value > 25 + " + + result <- execute_query(df_source, multiline_comment_query) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Test with embedded multiline comments + embedded_multiline_query <- " + SELECT id, /* comment between columns */ value + FROM /* this is + * a multiline + * comment + */ test_table + WHERE value /* another comment */ > 25 + " + + result <- execute_query(df_source, embedded_multiline_query) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Clean up + cleanup_source(df_source) +}) + +test_that("execute_query handles SQL with trailing semicolons", { + # Create a simple test dataframe + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + # Create data source + df_source <- querychat_data_source(test_df, table_name = "test_table") + + # Test with trailing semicolon + query_with_semicolon <- " + SELECT id, value + FROM test_table + WHERE value > 25; + " + + result <- execute_query(df_source, query_with_semicolon) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Test with multiple semicolons (which could happen with LLM-generated SQL) + query_with_multiple_semicolons <- " + SELECT id, value + FROM test_table + WHERE value > 25;;;; + " + + result <- execute_query(df_source, query_with_multiple_semicolons) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Clean up + cleanup_source(df_source) +}) + +test_that("execute_query handles SQL with mixed comments and semicolons", { + # Create a simple test dataframe + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + # Create data source + df_source <- querychat_data_source(test_df, table_name = "test_table") + + # Test with a mix of comment styles and semicolons + complex_query <- " + /* + * This is a complex query with different comment styles + */ + SELECT + id, -- This is the ID column + value /* Value column */ + FROM + test_table -- Our test table + WHERE + /* Only get higher values */ + value > 25; -- End of query + " + + result <- execute_query(df_source, complex_query) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Test with comments that contain SQL-like syntax + tricky_comment_query <- " + SELECT id, value + FROM test_table + /* Comment with SQL-like syntax: + * SELECT * FROM another_table; + */ + WHERE value > 25 -- WHERE id = 'value; DROP TABLE test;' + " + + result <- execute_query(df_source, tricky_comment_query) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Clean up + cleanup_source(df_source) +}) + +test_that("execute_query handles SQL with unusual whitespace patterns", { + # Create a simple test dataframe + test_df <- data.frame( + id = 1:5, + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + # Create data source + df_source <- querychat_data_source(test_df, table_name = "test_table") + + # Test with unusual whitespace patterns (which LLMs might generate) + unusual_whitespace_query <- " + + SELECT id, value + + FROM test_table + + WHERE value>25 + + " + + result <- execute_query(df_source, unusual_whitespace_query) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 3) + expect_equal(ncol(result), 2) + + # Clean up + cleanup_source(df_source) +}) diff --git a/pkg-r/tests/testthat/test-test-query.R b/pkg-r/tests/testthat/test-test-query.R new file mode 100644 index 000000000..ceac04e5e --- /dev/null +++ b/pkg-r/tests/testthat/test-test-query.R @@ -0,0 +1,115 @@ +library(testthat) +library(DBI) +library(RSQLite) +library(querychat) + +test_that("test_query.dbi_source correctly retrieves one row of data", { + # Create a simple data frame + test_df <- data.frame( + id = 1:5, + name = c("Alice", "Bob", "Charlie", "Diana", "Eve"), + value = c(10, 20, 30, 40, 50), + stringsAsFactors = FALSE + ) + + # Setup DBI source + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) + + dbi_source <- querychat_data_source(conn, "test_table") + + # Test basic query - should only return one row + result <- test_query(dbi_source, "SELECT * FROM test_table") + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 1) # Should only return 1 row + expect_equal(result$id, 1) # Should be first row + + # Test with WHERE clause + result <- test_query(dbi_source, "SELECT * FROM test_table WHERE value > 25") + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 1) # Should only return 1 row + expect_equal(result$value, 30) # Should return first row with value > 25 + + # Test with ORDER BY - should get the highest value + result <- test_query( + dbi_source, + "SELECT * FROM test_table ORDER BY value DESC" + ) + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 1) + expect_equal(result$value, 50) # Should be the highest value + + # Test with query returning no results + result <- test_query(dbi_source, "SELECT * FROM test_table WHERE value > 100") + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 0) # Should return empty data frame + + # Clean up + cleanup_source(dbi_source) + unlink(temp_db) +}) + +test_that("test_query.dbi_source handles errors correctly", { + # Setup DBI source + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + + # Create a test table + test_df <- data.frame( + id = 1:3, + value = c(10, 20, 30), + stringsAsFactors = FALSE + ) + dbWriteTable(conn, "test_table", test_df, overwrite = TRUE) + + dbi_source <- querychat_data_source(conn, "test_table") + + # Test with invalid SQL + expect_error(test_query(dbi_source, "SELECT * WRONG SYNTAX")) + + # Test with non-existent table + expect_error(test_query(dbi_source, "SELECT * FROM non_existent_table")) + + # Test with non-existent column + expect_error(test_query( + dbi_source, + "SELECT non_existent_column FROM test_table" + )) + + # Clean up + cleanup_source(dbi_source) + unlink(temp_db) +}) + +test_that("test_query.dbi_source works with different data types", { + # Create a data frame with different data types + test_df <- data.frame( + id = 1:3, + text_col = c("text1", "text2", "text3"), + num_col = c(1.1, 2.2, 3.3), + int_col = c(10L, 20L, 30L), + bool_col = c(TRUE, FALSE, TRUE), + stringsAsFactors = FALSE + ) + + # Setup DBI source + temp_db <- tempfile(fileext = ".db") + conn <- dbConnect(RSQLite::SQLite(), temp_db) + dbWriteTable(conn, "types_table", test_df, overwrite = TRUE) + + dbi_source <- querychat_data_source(conn, "types_table") + + # Test query with different column types + result <- test_query(dbi_source, "SELECT * FROM types_table") + expect_s3_class(result, "data.frame") + expect_equal(nrow(result), 1) + expect_type(result$text_col, "character") + expect_type(result$num_col, "double") + expect_type(result$int_col, "integer") + expect_type(result$bool_col, "integer") # SQLite stores booleans as integers + + # Clean up + cleanup_source(dbi_source) + unlink(temp_db) +}) diff --git a/pyproject.toml b/pyproject.toml index c5a1787e3..4f5302285 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,7 +44,7 @@ packages = ["pkg-py/src/querychat"] include = ["pkg-py/src/querychat", "pkg-py/LICENSE", "pkg-py/README.md"] [dependency-groups] -dev = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4"] +dev = ["ruff>=0.6.5", "pyright>=1.1.401", "tox-uv>=1.11.4", "pytest>=8.4.0"] docs = ["quartodoc>=0.11.1"] examples = ["seaborn", "openai"] @@ -77,6 +77,8 @@ exclude = [ "node_modules", "site-packages", "venv", + "app-*.py", # ignore example apps for now + "app.py", "examples", # ignore example apps for now ] @@ -160,6 +162,10 @@ unfixable = [] # Allow unused variables when underscore-prefixed. dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" +# disable S101 (flagging asserts) for tests +[tool.ruff.lint.per-file-ignores] +"pkg-py/tests/*.py" = ["S101"] + [tool.ruff.format] quote-style = "double" indent-style = "space"