Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pkg-py/docs/index.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ Currently, querychat uses DuckDB for its SQL engine. It's extremely fast and has

### Provide a greeting (recommended)

When the querychat UI first appears, you will usually want it to greet the user with some basic instructions. By default, these instructions are auto-generated every time a user arrives; this is slow, wasteful, and unpredictable. Instead, you should create a file called `greeting.md`, and when calling `querychat.init`, pass `greeting=Path("greeting.md").read_text()`.
When the querychat UI first appears, you will usually want it to greet the user with some basic instructions. By default, these instructions are auto-generated every time a user arrives; this is slow, wasteful, and unpredictable. Instead, you should create a file called `greeting.md`, and when calling `querychat.init`, pass `greeting=Path("greeting.md")`.

You can provide suggestions to the user by using the `<span class="suggestion"> </span>` tag.

Expand Down Expand Up @@ -141,7 +141,7 @@ which you can then pass via:
querychat_config = querychat.init(
titanic,
"titanic",
data_description=Path("data_description.md").read_text()
data_description=Path("data_description.md")
)
```

Expand All @@ -163,7 +163,7 @@ querychat_config = querychat.init(
)
```

You can also put these instructions in a separate file and use `Path("instructions.md").read_text()` to load them, as we did for `data_description` above.
You can also put these instructions in a separate file and use `Path("instructions.md")` to load them, as we did for `data_description` above.

**Warning:** It is not 100% guaranteed that the LLM will always—or in many cases, ever—obey your instructions, and it can be difficult to predict which instructions will be a problem. So be sure to test extensively each time you change your instructions, and especially, if you change the model you use.

Expand Down
7 changes: 3 additions & 4 deletions pkg-py/examples/app-database-sqlite.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from pathlib import Path

import chatlas
import querychat as qc
from seaborn import load_dataset
from shiny import App, render, ui
from sqlalchemy import create_engine

import querychat as qc

# Load titanic data and create SQLite database
db_path = Path(__file__).parent / "titanic.db"
engine = create_engine("sqlite:///" + str(db_path))
Expand All @@ -17,8 +16,8 @@
titanic = load_dataset("titanic")
titanic.to_sql("titanic", engine, if_exists="replace", index=False)

greeting = (Path(__file__).parent / "greeting.md").read_text()
data_desc = (Path(__file__).parent / "data_description.md").read_text()
greeting = Path(__file__).parent / "greeting.md"
data_desc = Path(__file__).parent / "data_description.md"

# 1. Configure querychat

Expand Down
7 changes: 3 additions & 4 deletions pkg-py/examples/app-dataframe-pandas.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
from pathlib import Path

import chatlas
import querychat as qc
from seaborn import load_dataset
from shiny import App, render, ui

import querychat as qc

titanic = load_dataset("titanic")

greeting = (Path(__file__).parent / "greeting.md").read_text()
data_desc = (Path(__file__).parent / "data_description.md").read_text()
greeting = Path(__file__).parent / "greeting.md"
data_desc = Path(__file__).parent / "data_description.md"

# 1. Configure querychat

Expand Down
94 changes: 60 additions & 34 deletions pkg-py/src/querychat/querychat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import re
import sys
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, Union
Expand All @@ -23,22 +24,16 @@ class CreateChatCallback(Protocol):
def __call__(self, system_prompt: str) -> chatlas.Chat: ...


@dataclass
class QueryChatConfig:
"""
Configuration class for querychat.
"""

def __init__(
self,
data_source: DataSource,
system_prompt: str,
greeting: Optional[str],
create_chat_callback: CreateChatCallback,
):
self.data_source = data_source
self.system_prompt = system_prompt
self.greeting = greeting
self.create_chat_callback = create_chat_callback
data_source: DataSource
system_prompt: str
greeting: Optional[str]
create_chat_callback: CreateChatCallback


class QueryChat:
Expand Down Expand Up @@ -140,8 +135,9 @@ def __getitem__(self, key: str) -> Any:

