Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 18 additions & 15 deletions pkg-py/src/querychat/querychat.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,7 @@ def system_prompt(
data_description: Optional[str | Path] = None,
extra_instructions: Optional[str | Path] = None,
categorical_threshold: int = 10,
prompt_path: Optional[Path] = None,
prompt_template: Optional[str | Path] = None,
) -> str:
"""
Create a system prompt for the chat model based on a data source's schema
Expand All @@ -157,8 +157,8 @@ def system_prompt(
categorical_threshold : int, default=10
Threshold for determining if a column is categorical based on number of
unique values
prompt_path
Optional `Path` to a custom prompt file. If not provided, the default
prompt_template
Optional `Path` to or string of a custom prompt template. If not provided, the default
querychat template will be used.

Returns
Expand All @@ -168,27 +168,30 @@ def system_prompt(

"""
# Read the prompt file
if prompt_path is None:
if prompt_template is None:
# Default to the prompt file in the same directory as this module
# This allows for easy customization by placing a different prompt.md file there
prompt_path = Path(__file__).parent / "prompt" / "prompt.md"

prompt_text = prompt_path.read_text()
prompt_template = Path(__file__).parent / "prompt" / "prompt.md"
prompt_str = (
prompt_template.read_text()
if isinstance(prompt_template, Path)
else prompt_template
)

