Skip to content

Commit 100c300

Browse files
oacarschloerke
andauthored
feat(r+py): Add querychat_system_prompt(prompt_path=) support; Export R's df_to_schema() (#37)
Co-authored-by: Barret Schloerke <[email protected]>
1 parent a97b17b commit 100c300

File tree

10 files changed

+188
-72
lines changed

10 files changed

+188
-72
lines changed

pkg-py/src/querychat/__init__.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1-
from querychat.querychat import init, mod_server as server, sidebar, system_prompt, mod_ui as ui
1+
from querychat.querychat import init, sidebar, system_prompt
2+
from querychat.querychat import mod_server as server
3+
from querychat.querychat import mod_ui as ui
24

3-
__all__ = ["init", "server", "sidebar", "ui", "system_prompt"]
5+
__all__ = ["init", "server", "sidebar", "system_prompt", "ui"]

pkg-py/src/querychat/querychat.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,7 @@
99
import chatlas
1010
import chevron
1111
import narwhals as nw
12-
import pandas as pd
1312
import sqlalchemy
14-
from narwhals.typing import IntoFrame
1513
from shiny import Inputs, Outputs, Session, module, reactive, ui
1614

1715
if TYPE_CHECKING:
@@ -145,6 +143,7 @@ def system_prompt(
145143
data_description: Optional[str] = None,
146144
extra_instructions: Optional[str] = None,
147145
categorical_threshold: int = 10,
146+
prompt_path: Optional[Path] = None,
148147
) -> str:
149148
"""
150149
Create a system prompt for the chat model based on a data source's schema
@@ -162,6 +161,9 @@ def system_prompt(
162161
categorical_threshold : int, default=10
163162
Threshold for determining if a column is categorical based on number of
164163
unique values
164+
prompt_path
165+
Optional `Path` to a custom prompt file. If not provided, the default
166+
querychat template will be used.
165167
166168
Returns
167169
-------
@@ -170,7 +172,11 @@ def system_prompt(
170172
171173
"""
172174
# Read the prompt file
173-
prompt_path = Path(__file__).parent / "prompt" / "prompt.md"
175+
if prompt_path is None:
176+
# Default to the prompt file in the same directory as this module
177+
# This allows for easy customization by placing a different prompt.md file there
178+
prompt_path = Path(__file__).parent / "prompt" / "prompt.md"
179+
174180
prompt_text = prompt_path.read_text()
175181