def system_prompt(
data_source: DataSource,
data_description: Optional[str] = None,
extra_instructions: Optional[str] = None,
*,
data_description: Optional[str | Path] = None,
extra_instructions: Optional[str | Path] = None,
categorical_threshold: int = 10,
prompt_path: Optional[Path] = None,
) -> str:
Expand Down Expand Up @@ -179,15 +175,27 @@ def system_prompt(

prompt_text = prompt_path.read_text()

data_description_str: str | None = (
data_description.read_text()
if isinstance(data_description, Path)
else data_description
)

extra_instructions_str: str | None = (
extra_instructions.read_text()
if isinstance(extra_instructions, Path)
else extra_instructions
)

return chevron.render(
prompt_text,
{
"db_engine": data_source.db_engine,
"schema": data_source.get_schema(
categorical_threshold=categorical_threshold,
),
"data_description": data_description,
"extra_instructions": extra_instructions,
"data_description": data_description_str,
"extra_instructions": extra_instructions_str,
},
)

Expand Down Expand Up @@ -232,11 +240,10 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str:
def init(
data_source: IntoFrame | sqlalchemy.Engine,
table_name: str,
/,
*,
greeting: Optional[str] = None,
data_description: Optional[str] = None,
extra_instructions: Optional[str] = None,
greeting: Optional[str | Path] = None,
data_description: Optional[str | Path] = None,
extra_instructions: Optional[str | Path] = None,
prompt_path: Optional[Path] = None,
system_prompt_override: Optional[str] = None,
create_chat_callback: Optional[CreateChatCallback] = None,
Expand All @@ -254,12 +261,18 @@ def init(
SQL queries (usually the variable name of the data frame, but it doesn't
have to be). If a data_source is a SQLAlchemy engine, the table_name is
the name of the table in the database to query against.
greeting : str, optional
A string in Markdown format, containing the initial message
data_description : str, optional
Description of the data in plain text or Markdown
extra_instructions : str, optional
Additional instructions for the chat model
greeting : str | Path, optional
A string in Markdown format, containing the initial message.
If a pathlib.Path object is passed,
querychat will read the contents of the path into a string with `.read_text()`.
data_description : str | Path, optional
Description of the data in plain text or Markdown.
If a pathlib.Path object is passed,
querychat will read the contents of the path into a string with `.read_text()`.
extra_instructions : str | Path, optional
Additional instructions for the chat model.
If a pathlib.Path object is passed,
querychat will read the contents of the path into a string with `.read_text()`.
prompt_path : Path, optional
Path to a custom prompt file. If not provided, the default querychat
template will be used. This should be a Markdown file that contains the
Expand Down Expand Up @@ -305,14 +318,22 @@ def init(
file=sys.stderr,
)

# Create the system prompt, or use the override
_system_prompt = system_prompt_override or system_prompt(
data_source_obj,
data_description,
extra_instructions,
prompt_path=prompt_path,
# quality of life improvement to do the Path.read_text() for user or pass along the string
greeting_str: str | None = (
greeting.read_text() if isinstance(greeting, Path) else greeting
)

# Create the system prompt, or use the override
if isinstance(system_prompt_override, Path):
system_prompt_ = system_prompt_override.read_text()
else:
system_prompt_ = system_prompt_override or system_prompt(
data_source_obj,
data_description=data_description,
extra_instructions=extra_instructions,
prompt_path=prompt_path,
)

# Default chat function if none provided
create_chat_callback = create_chat_callback or partial(
chatlas.ChatOpenAI,
Expand All @@ -321,8 +342,8 @@ def init(

return QueryChatConfig(
data_source=data_source_obj,
system_prompt=_system_prompt,
greeting=greeting,
system_prompt=system_prompt_,
greeting=greeting_str,
create_chat_callback=create_chat_callback,
)

Expand Down Expand Up @@ -354,7 +375,12 @@ def mod_ui() -> ui.TagList:
)


def sidebar(id: str, width: int = 400, height: str = "100%", **kwargs) -> ui.Sidebar:
def sidebar(
id: str,
width: int = 400,
height: str = "100%",
**kwargs,
) -> ui.Sidebar:
"""
Create a sidebar containing the querychat UI.

Expand Down