diff --git a/pkg-py/docs/index.qmd b/pkg-py/docs/index.qmd index 72b69adf..462f6d91 100644 --- a/pkg-py/docs/index.qmd +++ b/pkg-py/docs/index.qmd @@ -59,7 +59,7 @@ Currently, querychat uses DuckDB for its SQL engine. It's extremely fast and has ### Provide a greeting (recommended) -When the querychat UI first appears, you will usually want it to greet the user with some basic instructions. By default, these instructions are auto-generated every time a user arrives; this is slow, wasteful, and unpredictable. Instead, you should create a file called `greeting.md`, and when calling `querychat.init`, pass `greeting=Path("greeting.md").read_text()`. +When the querychat UI first appears, you will usually want it to greet the user with some basic instructions. By default, these instructions are auto-generated every time a user arrives; this is slow, wasteful, and unpredictable. Instead, you should create a file called `greeting.md`, and when calling `querychat.init`, pass `greeting=Path("greeting.md")`. You can provide suggestions to the user by using the ` ` tag. @@ -141,7 +141,7 @@ which you can then pass via: querychat_config = querychat.init( titanic, "titanic", - data_description=Path("data_description.md").read_text() + data_description=Path("data_description.md") ) ``` @@ -163,7 +163,7 @@ querychat_config = querychat.init( ) ``` -You can also put these instructions in a separate file and use `Path("instructions.md").read_text()` to load them, as we did for `data_description` above. +You can also put these instructions in a separate file and use `Path("instructions.md")` to load them, as we did for `data_description` above. **Warning:** It is not 100% guaranteed that the LLM will always—or in many cases, ever—obey your instructions, and it can be difficult to predict which instructions will be a problem. So be sure to test extensively each time you change your instructions, and especially, if you change the model you use. diff --git a/pkg-py/examples/app-database-sqlite.py b/pkg-py/examples/app-database-sqlite.py index d451a900..c5107fb7 100644 --- a/pkg-py/examples/app-database-sqlite.py +++ b/pkg-py/examples/app-database-sqlite.py @@ -1,12 +1,11 @@ from pathlib import Path import chatlas +import querychat as qc from seaborn import load_dataset from shiny import App, render, ui from sqlalchemy import create_engine -import querychat as qc - # Load titanic data and create SQLite database db_path = Path(__file__).parent / "titanic.db" engine = create_engine("sqlite:///" + str(db_path)) @@ -17,8 +16,8 @@ titanic = load_dataset("titanic") titanic.to_sql("titanic", engine, if_exists="replace", index=False) -greeting = (Path(__file__).parent / "greeting.md").read_text() -data_desc = (Path(__file__).parent / "data_description.md").read_text() +greeting = Path(__file__).parent / "greeting.md" +data_desc = Path(__file__).parent / "data_description.md" # 1. Configure querychat diff --git a/pkg-py/examples/app-dataframe-pandas.py b/pkg-py/examples/app-dataframe-pandas.py index e860e393..dac93c6e 100644 --- a/pkg-py/examples/app-dataframe-pandas.py +++ b/pkg-py/examples/app-dataframe-pandas.py @@ -1,15 +1,14 @@ from pathlib import Path import chatlas +import querychat as qc from seaborn import load_dataset from shiny import App, render, ui -import querychat as qc - titanic = load_dataset("titanic") -greeting = (Path(__file__).parent / "greeting.md").read_text() -data_desc = (Path(__file__).parent / "data_description.md").read_text() +greeting = Path(__file__).parent / "greeting.md" +data_desc = Path(__file__).parent / "data_description.md" # 1. Configure querychat diff --git a/pkg-py/src/querychat/querychat.py b/pkg-py/src/querychat/querychat.py index c15db100..865facca 100644 --- a/pkg-py/src/querychat/querychat.py +++ b/pkg-py/src/querychat/querychat.py @@ -2,6 +2,7 @@ import re import sys +from dataclasses import dataclass from functools import partial from pathlib import Path from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, Union @@ -23,22 +24,16 @@ class CreateChatCallback(Protocol): def __call__(self, system_prompt: str) -> chatlas.Chat: ... +@dataclass class QueryChatConfig: """ Configuration class for querychat. """ - def __init__( - self, - data_source: DataSource, - system_prompt: str, - greeting: Optional[str], - create_chat_callback: CreateChatCallback, - ): - self.data_source = data_source - self.system_prompt = system_prompt - self.greeting = greeting - self.create_chat_callback = create_chat_callback + data_source: DataSource + system_prompt: str + greeting: Optional[str] + create_chat_callback: CreateChatCallback class QueryChat: @@ -140,8 +135,9 @@ def __getitem__(self, key: str) -> Any: def system_prompt( data_source: DataSource, - data_description: Optional[str] = None, - extra_instructions: Optional[str] = None, + *, + data_description: Optional[str | Path] = None, + extra_instructions: Optional[str | Path] = None, categorical_threshold: int = 10, prompt_path: Optional[Path] = None, ) -> str: @@ -179,6 +175,18 @@ def system_prompt( prompt_text = prompt_path.read_text() + data_description_str: str | None = ( + data_description.read_text() + if isinstance(data_description, Path) + else data_description + ) + + extra_instructions_str: str | None = ( + extra_instructions.read_text() + if isinstance(extra_instructions, Path) + else extra_instructions + ) + return chevron.render( prompt_text, { @@ -186,8 +194,8 @@ def system_prompt( "schema": data_source.get_schema( categorical_threshold=categorical_threshold, ), - "data_description": data_description, - "extra_instructions": extra_instructions, + "data_description": data_description_str, + "extra_instructions": extra_instructions_str, }, ) @@ -232,11 +240,10 @@ def df_to_html(df: IntoFrame, maxrows: int = 5) -> str: def init( data_source: IntoFrame | sqlalchemy.Engine, table_name: str, - /, *, - greeting: Optional[str] = None, - data_description: Optional[str] = None, - extra_instructions: Optional[str] = None, + greeting: Optional[str | Path] = None, + data_description: Optional[str | Path] = None, + extra_instructions: Optional[str | Path] = None, prompt_path: Optional[Path] = None, system_prompt_override: Optional[str] = None, create_chat_callback: Optional[CreateChatCallback] = None, @@ -254,12 +261,18 @@ def init( SQL queries (usually the variable name of the data frame, but it doesn't have to be). If a data_source is a SQLAlchemy engine, the table_name is the name of the table in the database to query against. - greeting : str, optional - A string in Markdown format, containing the initial message - data_description : str, optional - Description of the data in plain text or Markdown - extra_instructions : str, optional - Additional instructions for the chat model + greeting : str | Path, optional + A string in Markdown format, containing the initial message. + If a pathlib.Path object is passed, + querychat will read the contents of the path into a string with `.read_text()`. + data_description : str | Path, optional + Description of the data in plain text or Markdown. + If a pathlib.Path object is passed, + querychat will read the contents of the path into a string with `.read_text()`. + extra_instructions : str | Path, optional + Additional instructions for the chat model. + If a pathlib.Path object is passed, + querychat will read the contents of the path into a string with `.read_text()`. prompt_path : Path, optional Path to a custom prompt file. If not provided, the default querychat template will be used. This should be a Markdown file that contains the @@ -305,14 +318,22 @@ def init( file=sys.stderr, ) - # Create the system prompt, or use the override - _system_prompt = system_prompt_override or system_prompt( - data_source_obj, - data_description, - extra_instructions, - prompt_path=prompt_path, + # quality of life improvement to do the Path.read_text() for user or pass along the string + greeting_str: str | None = ( + greeting.read_text() if isinstance(greeting, Path) else greeting ) + # Create the system prompt, or use the override + if isinstance(system_prompt_override, Path): + system_prompt_ = system_prompt_override.read_text() + else: + system_prompt_ = system_prompt_override or system_prompt( + data_source_obj, + data_description=data_description, + extra_instructions=extra_instructions, + prompt_path=prompt_path, + ) + # Default chat function if none provided create_chat_callback = create_chat_callback or partial( chatlas.ChatOpenAI, @@ -321,8 +342,8 @@ def init( return QueryChatConfig( data_source=data_source_obj, - system_prompt=_system_prompt, - greeting=greeting, + system_prompt=system_prompt_, + greeting=greeting_str, create_chat_callback=create_chat_callback, ) @@ -354,7 +375,12 @@ def mod_ui() -> ui.TagList: ) -def sidebar(id: str, width: int = 400, height: str = "100%", **kwargs) -> ui.Sidebar: +def sidebar( + id: str, + width: int = 400, + height: str = "100%", + **kwargs, +) -> ui.Sidebar: """ Create a sidebar containing the querychat UI.