176182
return chevron.render(
@@ -226,11 +232,14 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str:
226232
def init(
227233
data_source: IntoFrame | sqlalchemy.Engine,
228234
table_name: str,
235+
/,
236+
*,
229237
greeting: Optional[str] = None,
230238
data_description: Optional[str] = None,
231239
extra_instructions: Optional[str] = None,
232-
create_chat_callback: Optional[CreateChatCallback] = None,
240+
prompt_path: Optional[Path] = None,
233241
system_prompt_override: Optional[str] = None,
242+
create_chat_callback: Optional[CreateChatCallback] = None,
234243
) -> QueryChatConfig:
235244
"""
236245
Initialize querychat with any compliant data source.
@@ -251,10 +260,22 @@ def init(
251260
Description of the data in plain text or Markdown
252261
extra_instructions : str, optional
253262
Additional instructions for the chat model
263+
prompt_path : Path, optional
264+
Path to a custom prompt file. If not provided, the default querychat
265+
template will be used. This should be a Markdown file that contains the
266+
system prompt template. The mustache template can use the following
267+
variables:
268+
- `{{db_engine}}`: The database engine used (e.g., "DuckDB")
269+
- `{{schema}}`: The schema of the data source, generated by
270+
`data_source.get_schema()`
271+
- `{{data_description}}`: The optional data description provided
272+
- `{{extra_instructions}}`: Any additional instructions provided
273+
system_prompt_override : str, optional
274+
A custom system prompt to use instead of the default. If provided,
275+
`data_description`, `extra_instructions`, and `prompt_path` will be
276+
silently ignored.
254277
create_chat_callback : CreateChatCallback, optional
255278
A function that creates a chat object
256-
system_prompt_override : str, optional
257-
A custom system prompt to use instead of the default
258279
259280
Returns
260281
-------
@@ -289,6 +310,7 @@ def init(
289310
data_source_obj,
290311
data_description,
291312
extra_instructions,
313+
prompt_path=prompt_path,
292314
)
293315

294316
# Default chat function if none provided

pkg-r/NAMESPACE

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
export(df_to_schema)
34
export(querychat_init)
45
export(querychat_server)
56
export(querychat_sidebar)

pkg-r/NEWS.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# querychat (development version)
2+
3+
* Initial CRAN submission.
4+
5+
* Added `prompt_path` support for `querychat_system_prompt()`. (Thank you, @oacar! #37)

pkg-r/R/prompt.R

Lines changed: 51 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,32 @@
44
#' schema and optional additional context and instructions.
55
#'
66
#' @param df A data frame to generate schema information from.
7-
#' @param name A string containing the name of the table in SQL queries.
8-
#' @param data_description Optional description of the data, in plain text or Markdown format.
9-
#' @param extra_instructions Optional additional instructions for the chat model, in plain text or Markdown format.
7+
#' @param table_name A string containing the name of the table in SQL queries.
8+
#' @param data_description Optional string in plain text or Markdown format, containing
9+
#' a description of the data frame or any additional context that might be
10+
#' helpful in understanding the data. This will be included in the system
11+
#' prompt for the chat model.
12+
#' @param extra_instructions Optional string in plain text or Markdown format, containing
13+
#' any additional instructions for the chat model. These will be appended at
14+
#' the end of the system prompt.
1015
#' @param categorical_threshold The maximum number of unique values for a text column to be considered categorical.
16+
#' @param prompt_path Optional string containing the path to a custom prompt file. If
17+
#' `NULL`, the default prompt file in the package will be used. This file should
18+
#' contain a whisker template for the system prompt, with placeholders for `{{schema}}`,
19+
#' `{{data_description}}`, and `{{extra_instructions}}`.
1120
#'
1221
#' @return A string containing the system prompt for the chat model.
1322
#'
1423
#' @export
1524
querychat_system_prompt <- function(
1625
df,
17-
name,
26+
table_name,
1827
data_description = NULL,
1928
extra_instructions = NULL,
20-
categorical_threshold = 10
29+
categorical_threshold = 10,
30+
prompt_path = system.file("prompt", "prompt.md", package = "querychat")
2131
) {
22-
schema <- df_to_schema(df, name, categorical_threshold)
32+
schema <- df_to_schema(df, table_name, categorical_threshold)
2333

2434
if (!is.null(data_description)) {
2535
data_description <- paste(data_description, collapse = "\n")
@@ -29,26 +39,50 @@ querychat_system_prompt <- function(
2939
}
3040

3141
# Read the prompt file
32-
prompt_path <- system.file("prompt", "prompt.md", package = "querychat")
42+
if (is.null(prompt_path)) {
43+
prompt_path <- system.file("prompt", "prompt.md", package = "querychat")
44+
}
45+
if (!file.exists(prompt_path)) {
46+
stop("Prompt file not found at: ", prompt_path)
47+
}
3348
prompt_content <- readLines(prompt_path, warn = FALSE)
3449
prompt_text <- paste(prompt_content, collapse = "\n")
3550

36-
whisker::whisker.render(
37-
prompt_text,
38-
list(
39-
schema = schema,
40-
data_description = data_description,
41-
extra_instructions = extra_instructions
51+
processed_template <-
52+
whisker::whisker.render(
53+
prompt_text,
54+
list(
55+
schema = schema,
56+
data_description = data_description,
57+
extra_instructions = extra_instructions
58+
)
4259
)
43-
)
60+
61+
attr(processed_template, "table_name") <- table_name
62+
63+
processed_template
4464
}
4565

66+
#' Generate a schema description from a data frame
67+
#'
68+
#' This function generates a schema description for a data frame, including
69+
#' the column names, their types, and additional information such as ranges for
70+
#' numeric columns and unique values for text columns.
71+
#'
72+
#' @param df A data frame to generate schema information from.
73+
#' @param table_name A string containing the name of the table in SQL queries.
74+
#' @param categorical_threshold The maximum number of unique values for a text column to be considered categorical.
75+
#'
76+
#' @return A string containing the schema description for the data frame.
77+
#' The schema includes the table name, column names, their types, and additional
78+
#' information such as ranges for numeric columns and unique values for text columns.
79+
#' @export
4680
df_to_schema <- function(
4781
df,
48-
name = deparse(substitute(df)),
49-
categorical_threshold
82+
table_name = deparse(substitute(df)),
83+
categorical_threshold = 10
5084
) {
51-
schema <- c(paste("Table:", name), "Columns:")
85+
schema <- c(paste("Table:", table_name), "Columns:")
5286

5387
column_info <- lapply(names(df), function(column) {
5488
# Map R classes to SQL-like types

pkg-r/R/querychat.R

Lines changed: 34 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -4,75 +4,82 @@
44
#' Shiny sessions in the R process.
55
#'
66
#' @param df A data frame.
7-
#' @param tbl_name A string containing a valid table name for the data frame,
7+
#' @param table_name A string containing a valid table name for the data frame,
88
#' that will appear in SQL queries. Ensure that it begins with a letter, and
99
#' contains only letters, numbers, and underscores. By default, querychat will
1010
#' try to infer a table name using the name of the `df` argument.
1111
#' @param greeting A string in Markdown format, containing the initial message
1212
#' to display to the user upon first loading the chatbot. If not provided, the
1313
#' LLM will be invoked at the start of the conversation to generate one.
14-
#' @param data_description A string in plain text or Markdown format, containing
15-
#' a description of the data frame or any additional context that might be
16-
#' helpful in understanding the data. This will be included in the system
17-
#' prompt for the chat model. If a `system_prompt` argument is provided, the
18-
#' `data_description` argument will be ignored.
19-
#' @param extra_instructions A string in plain text or Markdown format, containing
20-
#' any additional instructions for the chat model. These will be appended at
21-
#' the end of the system prompt. If a `system_prompt` argument is provided,
22-
#' the `extra_instructions` argument will be ignored.
23-
#' @param create_chat_func A function that takes a system prompt and returns a
24-
#' chat object. The default uses `ellmer::chat_openai()`.
14+
#' @param ... Additional arguments passed to the `querychat_system_prompt()`
15+
#' function, such as `categorical_threshold`, and `prompt_path`. If a
16+
#' `system_prompt` argument is provided, the `...` arguments will be silently
17+
#' ignored.
18+
#' @inheritParams querychat_system_prompt
2519
#' @param system_prompt A string containing the system prompt for the chat model.
2620
#' The default uses `querychat_system_prompt()` to generate a generic prompt,
2721
#' which you can enhance via the `data_description` and `extra_instructions`
2822
#' arguments.
29-
#'
23+
#' @param create_chat_func A function that takes a system prompt and returns a
24+
#' chat object. The default uses `ellmer::chat_openai()`.
3025
#' @returns An object that can be passed to `querychat_server()` as the
3126
#' `querychat_config` argument. By convention, this object should be named
3227
#' `querychat_config`.
3328
#'
3429
#' @export
3530
querychat_init <- function(
3631
df,
37-
tbl_name = deparse(substitute(df)),
32+
...,
33+
table_name = deparse(substitute(df)),
3834
greeting = NULL,
3935
data_description = NULL,
4036
extra_instructions = NULL,
41-
create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o"),
4237
system_prompt = querychat_system_prompt(
4338
df,
44-
tbl_name,
39+
table_name,
40+
# By default, pass through any params supplied to querychat_init()
41+
...,
4542
data_description = data_description,
4643
extra_instructions = extra_instructions
47-
)
44+
),
45+
create_chat_func = purrr::partial(ellmer::chat_openai, model = "gpt-4o")
4846
) {
49-
is_tbl_name_ok <- is.character(tbl_name) &&
50-
length(tbl_name) == 1 &&
51-
grepl("^[a-zA-Z][a-zA-Z0-9_]*$", tbl_name, perl = TRUE)
52-
if (!is_tbl_name_ok) {
53-
if (missing(tbl_name)) {
47+
is_table_name_ok <- is.character(table_name) &&
48+
length(table_name) == 1 &&
49+
grepl("^[a-zA-Z][a-zA-Z0-9_]*$", table_name, perl = TRUE)
50+
if (!is_table_name_ok) {
51+
if (missing(table_name)) {
5452
rlang::abort(
55-
"Unable to infer table name from `df` argument. Please specify `tbl_name` argument explicitly."
53+
"Unable to infer table name from `df` argument. Please specify `table_name` argument explicitly."
5654
)
5755
} else {
5856
rlang::abort(
59-
"`tbl_name` argument must be a string containing a valid table name."
57+
"`table_name` argument must be a string containing a valid table name."
6058
)
6159
}
6260
}
6361

6462
force(df)
65-
force(system_prompt)
63+
force(system_prompt) # Have default `...` params evaluated
6664
force(create_chat_func)
6765

6866
# TODO: Provide nicer looking errors here
6967
stopifnot(
7068
"df must be a data frame" = is.data.frame(df),
71-
"tbl_name must be a string" = is.character(tbl_name),
69+
"table_name must be a string" = is.character(table_name),
7270
"system_prompt must be a string" = is.character(system_prompt),
7371
"create_chat_func must be a function" = is.function(create_chat_func)
7472
)
7573

74+
if ("table_name" %in% names(attributes(system_prompt))) {
75+
# If available, be sure to use the `table_name` argument to `querychat_init()`
76+
# matches the one supplied to the system prompt
77+
if (table_name != attr(system_prompt, "table_name")) {
78+
rlang::abort(
79+
"`querychat_init(table_name=)` must match system prompt `table_name` supplied to `querychat_system_prompt()`."
80+
)
81+
}
82+
}
7683
if (!is.null(greeting)) {
7784
greeting <- paste(collapse = "\n", greeting)
7885
} else {
@@ -83,7 +90,7 @@ querychat_init <- function(
8390
}
8491

8592
conn <- DBI::dbConnect(duckdb::duckdb(), dbdir = ":memory:")
86-
duckdb::duckdb_register(conn, tbl_name, df, experimental = FALSE)
93+
duckdb::duckdb_register(conn, table_name, df, experimental = FALSE)
8794
shiny::onStop(function() DBI::dbDisconnect(conn))
8895

8996
structure(

pkg-r/man/df_to_schema.Rd

Lines changed: 29 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)