Skip to content

Commit 1d87120

Browse files
cpsievertgadenbuie
andauthored
feat!(pkg-r): new data_source plugin API (#123)
* feat!(pkg-r): new data_source plugin API * `air format` (GitHub Actions) * Address feedback * `air format` (GitHub Actions) * `devtools::document()` (GitHub Actions) * Apply suggestions from code review Co-authored-by: Garrick Aden-Buie <[email protected]> * Remove file committed by mistake * get_system_prompt -> assemble_system_prompt * `devtools::document()` (GitHub Actions) * Missed a renaming --------- Co-authored-by: cpsievert <[email protected]> Co-authored-by: Garrick Aden-Buie <[email protected]>
1 parent 4029536 commit 1d87120

24 files changed

+1151
-562
lines changed

pkg-py/src/querychat/_querychat.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def __init__(
4949

5050
self.greeting = greeting.read_text() if isinstance(greeting, Path) else greeting
5151

52-
prompt = get_system_prompt(
52+
prompt = assemble_system_prompt(
5353
self._data_source,
5454
data_description=data_description,
5555
extra_instructions=extra_instructions,
@@ -682,7 +682,7 @@ def as_querychat_client(client: str | chatlas.Chat | None) -> chatlas.Chat:
682682
return chatlas.ChatAuto(provider_model=client)
683683

684684

685-
def get_system_prompt(
685+
def assemble_system_prompt(
686686
data_source: DataSource,
687687
*,
688688
data_description: Optional[str | Path] = None,

pkg-r/NAMESPACE

Lines changed: 3 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,9 @@
11
# Generated by roxygen2: do not edit by hand
22

3-
S3method(as_querychat_data_source,DBIConnection)
4-
S3method(as_querychat_data_source,data.frame)
5-
S3method(cleanup_source,dbi_source)
6-
S3method(create_system_prompt,querychat_data_source)
7-
S3method(execute_query,dbi_source)
8-
S3method(get_db_type,data_frame_source)
9-
S3method(get_db_type,dbi_source)
10-
S3method(get_db_type,default)
11-
S3method(get_schema,dbi_source)
12-
S3method(test_query,dbi_source)
3+
export(DBISource)
4+
export(DataFrameSource)
5+
export(DataSource)
136
export(QueryChat)
14-
export(as_querychat_data_source)
15-
export(cleanup_source)
16-
export(create_system_prompt)
17-
export(execute_query)
18-
export(get_db_type)
19-
export(get_schema)
207
export(querychat)
218
export(querychat_app)
229
export(querychat_data_source)
@@ -25,7 +12,6 @@ export(querychat_init)
2512
export(querychat_server)
2613
export(querychat_sidebar)
2714
export(querychat_ui)
28-
export(test_query)
2915
importFrom(R6,R6Class)
3016
importFrom(bslib,sidebar)
3117
importFrom(lifecycle,deprecated)

pkg-r/R/QueryChat.R

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ QueryChat <- R6::R6Class(
171171
}
172172
self$greeting <- greeting
173173

174-
prompt <- create_system_prompt(
174+
prompt <- assemble_system_prompt(
175175
private$.data_source,
176176
data_description = data_description,
177177
categorical_threshold = categorical_threshold,
@@ -522,7 +522,10 @@ QueryChat <- R6::R6Class(
522522
#'
523523
#' @return Invisibly returns `NULL`. Resources are cleaned up internally.
524524
cleanup = function() {
525-
cleanup_source(private$.data_source)
525+
if (!is.null(private$.data_source)) {
526+
private$.data_source$cleanup()
527+
}
528+
invisible(NULL)
526529
}
527530
),
528531
active = list(
@@ -686,8 +689,22 @@ querychat_app <- function(
686689

687690
normalize_data_source <- function(data_source, table_name) {
688691
if (is_data_source(data_source)) {
689-
data_source
690-
} else {
691-
as_querychat_data_source(data_source, table_name)
692+
return(data_source)
693+
}
694+
695+
if (is.data.frame(data_source)) {
696+
return(DataFrameSource$new(data_source, table_name))
692697
}
698+
699+
if (inherits(data_source, "DBIConnection")) {
700+
return(DBISource$new(data_source, table_name))
701+
}
702+
703+
cli::cli_abort(
704+
paste0(
705+
"`data_source` must be a DataSource, data.frame, or DBIConnection. ",
706+
"Got: ",
707+
class(data_source)[1]
708+
)
709+
)
693710
}

0 commit comments

Comments
 (0)