data_description_str: str | None = (
data_description_str = (
data_description.read_text()
if isinstance(data_description, Path)
else data_description
)

extra_instructions_str: str | None = (
extra_instructions_str = (
extra_instructions.read_text()
if isinstance(extra_instructions, Path)
else extra_instructions
)

return chevron.render(
prompt_text,
prompt_str,
{
"db_engine": data_source.db_engine,
"schema": data_source.get_schema(
Expand Down Expand Up @@ -244,7 +247,7 @@ def init(
greeting: Optional[str | Path] = None,
data_description: Optional[str | Path] = None,
extra_instructions: Optional[str | Path] = None,
prompt_path: Optional[Path] = None,
prompt_template: Optional[str | Path] = None,
system_prompt_override: Optional[str] = None,
create_chat_callback: Optional[CreateChatCallback] = None,
) -> QueryChatConfig:
Expand Down Expand Up @@ -273,8 +276,8 @@ def init(
Additional instructions for the chat model.
If a pathlib.Path object is passed,
querychat will read the contents of the path into a string with `.read_text()`.
prompt_path : Path, optional
Path to a custom prompt file. If not provided, the default querychat
prompt_template : Path, optional
Path to or a string of a custom prompt file. If not provided, the default querychat
template will be used. This should be a Markdown file that contains the
system prompt template. The mustache template can use the following
variables:
Expand All @@ -285,7 +288,7 @@ def init(
- `{{extra_instructions}}`: Any additional instructions provided
system_prompt_override : str, optional
A custom system prompt to use instead of the default. If provided,
`data_description`, `extra_instructions`, and `prompt_path` will be
`data_description`, `extra_instructions`, and `prompt_template` will be
silently ignored.
create_chat_callback : CreateChatCallback, optional
A function that creates a chat object
Expand Down Expand Up @@ -331,7 +334,7 @@ def init(
data_source_obj,
data_description=data_description,
extra_instructions=extra_instructions,
prompt_path=prompt_path,
prompt_template=prompt_template,
)

# Default chat function if none provided
Expand Down
2 changes: 1 addition & 1 deletion pkg-r/NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@

* Initial CRAN submission.

* Added `prompt_path` support for `querychat_system_prompt()`. (Thank you, @oacar! #37)
* Added `prompt_template` support for `querychat_system_prompt()`. (Thank you, @oacar! #37, #45)
71 changes: 42 additions & 29 deletions pkg-r/R/prompt.R
Original file line number Diff line number Diff line change
Expand Up @@ -5,48 +5,47 @@
#'
#' @param df A data frame to generate schema information from.
#' @param table_name A string containing the name of the table in SQL queries.
#' @param data_description Optional string in plain text or Markdown format, containing
#' a description of the data frame or any additional context that might be
#' helpful in understanding the data. This will be included in the system
#' prompt for the chat model.
#' @param extra_instructions Optional string in plain text or Markdown format, containing
#' any additional instructions for the chat model. These will be appended at
#' the end of the system prompt.
#' @param categorical_threshold The maximum number of unique values for a text column to be considered categorical.
#' @param prompt_path Optional string containing the path to a custom prompt file. If
#' `NULL`, the default prompt file in the package will be used. This file should
#' contain a whisker template for the system prompt, with placeholders for `{{schema}}`,
#' `{{data_description}}`, and `{{extra_instructions}}`.
#' @param data_description Optional string or existing file path. The contents
#' should be in plain text or Markdown format, containing a description of the
#' data frame or any additional context that might be helpful in understanding
#' the data. This will be included in the system prompt for the chat model.
#' @param extra_instructions Optional string or existing file path. The contents
#' should be in plain text or Markdown format, containing any additional
#' instructions for the chat model. These will be appended at the end of the
#' system prompt.
#' @param prompt_template Optional string or existing file path. If `NULL`, the
#' default prompt file in the package will be used. The contents should
#' contain a whisker template for the system prompt, with placeholders for
#' `{{schema}}`, `{{data_description}}`, and `{{extra_instructions}}`.
#' @param categorical_threshold The maximum number of unique values for a text
#' column to be considered categorical.
#' @param ... Ignored. Used to allow for future parameters.
#'
#' @return A string containing the system prompt for the chat model.
#'
#' @export
querychat_system_prompt <- function(
df,
table_name,
...,
data_description = NULL,
extra_instructions = NULL,
categorical_threshold = 10,
prompt_path = system.file("prompt", "prompt.md", package = "querychat")
prompt_template = NULL,
categorical_threshold = 10
) {
schema <- df_to_schema(df, table_name, categorical_threshold)
rlang::check_dots_empty()

if (!is.null(data_description)) {
data_description <- paste(data_description, collapse = "\n")
}
if (!is.null(extra_instructions)) {
extra_instructions <- paste(extra_instructions, collapse = "\n")
}
schema <- df_to_schema(df, table_name, categorical_threshold)

# Read the prompt file
if (is.null(prompt_path)) {
prompt_path <- system.file("prompt", "prompt.md", package = "querychat")
}
if (!file.exists(prompt_path)) {
stop("Prompt file not found at: ", prompt_path)
data_description <- read_path_or_string(data_description, "data_description")
extra_instructions <- read_path_or_string(
extra_instructions,
"extra_instructions"
)
if (is.null(prompt_template)) {
prompt_template <- system.file("prompt", "prompt.md", package = "querychat")
}
prompt_content <- readLines(prompt_path, warn = FALSE)
prompt_text <- paste(prompt_content, collapse = "\n")
prompt_text <- read_path_or_string(prompt_template, "prompt_template")

processed_template <-
whisker::whisker.render(
Expand All @@ -63,6 +62,20 @@ querychat_system_prompt <- function(
processed_template
}

read_path_or_string <- function(x, name) {
if (is.null(x)) {
return(NULL)
}
if (!is.character(x)) {
stop(sprintf("`%s=` must be a string or a path to a file.", name))
}
if (file.exists(x)) {
x <- readLines(x, warn = FALSE)
}
return(paste(x, collapse = "\n"))
}


#' Generate a schema description from a data frame
#'
#' This function generates a schema description for a data frame, including
Expand Down
6 changes: 4 additions & 2 deletions pkg-r/R/querychat.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#' to display to the user upon first loading the chatbot. If not provided, the
#' LLM will be invoked at the start of the conversation to generate one.
#' @param ... Additional arguments passed to the `querychat_system_prompt()`
#' function, such as `categorical_threshold`, and `prompt_path`. If a
#' function, such as `categorical_threshold`. If a
#' `system_prompt` argument is provided, the `...` arguments will be silently
#' ignored.
#' @inheritParams querychat_system_prompt
Expand All @@ -34,13 +34,15 @@ querychat_init <- function(
greeting = NULL,
data_description = NULL,
extra_instructions = NULL,
prompt_template = NULL,
system_prompt = querychat_system_prompt(
df,
table_name,
# By default, pass through any params supplied to querychat_init()
...,
data_description = data_description,
extra_instructions = extra_instructions
extra_instructions = extra_instructions,
prompt_template = prompt_template
),
create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o")
) {
Expand Down
26 changes: 17 additions & 9 deletions pkg-r/man/querychat_init.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 19 additions & 14 deletions pkg-r/man/querychat_system_prompt.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.