diff --git a/pkg-py/src/querychat/__init__.py b/pkg-py/src/querychat/__init__.py index 985d24f5..660a202f 100644 --- a/pkg-py/src/querychat/__init__.py +++ b/pkg-py/src/querychat/__init__.py @@ -1,3 +1,5 @@ -from querychat.querychat import init, mod_server as server, sidebar, system_prompt, mod_ui as ui +from querychat.querychat import init, sidebar, system_prompt +from querychat.querychat import mod_server as server +from querychat.querychat import mod_ui as ui -__all__ = ["init", "server", "sidebar", "ui", "system_prompt"] +__all__ = ["init", "server", "sidebar", "system_prompt", "ui"] diff --git a/pkg-py/src/querychat/querychat.py b/pkg-py/src/querychat/querychat.py index d52fcb7c..c15db100 100644 --- a/pkg-py/src/querychat/querychat.py +++ b/pkg-py/src/querychat/querychat.py @@ -9,9 +9,7 @@ import chatlas import chevron import narwhals as nw -import pandas as pd import sqlalchemy -from narwhals.typing import IntoFrame from shiny import Inputs, Outputs, Session, module, reactive, ui if TYPE_CHECKING: @@ -145,6 +143,7 @@ def system_prompt( data_description: Optional[str] = None, extra_instructions: Optional[str] = None, categorical_threshold: int = 10, + prompt_path: Optional[Path] = None, ) -> str: """ Create a system prompt for the chat model based on a data source's schema @@ -162,6 +161,9 @@ def system_prompt( categorical_threshold : int, default=10 Threshold for determining if a column is categorical based on number of unique values + prompt_path + Optional `Path` to a custom prompt file. If not provided, the default + querychat template will be used. Returns ------- @@ -170,7 +172,11 @@ def system_prompt( """ # Read the prompt file - prompt_path = Path(__file__).parent / "prompt" / "prompt.md" + if prompt_path is None: + # Default to the prompt file in the same directory as this module + # This allows for easy customization by placing a different prompt.md file there + prompt_path = Path(__file__).parent / "prompt" / "prompt.md" + prompt_text = prompt_path.read_text() return chevron.render( @@ -226,11 +232,14 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: def init( data_source: IntoFrame | sqlalchemy.Engine, table_name: str, + /, + *, greeting: Optional[str] = None, data_description: Optional[str] = None, extra_instructions: Optional[str] = None, - create_chat_callback: Optional[CreateChatCallback] = None, + prompt_path: Optional[Path] = None, system_prompt_override: Optional[str] = None, + create_chat_callback: Optional[CreateChatCallback] = None, ) -> QueryChatConfig: """ Initialize querychat with any compliant data source. @@ -251,10 +260,22 @@ def init( Description of the data in plain text or Markdown extra_instructions : str, optional Additional instructions for the chat model + prompt_path : Path, optional + Path to a custom prompt file. If not provided, the default querychat + template will be used. This should be a Markdown file that contains the + system prompt template. The mustache template can use the following + variables: + - `{{db_engine}}`: The database engine used (e.g., "DuckDB") + - `{{schema}}`: The schema of the data source, generated by + `data_source.get_schema()` + - `{{data_description}}`: The optional data description provided + - `{{extra_instructions}}`: Any additional instructions provided + system_prompt_override : str, optional + A custom system prompt to use instead of the default. If provided, + `data_description`, `extra_instructions`, and `prompt_path` will be + silently ignored. create_chat_callback : CreateChatCallback, optional A function that creates a chat object - system_prompt_override : str, optional - A custom system prompt to use instead of the default Returns ------- @@ -289,6 +310,7 @@ def init( data_source_obj, data_description, extra_instructions, + prompt_path=prompt_path, ) # Default chat function if none provided diff --git a/pkg-r/NAMESPACE b/pkg-r/NAMESPACE index d1e39fd8..077a6ed0 100644 --- a/pkg-r/NAMESPACE +++ b/pkg-r/NAMESPACE @@ -1,5 +1,6 @@ # Generated by roxygen2: do not edit by hand +export(df_to_schema) export(querychat_init) export(querychat_server) export(querychat_sidebar) diff --git a/pkg-r/NEWS.md b/pkg-r/NEWS.md new file mode 100644 index 00000000..811fe69b --- /dev/null +++ b/pkg-r/NEWS.md @@ -0,0 +1,5 @@ +# querychat (development version) + +* Initial CRAN submission. + +* Added `prompt_path` support for `querychat_system_prompt()`. (Thank you, @oacar! #37) diff --git a/pkg-r/R/prompt.R b/pkg-r/R/prompt.R index 75ac68b6..10bfc977 100644 --- a/pkg-r/R/prompt.R +++ b/pkg-r/R/prompt.R @@ -4,22 +4,32 @@ #' schema and optional additional context and instructions. #' #' @param df A data frame to generate schema information from. -#' @param name A string containing the name of the table in SQL queries. -#' @param data_description Optional description of the data, in plain text or Markdown format. -#' @param extra_instructions Optional additional instructions for the chat model, in plain text or Markdown format. +#' @param table_name A string containing the name of the table in SQL queries. +#' @param data_description Optional string in plain text or Markdown format, containing +#' a description of the data frame or any additional context that might be +#' helpful in understanding the data. This will be included in the system +#' prompt for the chat model. +#' @param extra_instructions Optional string in plain text or Markdown format, containing +#' any additional instructions for the chat model. These will be appended at +#' the end of the system prompt. #' @param categorical_threshold The maximum number of unique values for a text column to be considered categorical. +#' @param prompt_path Optional string containing the path to a custom prompt file. If +#' `NULL`, the default prompt file in the package will be used. This file should +#' contain a whisker template for the system prompt, with placeholders for `{{schema}}`, +#' `{{data_description}}`, and `{{extra_instructions}}`. #' #' @return A string containing the system prompt for the chat model. #' #' @export querychat_system_prompt <- function( df, - name, + table_name, data_description = NULL, extra_instructions = NULL, - categorical_threshold = 10 + categorical_threshold = 10, + prompt_path = system.file("prompt", "prompt.md", package = "querychat") ) { - schema <- df_to_schema(df, name, categorical_threshold) + schema <- df_to_schema(df, table_name, categorical_threshold) if (!is.null(data_description)) { data_description <- paste(data_description, collapse = "\n") @@ -29,26 +39,50 @@ querychat_system_prompt <- function( } # Read the prompt file - prompt_path <- system.file("prompt", "prompt.md", package = "querychat") + if (is.null(prompt_path)) { + prompt_path <- system.file("prompt", "prompt.md", package = "querychat") + } + if (!file.exists(prompt_path)) { + stop("Prompt file not found at: ", prompt_path) + } prompt_content <- readLines(prompt_path, warn = FALSE) prompt_text <- paste(prompt_content, collapse = "\n") - whisker::whisker.render( - prompt_text, - list( - schema = schema, - data_description = data_description, - extra_instructions = extra_instructions + processed_template <- + whisker::whisker.render( + prompt_text, + list( + schema = schema, + data_description = data_description, + extra_instructions = extra_instructions + ) ) - ) + + attr(processed_template, "table_name") <- table_name + + processed_template } +#' Generate a schema description from a data frame +#' +#' This function generates a schema description for a data frame, including +#' the column names, their types, and additional information such as ranges for +#' numeric columns and unique values for text columns. +#' +#' @param df A data frame to generate schema information from. +#' @param table_name A string containing the name of the table in SQL queries. +#' @param categorical_threshold The maximum number of unique values for a text column to be considered categorical. +#' +#' @return A string containing the schema description for the data frame. +#' The schema includes the table name, column names, their types, and additional +#' information such as ranges for numeric columns and unique values for text columns. +#' @export df_to_schema <- function( df, - name = deparse(substitute(df)), - categorical_threshold + table_name = deparse(substitute(df)), + categorical_threshold = 10 ) { - schema <- c(paste("Table:", name), "Columns:") + schema <- c(paste("Table:", table_name), "Columns:") column_info <- lapply(names(df), function(column) { # Map R classes to SQL-like types diff --git a/pkg-r/R/querychat.R b/pkg-r/R/querychat.R index f97a867c..eca6a0b8 100644 --- a/pkg-r/R/querychat.R +++ b/pkg-r/R/querychat.R @@ -4,29 +4,24 @@ #' Shiny sessions in the R process. #' #' @param df A data frame. -#' @param tbl_name A string containing a valid table name for the data frame, +#' @param table_name A string containing a valid table name for the data frame, #' that will appear in SQL queries. Ensure that it begins with a letter, and #' contains only letters, numbers, and underscores. By default, querychat will #' try to infer a table name using the name of the `df` argument. #' @param greeting A string in Markdown format, containing the initial message #' to display to the user upon first loading the chatbot. If not provided, the #' LLM will be invoked at the start of the conversation to generate one. -#' @param data_description A string in plain text or Markdown format, containing -#' a description of the data frame or any additional context that might be -#' helpful in understanding the data. This will be included in the system -#' prompt for the chat model. If a `system_prompt` argument is provided, the -#' `data_description` argument will be ignored. -#' @param extra_instructions A string in plain text or Markdown format, containing -#' any additional instructions for the chat model. These will be appended at -#' the end of the system prompt. If a `system_prompt` argument is provided, -#' the `extra_instructions` argument will be ignored. -#' @param create_chat_func A function that takes a system prompt and returns a -#' chat object. The default uses `ellmer::chat_openai()`. +#' @param ... Additional arguments passed to the `querychat_system_prompt()` +#' function, such as `categorical_threshold`, and `prompt_path`. If a +#' `system_prompt` argument is provided, the `...` arguments will be silently +#' ignored. +#' @inheritParams querychat_system_prompt #' @param system_prompt A string containing the system prompt for the chat model. #' The default uses `querychat_system_prompt()` to generate a generic prompt, #' which you can enhance via the `data_description` and `extra_instructions` #' arguments. -#' +#' @param create_chat_func A function that takes a system prompt and returns a +#' chat object. The default uses `ellmer::chat_openai()`. #' @returns An object that can be passed to `querychat_server()` as the #' `querychat_config` argument. By convention, this object should be named #' `querychat_config`. @@ -34,45 +29,57 @@ #' @export querychat_init <- function( df, - tbl_name = deparse(substitute(df)), + ..., + table_name = deparse(substitute(df)), greeting = NULL, data_description = NULL, extra_instructions = NULL, - create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o"), system_prompt = querychat_system_prompt( df, - tbl_name, + table_name, + # By default, pass through any params supplied to querychat_init() + ..., data_description = data_description, extra_instructions = extra_instructions - ) + ), + create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o") ) { - is_tbl_name_ok <- is.character(tbl_name) && - length(tbl_name) == 1 && - grepl("^[a-zA-Z][a-zA-Z0-9_]*$", tbl_name, perl = TRUE) - if (!is_tbl_name_ok) { - if (missing(tbl_name)) { + is_table_name_ok <- is.character(table_name) && + length(table_name) == 1 && + grepl("^[a-zA-Z][a-zA-Z0-9_]*$", table_name, perl = TRUE) + if (!is_table_name_ok) { + if (missing(table_name)) { rlang::abort( - "Unable to infer table name from `df` argument. Please specify `tbl_name` argument explicitly." + "Unable to infer table name from `df` argument. Please specify `table_name` argument explicitly." ) } else { rlang::abort( - "`tbl_name` argument must be a string containing a valid table name." + "`table_name` argument must be a string containing a valid table name." ) } } force(df) - force(system_prompt) + force(system_prompt) # Have default `...` params evaluated force(create_chat_func) # TODO: Provide nicer looking errors here stopifnot( "df must be a data frame" = is.data.frame(df), - "tbl_name must be a string" = is.character(tbl_name), + "table_name must be a string" = is.character(table_name), "system_prompt must be a string" = is.character(system_prompt), "create_chat_func must be a function" = is.function(create_chat_func) ) + if ("table_name" %in% names(attributes(system_prompt))) { + # If available, be sure to use the `table_name` argument to `querychat_init()` + # matches the one supplied to the system prompt + if (table_name != attr(system_prompt, "table_name")) { + rlang::abort( + "`querychat_init(table_name=)` must match system prompt `table_name` supplied to `querychat_system_prompt()`." + ) + } + } if (!is.null(greeting)) { greeting <- paste(collapse = "\n", greeting) } else { @@ -83,7 +90,7 @@ querychat_init <- function( } conn <- DBI::dbConnect(duckdb::duckdb(), dbdir = ":memory:") - duckdb::duckdb_register(conn, tbl_name, df, experimental = FALSE) + duckdb::duckdb_register(conn, table_name, df, experimental = FALSE) shiny::onStop(function() DBI::dbDisconnect(conn)) structure( diff --git a/pkg-r/man/df_to_schema.Rd b/pkg-r/man/df_to_schema.Rd new file mode 100644 index 00000000..d6060c4c --- /dev/null +++ b/pkg-r/man/df_to_schema.Rd @@ -0,0 +1,29 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/prompt.R +\name{df_to_schema} +\alias{df_to_schema} +\title{Generate a schema description from a data frame} +\usage{ +df_to_schema( + df, + table_name = deparse(substitute(df)), + categorical_threshold = 10 +) +} +\arguments{ +\item{df}{A data frame to generate schema information from.} + +\item{table_name}{A string containing the name of the table in SQL queries.} + +\item{categorical_threshold}{The maximum number of unique values for a text column to be considered categorical.} +} +\value{ +A string containing the schema description for the data frame. +The schema includes the table name, column names, their types, and additional +information such as ranges for numeric columns and unique values for text columns. +} +\description{ +This function generates a schema description for a data frame, including +the column names, their types, and additional information such as ranges for +numeric columns and unique values for text columns. +} diff --git a/pkg-r/man/querychat_init.Rd b/pkg-r/man/querychat_init.Rd index 260261ae..5a0b0c84 100644 --- a/pkg-r/man/querychat_init.Rd +++ b/pkg-r/man/querychat_init.Rd @@ -6,19 +6,25 @@ \usage{ querychat_init( df, - tbl_name = deparse(substitute(df)), + ..., + table_name = deparse(substitute(df)), greeting = NULL, data_description = NULL, extra_instructions = NULL, - create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o"), - system_prompt = querychat_system_prompt(df, tbl_name, data_description = - data_description, extra_instructions = extra_instructions) + system_prompt = querychat_system_prompt(df, table_name, ..., data_description = + data_description, extra_instructions = extra_instructions), + create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o") ) } \arguments{ \item{df}{A data frame.} -\item{tbl_name}{A string containing a valid table name for the data frame, +\item{...}{Additional arguments passed to the \code{querychat_system_prompt()} +function, such as \code{categorical_threshold}, and \code{prompt_path}. If a +\code{system_prompt} argument is provided, the \code{...} arguments will be silently +ignored.} + +\item{table_name}{A string containing a valid table name for the data frame, that will appear in SQL queries. Ensure that it begins with a letter, and contains only letters, numbers, and underscores. By default, querychat will try to infer a table name using the name of the \code{df} argument.} @@ -27,24 +33,22 @@ try to infer a table name using the name of the \code{df} argument.} to display to the user upon first loading the chatbot. If not provided, the LLM will be invoked at the start of the conversation to generate one.} -\item{data_description}{A string in plain text or Markdown format, containing +\item{data_description}{Optional string in plain text or Markdown format, containing a description of the data frame or any additional context that might be helpful in understanding the data. This will be included in the system -prompt for the chat model. If a \code{system_prompt} argument is provided, the -\code{data_description} argument will be ignored.} +prompt for the chat model.} -\item{extra_instructions}{A string in plain text or Markdown format, containing +\item{extra_instructions}{Optional string in plain text or Markdown format, containing any additional instructions for the chat model. These will be appended at -the end of the system prompt. If a \code{system_prompt} argument is provided, -the \code{extra_instructions} argument will be ignored.} - -\item{create_chat_func}{A function that takes a system prompt and returns a -chat object. The default uses \code{ellmer::chat_openai()}.} +the end of the system prompt.} \item{system_prompt}{A string containing the system prompt for the chat model. The default uses \code{querychat_system_prompt()} to generate a generic prompt, which you can enhance via the \code{data_description} and \code{extra_instructions} arguments.} + +\item{create_chat_func}{A function that takes a system prompt and returns a +chat object. The default uses \code{ellmer::chat_openai()}.} } \value{ An object that can be passed to \code{querychat_server()} as the diff --git a/pkg-r/man/querychat_system_prompt.Rd b/pkg-r/man/querychat_system_prompt.Rd index 31dae21f..a62b0ac3 100644 --- a/pkg-r/man/querychat_system_prompt.Rd +++ b/pkg-r/man/querychat_system_prompt.Rd @@ -6,22 +6,33 @@ \usage{ querychat_system_prompt( df, - name, + table_name, data_description = NULL, extra_instructions = NULL, - categorical_threshold = 10 + categorical_threshold = 10, + prompt_path = system.file("prompt", "prompt.md", package = "querychat") ) } \arguments{ \item{df}{A data frame to generate schema information from.} -\item{name}{A string containing the name of the table in SQL queries.} +\item{table_name}{A string containing the name of the table in SQL queries.} -\item{data_description}{Optional description of the data, in plain text or Markdown format.} +\item{data_description}{Optional string in plain text or Markdown format, containing +a description of the data frame or any additional context that might be +helpful in understanding the data. This will be included in the system +prompt for the chat model.} -\item{extra_instructions}{Optional additional instructions for the chat model, in plain text or Markdown format.} +\item{extra_instructions}{Optional string in plain text or Markdown format, containing +any additional instructions for the chat model. These will be appended at +the end of the system prompt.} \item{categorical_threshold}{The maximum number of unique values for a text column to be considered categorical.} + +\item{prompt_path}{Optional string containing the path to a custom prompt file. If +\code{NULL}, the default prompt file in the package will be used. This file should +contain a whisker template for the system prompt, with placeholders for \code{{{schema}}}, +\code{{{data_description}}}, and \code{{{extra_instructions}}}.} } \value{ A string containing the system prompt for the chat model. diff --git a/pyproject.toml b/pyproject.toml index daec9962..c5a1787e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -77,7 +77,7 @@ exclude = [ "node_modules", "site-packages", "venv", - "app-*.py", # ignore example apps for now + "examples", # ignore example apps for now ] line-length = 88 @@ -110,6 +110,7 @@ extend-ignore = [ "D104", # Missing docstring in public package "D107", # Missing docstring in __init__ "D205", # 1 blank line required between summary line and description + "UP045", # Use `X | NULL` for type annotations, not `Optional[X]` ] extend-select = [ # "C90", # C90; mccabe: https://docs.astral.sh/ruff/rules/complex-structure/