diff --git a/pkg-py/CHANGELOG.md b/pkg-py/CHANGELOG.md index 418e24f5..780e5b42 100644 --- a/pkg-py/CHANGELOG.md +++ b/pkg-py/CHANGELOG.md @@ -7,6 +7,12 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [UNRELEASED] +### Changes + +* The entire functional API (i.e., `init()`, `sidebar()`, `server()`, etc) has been hard deprecated in favor of a simpler OOP-based API. Namely, the new `QueryChat()` class is now the main entry point (instead of `init()`) and has methods to replace old functions (e.g., `.sidebar()`, `.server()`, etc). (#101) + +## [UNRELEASED] + ### New features * The `.sql` query and `.title` returned from `querychat.server()` are now reactive values, meaning you can now `.set()` their value, and `.df()` will update accordingly. (#98) diff --git a/pkg-py/src/querychat/__init__.py b/pkg-py/src/querychat/__init__.py index b5e4279d..ea402e9f 100644 --- a/pkg-py/src/querychat/__init__.py +++ b/pkg-py/src/querychat/__init__.py @@ -1,14 +1,15 @@ -from querychat._greeting import greeting -from querychat.querychat import ( - init, - sidebar, - system_prompt, -) -from querychat.querychat import ( - mod_server as server, -) -from querychat.querychat import ( - mod_ui as ui, -) +from ._deprecated import greeting, init, sidebar, system_prompt +from ._deprecated import mod_server as server +from ._deprecated import mod_ui as ui +from ._querychat import QueryChat -__all__ = ["greeting", "init", "server", "sidebar", "system_prompt", "ui"] +__all__ = ( + "QueryChat", + # TODO(lifecycle): Remove these deprecated functions when we reach v1.0 + "greeting", + "init", + "server", + "sidebar", + "system_prompt", + "ui", +) diff --git a/pkg-py/src/querychat/_deprecated.py b/pkg-py/src/querychat/_deprecated.py new file mode 100644 index 00000000..2e2ea19d --- /dev/null +++ b/pkg-py/src/querychat/_deprecated.py @@ -0,0 +1,108 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Optional, Union + +from shiny import Inputs, Outputs, Session, module, ui + +if TYPE_CHECKING: + from pathlib import Path + + import chatlas + import sqlalchemy + from narwhals.stable.v1.typing import IntoFrame + + from .datasource import DataSource + + +def init( + data_source: IntoFrame | sqlalchemy.Engine, + table_name: str, + *, + greeting: Optional[str | Path] = None, + data_description: Optional[str | Path] = None, + extra_instructions: Optional[str | Path] = None, + prompt_template: Optional[str | Path] = None, + system_prompt_override: Optional[str] = None, + client: Optional[Union[chatlas.Chat, str]] = None, +): + """ + Initialize querychat with any compliant data source. + + **Deprecated.** Use `QueryChat()` instead. + """ + raise RuntimeError("init() is deprecated. Use QueryChat() instead.") + + +@module.ui +def mod_ui(**kwargs) -> ui.TagList: + """ + Create the UI for the querychat component. + + **Deprecated.** Use `QueryChat.ui()` instead. + """ + raise RuntimeError("mod_ui() is deprecated. Use QueryChat.ui() instead.") + + +@module.server +def mod_server( + input: Inputs, + output: Outputs, + session: Session, + querychat_config: Any, +): + """ + Initialize the querychat server. + + **Deprecated.** Use `QueryChat.server()` instead. + """ + raise RuntimeError("mod_server() is deprecated. Use QueryChat.server() instead.") + + +def sidebar( + id: str, + width: int = 400, + height: str = "100%", + **kwargs, +) -> ui.Sidebar: + """ + Create a sidebar containing the querychat UI. + + **Deprecated.** Use `QueryChat.sidebar()` instead. + """ + raise RuntimeError("sidebar() is deprecated. Use QueryChat.sidebar() instead.") + + +def system_prompt( + data_source: DataSource, + *, + data_description: Optional[str | Path] = None, + extra_instructions: Optional[str | Path] = None, + categorical_threshold: int = 10, + prompt_template: Optional[str | Path] = None, +) -> str: + """ + Create a system prompt for the chat model based on a data source's schema + and optional additional context and instructions. + + **Deprecated.** Use `QueryChat.set_system_prompt()` instead. + """ + raise RuntimeError( + "system_prompt() is deprecated. Use QueryChat.set_system_prompt() instead." + ) + + +def greeting( + querychat_config, + *, + generate: bool = True, + stream: bool = False, + **kwargs, +) -> str | None: + """ + Generate or retrieve a greeting message. + + **Deprecated.** Use `QueryChat.generate_greeting()` instead. + """ + raise RuntimeError( + "greeting() is deprecated. Use QueryChat.generate_greeting() instead." + ) diff --git a/pkg-py/src/querychat/_greeting.py b/pkg-py/src/querychat/_greeting.py deleted file mode 100644 index bf0a6472..00000000 --- a/pkg-py/src/querychat/_greeting.py +++ /dev/null @@ -1,95 +0,0 @@ -from __future__ import annotations - -from copy import deepcopy - - -def greeting( - querychat_config, - *, - generate: bool = True, - stream: bool = False, - **kwargs, -) -> str | None: - """ - Generate or retrieve a greeting message. - - Use this function to generate a friendly greeting message using the chat - client and data source specified in the `querychat_config` object. You can - pass this greeting to `init()` to set an initial greeting for users for - faster startup times and lower costs. If you don't provide a greeting in - `init()`, one will be generated at the start of every new conversation. - - Parameters - ---------- - querychat_config - A QueryChatConfig object from `init()`. - generate - If `True` and if `querychat_config` does not include a `greeting`, a new - greeting is generated. If `False`, returns the existing greeting from - the configuration (if any). - stream - If `True`, returns a streaming response suitable for use in a Shiny app - with `chat_ui.append_message_stream()`. If `False` (default), returns - the full greeting at once. Only relevant when `generate = True`. - **kwargs - Additional arguments passed to the chat client's `chat()` or `stream_async()` method. - - Returns - ------- - str | None - - When `generate = False`: Returns the existing greeting as a string or - `None` if no greeting exists. - - When `generate = True`: Returns the chat response containing a greeting and - sample prompts. - - Examples - -------- - ```python - import pandas as pd - from querychat import init, greeting - - # Create config with mtcars dataset - mtcars = pd.read_csv( - "https://gist.githubusercontent.com/seankross/a412dfbd88b3db70b74b/raw/5f23f993cd87c283ce766e7ac6b329ee7cc2e1d1/mtcars.csv" - ) - mtcars_config = init(mtcars, "mtcars") - - # Generate a new greeting - greeting_text = greeting(mtcars_config) - - # Update the config with the generated greeting - mtcars_config = init( - mtcars, - "mtcars", - greeting="Hello! I'm here to help you explore and analyze the mtcars...", - ) - ``` - - """ - not_querychat_config = ( - not hasattr(querychat_config, "client") - and not hasattr(querychat_config, "greeting") - and not hasattr(querychat_config, "system_prompt") - ) - - if not_querychat_config: - raise TypeError("`querychat_config` must be a QueryChatConfig object.") - - greeting_text = querychat_config.greeting - has_greeting = greeting_text is not None and len(greeting_text.strip()) > 0 - - if has_greeting: - return greeting_text - - if not generate: - return None - - chat = deepcopy(querychat_config.client) - chat.system_prompt = querychat_config.system_prompt - - prompt = "Please give me a friendly greeting. Include a few sample prompts in a two-level bulleted list." - - if stream: - return chat.stream_async(prompt, **kwargs) - else: - return chat.chat(prompt, **kwargs) diff --git a/pkg-py/src/querychat/_querychat.py b/pkg-py/src/querychat/_querychat.py new file mode 100644 index 00000000..7cc27462 --- /dev/null +++ b/pkg-py/src/querychat/_querychat.py @@ -0,0 +1,599 @@ +from __future__ import annotations + +import copy +import os +import re +import sys +from pathlib import Path +from typing import TYPE_CHECKING, Literal, Optional, overload + +import chatlas +import chevron +import sqlalchemy +from shiny import ui +from shiny.session import get_current_session + +from ._querychat_module import ModServerResult, mod_server, mod_ui +from .datasource import DataFrameSource, DataSource, SQLAlchemySource + +if TYPE_CHECKING: + import pandas as pd + from narwhals.stable.v1.typing import IntoFrame + + +class QueryChatBase: + """ + Create a QueryChat instance. + + This is the main entry point for using querychat. + """ + + def __init__( + self, + data_source: IntoFrame | sqlalchemy.Engine, + table_name: str, + *, + id: Optional[str] = None, + greeting: Optional[str | Path] = None, + client: Optional[str | chatlas.Chat] = None, + data_description: Optional[str | Path] = None, + extra_instructions: Optional[str | Path] = None, + prompt_template: Optional[str | Path] = None, + ): + """ + Initialize QueryChat. + + Parameters + ---------- + data_source + Either a Narwhals-compatible data frame (e.g., Polars or Pandas) or a + SQLAlchemy engine containing the table to query against. + table_name + If a data_source is a data frame, a name to use to refer to the table in + SQL queries (usually the variable name of the data frame, but it doesn't + have to be). If a data_source is a SQLAlchemy engine, the table_name is + the name of the table in the database to query against. + id + An optional ID for the QueryChat module. If not provided, an ID will be + generated based on the table_name. + greeting + A string in Markdown format, containing the initial message. If a + pathlib.Path object is passed, querychat will read the contents of the + path into a string with `.read_text()`. You can use + `querychat.greeting()` to help generate a greeting from a querychat + configuration. If no greeting is provided, one will be generated at the + start of every new conversation. + client + A `chatlas.Chat` object or a string to be passed to + `chatlas.ChatAuto()`'s `provider_model` parameter, describing the + provider and model combination to use (e.g. `"openai/gpt-4.1"`, + "anthropic/claude-sonnet-4-5", "google/gemini-2.5-flash". etc). + + If `client` is not provided, querychat consults the + `QUERYCHAT_CLIENT` environment variable. If that is not set, it + defaults to `"openai"`. + data_description + Description of the data in plain text or Markdown. If a pathlib.Path + object is passed, querychat will read the contents of the path into a + string with `.read_text()`. + extra_instructions + Additional instructions for the chat model. If a pathlib.Path object is + passed, querychat will read the contents of the path into a string with + `.read_text()`. + prompt_template + Path to or a string of a custom prompt file. If not provided, the default querychat + template will be used. This should be a Markdown file that contains the + system prompt template. The mustache template can use the following + variables: + - `{{db_engine}}`: The database engine used (e.g., "DuckDB") + - `{{schema}}`: The schema of the data source, generated by + `data_source.get_schema()` + - `{{data_description}}`: The optional data description provided + - `{{extra_instructions}}`: Any additional instructions provided + + Examples + -------- + ```python + from querychat import QueryChat + + qc = QueryChat(my_dataframe, "my_data") + qc.app() + ``` + + """ + self.data_source = normalize_data_source(data_source, table_name) + + # Validate table name (must begin with letter, contain only letters, numbers, underscores) + if not re.match(r"^[a-zA-Z][a-zA-Z0-9_]*$", table_name): + raise ValueError( + "Table name must begin with a letter and contain only letters, numbers, and underscores", + ) + + self.id = id or table_name + + self.client = normalize_client(client) + + if greeting is None: + print( + "Warning: No greeting provided; the LLM will be invoked at conversation start to generate one. " + "For faster startup, lower cost, and determinism, please save a greeting and pass it to init().", + "You can also use `querychat.greeting()` to help generate a greeting.", + file=sys.stderr, + ) + + self.greeting = greeting.read_text() if isinstance(greeting, Path) else greeting + + self.system_prompt = get_system_prompt( + self.data_source, + data_description=data_description, + extra_instructions=extra_instructions, + prompt_template=prompt_template, + ) + + # Populated when ._server() gets called (in an active session) + self._server_values: ModServerResult | None = None + + def sidebar( + self, + *, + width: int = 400, + height: str = "100%", + **kwargs, + ) -> ui.Sidebar: + """ + Create a sidebar containing the querychat UI. + + Parameters + ---------- + width + Width of the sidebar in pixels. + height + Height of the sidebar. + **kwargs + Additional arguments passed to `shiny.ui.sidebar()`. + + Returns + ------- + : + A sidebar UI component. + + """ + return ui.sidebar( + self.ui(), + width=width, + height=height, + class_="querychat-sidebar", + **kwargs, + ) + + def ui(self, **kwargs): + """ + Create the UI for the querychat component. + + Parameters + ---------- + **kwargs + Additional arguments to pass to `shinychat.chat_ui()`. + + Returns + ------- + : + A UI component. + + """ + return mod_ui(self.id, **kwargs) + + def _server(self): + """ + Initialize the server module. + + Note: + ---- + This is a private method since it is called automatically in Express mode. + + """ + # Must be called within an active Shiny session + session = get_current_session() + if session is None: + raise RuntimeError( + "A Shiny session must be active in order to initialize QueryChat's server logic. " + "If you're using Shiny Core, make sure to call .server() within your server function." + ) + + # No-op for Express' stub session (i.e., it's 1st run) + if session.is_stub_session(): + return + + # Call the server module + self._server_values = mod_server( + self.id, + data_source=self.data_source, + system_prompt=self.system_prompt, + greeting=self.greeting, + client=self.client, + ) + + return + + def df(self) -> pd.DataFrame: + """ + Reactively read the current filtered data frame that is in effect. + + Returns + ------- + : + The current filtered data frame as a pandas DataFrame. If no query + has been set, this will return the unfiltered data frame from the + data source. + + Raises + ------ + RuntimeError + If `.server()` has not been called yet. + + """ + vals = self._server_values + if vals is None: + raise RuntimeError("Must call .server() before accessing .df()") + + return vals.df() + + @overload + def sql(self, query: None = None) -> str: ... + + @overload + def sql(self, query: str) -> bool: ... + + def sql(self, query: Optional[str] = None) -> str | bool: + """ + Reactively read (or set) the current SQL query that is in effect. + + Parameters + ---------- + query + If provided, sets the current SQL query to this value. + + Returns + ------- + : + If no `query` is provided, returns the current SQL query as a string + (possibly `""` if no query has been set). If a `query` is provided, + returns `True` if the query was changed to a new value, or `False` + if it was the same as the current value. + + Raises + ------ + RuntimeError + If `.server()` has not been called yet. + + """ + vals = self._server_values + if vals is None: + raise RuntimeError("Must call .server() before accessing .sql()") + + if query is None: + return vals.sql() + else: + return vals.sql.set(query) + + @overload + def title(self, value: None = None) -> str | None: ... + + @overload + def title(self, value: str) -> bool: ... + + def title(self, value: Optional[str] = None) -> str | None | bool: + """ + Reactively read (or set) the current title that is in effect. + + The title is a short description of the current query that the LLM + provides to us whenever it generates a new SQL query. It can be used as + a status string for the data dashboard. + + Parameters + ---------- + value + If provided, sets the current title to this value. + + Returns + ------- + : + If no `value` is provided, returns the current title as a string, or + `None` if no title has been set due to no SQL query being set. If a + `value` is provided, sets the current title to this value and + returns `True` if the title was changed to a new value, or `False` + if it was the same as the current value. + + Raises + ------ + RuntimeError + If `.server()` has not been called yet. + + """ + vals = self._server_values + if vals is None: + raise RuntimeError("Must call .server() before accessing .title()") + + if value is None: + return vals.title() + else: + return vals.title.set(value) + + def generate_greeting(self, *, echo: Literal["none", "text"] = "none"): + """ + Generate a welcome greeting for the chat. + + By default, `QueryChat()` generates a greeting at the start of every new + conversation, which is convenient for getting started and development, + but also might add unnecessary latency and cost. Use this method to + generate a greeting once and save it for reuse. + + Parameters + ---------- + echo + If `echo = "text"`, prints the greeting to standard output. If + `echo = "none"` (default), does not print anything. + + Returns + ------- + : + The greeting string (in Markdown format). + + """ + client = copy.deepcopy(self.client) + client.system_prompt = self.system_prompt + client.set_turns([]) + prompt = "Please give me a friendly greeting. Include a few sample prompts in a two-level bulleted list." + return str(client.chat(prompt, echo=echo)) + + def set_system_prompt( + self, + data_source: DataSource, + *, + data_description: Optional[str | Path] = None, + extra_instructions: Optional[str | Path] = None, + categorical_threshold: int = 10, + prompt_template: Optional[str | Path] = None, + ) -> None: + """ + Customize the system prompt. + + Control the logic behind how the system prompt is generated based on the + data source's schema and optional additional context and instructions. + + Note + ---- + This method is for parametrized system prompt generation only. To set a + fully custom system prompt string, set the `system_prompt` attribute + directly. + + Parameters + ---------- + data_source + A data source to generate schema information from + data_description + Optional description of the data, in plain text or Markdown format + extra_instructions + Optional additional instructions for the chat model, in plain text or + Markdown format + categorical_threshold + Threshold for determining if a column is categorical based on number of + unique values + prompt_template + Optional `Path` to or string of a custom prompt template. If not provided, the default + querychat template will be used. + + """ + self.system_prompt = get_system_prompt( + data_source, + data_description=data_description, + extra_instructions=extra_instructions, + categorical_threshold=categorical_threshold, + prompt_template=prompt_template, + ) + + def set_data_source( + self, data_source: IntoFrame | sqlalchemy.Engine | DataSource, table_name: str + ) -> None: + """ + Set a new data source for the QueryChat object. + + Parameters + ---------- + data_source + The new data source to use. + table_name + If a data_source is a data frame, a name to use to refer to the table + + Returns + ------- + : + None + + """ + self.data_source = normalize_data_source(data_source, table_name) + + def set_client(self, client: str | chatlas.Chat) -> None: + """ + Set a new chat client for the QueryChat object. + + Parameters + ---------- + client + A `chatlas.Chat` object or a string to be passed to + `chatlas.ChatAuto()` describing the model to use (e.g. + `"openai/gpt-4.1"`). + + Returns + ------- + : + None + + """ + self.client = normalize_client(client) + + +class QueryChat(QueryChatBase): + def server(self): + """ + Initialize Shiny server logic. + + This method is intended for use in Shiny Code mode, where the user must + explicitly call `.server()` within the Shiny server function. In Shiny + Express mode, you can use `querychat.express.QueryChat` instead + of `querychat.QueryChat`, which calls `.server()` automatically. + + Examples + -------- + ```python + from shiny import App, render, ui + from querychat import QueryChat + + qc = QueryChat(my_dataframe, "my_data") + + app_ui = ui.page_fluid( + qc.sidebar(), + ui.output_data_frame("data_table"), + ) + + + def server(input, output, session): + qc.server() + + @render.data_frame + def data_table(): + return qc.df() + + + app = App(app_ui, server) + ``` + + Returns + ------- + : + None + + """ + return self._server() + + +class QueryChatExpress(QueryChatBase): + """ + Use QueryChat with Shiny Express mode. + + This class makes it easy to use querychat within Shiny Express apps -- + it automatically calls `.server()` during initialization, so you don't + have to do it manually. + + Examples + -------- + ```python + from shiny.express import render, ui + from querychat.express import QueryChat + + qc = QueryChat(my_dataframe, "my_data") + + qc.sidebar() + + + @render.data_frame + def data_table(): + return qc.df() + ``` + + """ + + def __init__( + self, + data_source: IntoFrame | sqlalchemy.Engine, + table_name: str, + *, + id: Optional[str] = None, + greeting: Optional[str | Path] = None, + client: Optional[str | chatlas.Chat] = None, + data_description: Optional[str | Path] = None, + extra_instructions: Optional[str | Path] = None, + prompt_template: Optional[str | Path] = None, + ): + super().__init__( + data_source, + table_name, + id=id, + greeting=greeting, + client=client, + data_description=data_description, + extra_instructions=extra_instructions, + prompt_template=prompt_template, + ) + self._server() + + +def normalize_data_source( + data_source: IntoFrame | sqlalchemy.Engine | DataSource, + table_name: str, +) -> DataSource: + if isinstance(data_source, DataSource): + return data_source + if isinstance(data_source, sqlalchemy.Engine): + return SQLAlchemySource(data_source, table_name) + return DataFrameSource(data_source, table_name) + + +def normalize_client(client: str | chatlas.Chat | None) -> chatlas.Chat: + if client is None: + client = os.getenv("QUERYCHAT_CLIENT", None) + + if client is None: + client = "openai" + + if isinstance(client, chatlas.Chat): + return client + + return chatlas.ChatAuto(provider_model=client) + + +def get_system_prompt( + data_source: DataSource, + *, + data_description: Optional[str | Path] = None, + extra_instructions: Optional[str | Path] = None, + categorical_threshold: int = 10, + prompt_template: Optional[str | Path] = None, +) -> str: + # Read the prompt file + if prompt_template is None: + # Default to the prompt file in the same directory as this module + # This allows for easy customization by placing a different prompt.md file there + prompt_template = Path(__file__).parent / "prompts" / "prompt.md" + prompt_str = ( + prompt_template.read_text() + if isinstance(prompt_template, Path) + else prompt_template + ) + + data_description_str = ( + data_description.read_text() + if isinstance(data_description, Path) + else data_description + ) + + extra_instructions_str = ( + extra_instructions.read_text() + if isinstance(extra_instructions, Path) + else extra_instructions + ) + + is_duck_db = data_source.get_db_type().lower() == "duckdb" + + return chevron.render( + prompt_str, + { + "db_type": data_source.get_db_type(), + "is_duck_db": is_duck_db, + "schema": data_source.get_schema( + categorical_threshold=categorical_threshold, + ), + "data_description": data_description_str, + "extra_instructions": extra_instructions_str, + }, + ) diff --git a/pkg-py/src/querychat/_querychat_module.py b/pkg-py/src/querychat/_querychat_module.py new file mode 100644 index 00000000..2fe0f3cd --- /dev/null +++ b/pkg-py/src/querychat/_querychat_module.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import copy +from dataclasses import dataclass +from pathlib import Path +from typing import TYPE_CHECKING, Callable, Union + +import shinychat +from shiny import module, reactive, ui + +from .tools import tool_query, tool_reset_dashboard, tool_update_dashboard + +if TYPE_CHECKING: + import chatlas + import pandas as pd + from shiny import Inputs, Outputs, Session + + from .datasource import DataSource + +ReactiveString = reactive.Value[str] +"""A reactive string value.""" +ReactiveStringOrNone = reactive.Value[Union[str, None]] +"""A reactive string (or None) value.""" + +CHAT_ID = "chat" + + +@module.ui +def mod_ui(**kwargs): + css_path = Path(__file__).parent / "static" / "css" / "styles.css" + js_path = Path(__file__).parent / "static" / "js" / "querychat.js" + + tag = shinychat.chat_ui(CHAT_ID, **kwargs) + tag.add_class("querychat") + + return ui.TagList( + ui.head_content( + ui.include_css(css_path), + ui.include_js(js_path), + ), + tag, + ) + + +@dataclass +class ModServerResult: + df: Callable[[], pd.DataFrame] + sql: ReactiveString + title: ReactiveStringOrNone + client: chatlas.Chat + + +@module.server +def mod_server( + input: Inputs, + output: Outputs, + session: Session, + *, + data_source: DataSource, + system_prompt: str, + greeting: str | None, + client: chatlas.Chat, +): + # Reactive values to store state + sql = ReactiveString("") + title = ReactiveStringOrNone(None) + + # Set up the chat object for this session + chat = copy.deepcopy(client) + chat.set_turns([]) + chat.system_prompt = system_prompt + + # Create the tool functions + update_dashboard_tool = tool_update_dashboard(data_source, sql, title) + reset_dashboard_tool = tool_reset_dashboard(sql, title) + query_tool = tool_query(data_source) + + # Register tools with annotations for the UI + chat.register_tool(update_dashboard_tool) + chat.register_tool(query_tool) + chat.register_tool(reset_dashboard_tool) + + # Execute query when SQL changes + @reactive.calc + def filtered_df(): + if sql.get() == "": + return data_source.get_data() + else: + return data_source.execute_query(sql.get()) + + # Chat UI logic + chat_ui = shinychat.Chat(CHAT_ID) + + # Handle user input + @chat_ui.on_user_submit + async def _(user_input: str): + stream = await chat.stream_async(user_input, echo="none", content="all") + await chat_ui.append_message_stream(stream) + + @reactive.effect + async def greet_on_startup(): + if greeting: + await chat_ui.append_message(greeting) + elif greeting is None: + stream = await chat.stream_async( + "Please give me a friendly greeting. Include a few sample prompts in a two-level bulleted list.", + echo="none", + ) + await chat_ui.append_message_stream(stream) + + # Handle update button clicks + @reactive.effect + @reactive.event(input.chat_update) + def _(): + update = input.chat_update() + if update is None: + return + if not isinstance(update, dict): + return + + new_query = update.get("query") + new_title = update.get("title") + if new_query is not None: + sql.set(new_query) + if new_title is not None: + title.set(new_title) + + return ModServerResult(df=filtered_df, sql=sql, title=title, client=chat) diff --git a/pkg-py/src/querychat/express/__init__.py b/pkg-py/src/querychat/express/__init__.py new file mode 100644 index 00000000..fee098df --- /dev/null +++ b/pkg-py/src/querychat/express/__init__.py @@ -0,0 +1,3 @@ +from .._querychat import QueryChatExpress as QueryChat + +__all__ = ["QueryChat"] diff --git a/pkg-py/src/querychat/querychat.py b/pkg-py/src/querychat/querychat.py deleted file mode 100644 index 95219f8a..00000000 --- a/pkg-py/src/querychat/querychat.py +++ /dev/null @@ -1,616 +0,0 @@ -from __future__ import annotations - -import copy -import os -import re -import sys -import warnings -from dataclasses import dataclass -from pathlib import Path -from typing import TYPE_CHECKING, Any, Callable, Optional, Protocol, Union, overload - -import chatlas -import chevron -import shinychat -import sqlalchemy -from shiny import Inputs, Outputs, Session, module, reactive, ui - -from ._utils import temp_env_vars -from .tools import tool_query, tool_reset_dashboard, tool_update_dashboard - -if TYPE_CHECKING: - import pandas as pd - from narwhals.stable.v1.typing import IntoFrame - -from .datasource import DataFrameSource, DataSource, SQLAlchemySource - - -class CreateChatCallback(Protocol): - def __call__(self, system_prompt: str) -> chatlas.Chat: ... - - -@dataclass -class QueryChatConfig: - """ - Configuration class for querychat. - """ - - data_source: DataSource - system_prompt: str - greeting: Optional[str] - client: chatlas.Chat - - -ReactiveString = reactive.Value[str] -"""A reactive string value.""" -ReactiveStringOrNone = reactive.Value[Union[str, None]] -"""A reactive string (or None) value.""" - - -class QueryChat: - """ - An object representing a query chat session. This is created within a Shiny - server function or Shiny module server function by using - `querychat.server()`. Use this object to bridge the chat interface with the - rest of the Shiny app, for example, by displaying the filtered data. - """ - - def __init__( - self, - chat: chatlas.Chat, - sql: ReactiveString, - title: ReactiveStringOrNone, - df: Callable[[], pd.DataFrame], - ): - """ - Initialize a QueryChat object. - - Parameters - ---------- - chat - The chat object for the session - sql - Reactively read (or set) the current SQL query - title - Reactively read (or set) the current title - df - Reactively read the current filtered data frame - - """ - self._chat = chat - self._sql = sql - self._title = title - self._df = df - - def chat(self) -> chatlas.Chat: - """ - Get the chat object for this session. - - Returns - ------- - : - The chat object - - """ - return self._chat - - @overload - def sql(self, query: None = None) -> str: ... - - @overload - def sql(self, query: str) -> bool: ... - - def sql(self, query: Optional[str] = None) -> str | bool: - """ - Reactively read (or set) the current SQL query that is in effect. - - Parameters - ---------- - query - If provided, sets the current SQL query to this value. - - Returns - ------- - : - If no `query` is provided, returns the current SQL query as a string - (possibly `""` if no query has been set). If a `query` is provided, - returns `True` if the query was changed to a new value, or `False` - if it was the same as the current value. - - """ - if query is None: - return self._sql() - else: - return self._sql.set(query) - - @overload - def title(self, value: None = None) -> str | None: ... - - @overload - def title(self, value: str) -> bool: ... - - def title(self, value: Optional[str] = None) -> str | None | bool: - """ - Reactively read (or set) the current title that is in effect. - - The title is a short description of the current query that the LLM - provides to us whenever it generates a new SQL query. It can be used as - a status string for the data dashboard. - - Returns - ------- - : - If no `value` is provided, returns the current title as a string, or - `None` if no title has been set due to no SQL query being set. If a - `value` is provided, sets the current title to this value and - returns `True` if the title was changed to a new value, or `False` - if it was the same as the current value. - - """ - if value is None: - return self._title() - else: - return self._title.set(value) - - def df(self) -> pd.DataFrame: - """ - Reactively read the current filtered data frame that is in effect. - - Returns - ------- - : - The current filtered data frame as a pandas DataFrame. If no query - has been set, this will return the unfiltered data frame from the - data source. - - """ - return self._df() - - def __getitem__(self, key: str) -> Any: - """ - Allow access to configuration parameters like a dictionary. For - backwards compatibility only; new code should use the attributes - directly instead. - """ - return { - "chat": self.chat, - "sql": self.sql, - "title": self.title, - "df": self.df, - }.get(key) - - -def system_prompt( - data_source: DataSource, - *, - data_description: Optional[str | Path] = None, - extra_instructions: Optional[str | Path] = None, - categorical_threshold: int = 10, - prompt_template: Optional[str | Path] = None, -) -> str: - """ - Create a system prompt for the chat model based on a data source's schema - and optional additional context and instructions. - - Parameters - ---------- - data_source - A data source to generate schema information from - data_description - Optional description of the data, in plain text or Markdown format - extra_instructions - Optional additional instructions for the chat model, in plain text or - Markdown format - categorical_threshold - Threshold for determining if a column is categorical based on number of - unique values - prompt_template - Optional `Path` to or string of a custom prompt template. If not provided, the default - querychat template will be used. - - Returns - ------- - : - The system prompt for the chat model. - - """ - # Read the prompt file - if prompt_template is None: - # Default to the prompt file in the same directory as this module - # This allows for easy customization by placing a different prompt.md file there - prompt_template = Path(__file__).parent / "prompts" / "prompt.md" - prompt_str = ( - prompt_template.read_text() - if isinstance(prompt_template, Path) - else prompt_template - ) - - data_description_str = ( - data_description.read_text() - if isinstance(data_description, Path) - else data_description - ) - - extra_instructions_str = ( - extra_instructions.read_text() - if isinstance(extra_instructions, Path) - else extra_instructions - ) - - is_duck_db = data_source.get_db_type().lower() == "duckdb" - - return chevron.render( - prompt_str, - { - "db_type": data_source.get_db_type(), - "is_duck_db": is_duck_db, - "schema": data_source.get_schema( - categorical_threshold=categorical_threshold, - ), - "data_description": data_description_str, - "extra_instructions": extra_instructions_str, - }, - ) - - -def _get_client_from_env() -> Optional[str]: - """Get client configuration from environment variable.""" - env_client = os.getenv("QUERYCHAT_CLIENT", "") - if not env_client: - return None - return env_client - - -def _create_client_from_string(client_str: str) -> chatlas.Chat: - """Create a chatlas.Chat client from a provider-model string.""" - provider, model = ( - client_str.split("/", 1) if "/" in client_str else (client_str, None) - ) - # We unset chatlas's envvars so we can listen to querychat's envvars instead - with temp_env_vars( - { - "CHATLAS_CHAT_PROVIDER": provider, - "CHATLAS_CHAT_MODEL": model, - "CHATLAS_CHAT_ARGS": os.environ.get("QUERYCHAT_CLIENT_ARGS"), - }, - ): - return chatlas.ChatAuto(provider="openai") - - -def _resolve_querychat_client( - client: Optional[Union[chatlas.Chat, CreateChatCallback, str]] = None, -) -> chatlas.Chat: - """ - Resolve the client argument into a chatlas.Chat object. - - Parameters - ---------- - client - The client to resolve. Can be: - - A chatlas.Chat object (returned as-is) - - A function that returns a chatlas.Chat object - - A provider-model string (e.g., "openai/gpt-4.1") - - None (fall back to environment variable or default) - - Returns - ------- - : - A resolved chatlas.Chat object - - """ - if client is None: - client = _get_client_from_env() - - if client is None: - # Default to OpenAI with using chatlas's default model - return chatlas.ChatOpenAI() - - if callable(client) and not isinstance(client, chatlas.Chat): - # Backcompat: support the old create_chat_callback style, using an empty - # system prompt as a placeholder. - client = client(system_prompt="") - - if isinstance(client, str): - client = _create_client_from_string(client) - - if not isinstance(client, chatlas.Chat): - raise TypeError( - "client must be a chatlas.Chat object or function that returns one", - ) - - return client - - -def init( - data_source: IntoFrame | sqlalchemy.Engine, - table_name: str, - *, - greeting: Optional[str | Path] = None, - data_description: Optional[str | Path] = None, - extra_instructions: Optional[str | Path] = None, - prompt_template: Optional[str | Path] = None, - system_prompt_override: Optional[str] = None, - client: Optional[Union[chatlas.Chat, CreateChatCallback, str]] = None, - create_chat_callback: Optional[CreateChatCallback] = None, -) -> QueryChatConfig: - """ - Initialize querychat with any compliant data source. - - Parameters - ---------- - data_source - Either a Narwhals-compatible data frame (e.g., Polars or Pandas) or a - SQLAlchemy engine containing the table to query against. - table_name - If a data_source is a data frame, a name to use to refer to the table in - SQL queries (usually the variable name of the data frame, but it doesn't - have to be). If a data_source is a SQLAlchemy engine, the table_name is - the name of the table in the database to query against. - greeting - A string in Markdown format, containing the initial message. If a - pathlib.Path object is passed, querychat will read the contents of the - path into a string with `.read_text()`. You can use - `querychat.greeting()` to help generate a greeting from a querychat - configuration. If no greeting is provided, one will be generated at the - start of every new conversation. - data_description - Description of the data in plain text or Markdown. - If a pathlib.Path object is passed, - querychat will read the contents of the path into a string with `.read_text()`. - extra_instructions - Additional instructions for the chat model. - If a pathlib.Path object is passed, - querychat will read the contents of the path into a string with `.read_text()`. - prompt_template - Path to or a string of a custom prompt file. If not provided, the default querychat - template will be used. This should be a Markdown file that contains the - system prompt template. The mustache template can use the following - variables: - - `{{db_engine}}`: The database engine used (e.g., "DuckDB") - - `{{schema}}`: The schema of the data source, generated by - `data_source.get_schema()` - - `{{data_description}}`: The optional data description provided - - `{{extra_instructions}}`: Any additional instructions provided - system_prompt_override - A custom system prompt to use instead of the default. If provided, - `data_description`, `extra_instructions`, and `prompt_template` will be - silently ignored. - client - A `chatlas.Chat` object, a string to be passed to `chatlas.ChatAuto()` - describing the model to use (e.g. `"openai/gpt-4.1"`), or a function - that creates a chat client. If using a function, the function should - accept a `system_prompt` argument and return a `chatlas.Chat` object. - - If `client` is not provided, querychat consults the `QUERYCHAT_CLIENT` - environment variable, which can be set to a provider-model string. If no - option is provided, querychat defaults to using - `chatlas.ChatOpenAI(model="gpt-4.1")`. - create_chat_callback - **Deprecated.** Use the `client` argument instead. - - Returns - ------- - : - A QueryChatConfig object that can be passed to server() - - """ - # Handle deprecated create_chat_callback argument - if create_chat_callback is not None: - warnings.warn( - "The 'create_chat_callback' parameter is deprecated. Use 'client' instead.", - DeprecationWarning, - stacklevel=2, - ) - if client is not None: - raise ValueError( - "You cannot pass both `create_chat_callback` and `client` to `init()`.", - ) - client = create_chat_callback - - # Resolve the client - resolved_client = _resolve_querychat_client(client) - - # Validate table name (must begin with letter, contain only letters, numbers, underscores) - if not re.match(r"^[a-zA-Z][a-zA-Z0-9_]*$", table_name): - raise ValueError( - "Table name must begin with a letter and contain only letters, numbers, and underscores", - ) - - data_source_obj: DataSource - if isinstance(data_source, sqlalchemy.Engine): - data_source_obj = SQLAlchemySource(data_source, table_name) - else: - data_source_obj = DataFrameSource( - data_source, - table_name, - ) - - # Process greeting - if greeting is None: - print( - "Warning: No greeting provided; the LLM will be invoked at conversation start to generate one. " - "For faster startup, lower cost, and determinism, please save a greeting and pass it to init().", - "You can also use `querychat.greeting()` to help generate a greeting.", - file=sys.stderr, - ) - - # quality of life improvement to do the Path.read_text() for user or pass along the string - greeting_str: str | None = ( - greeting.read_text() if isinstance(greeting, Path) else greeting - ) - - # Create the system prompt, or use the override - if isinstance(system_prompt_override, Path): - system_prompt_ = system_prompt_override.read_text() - else: - system_prompt_ = system_prompt_override or system_prompt( - data_source_obj, - data_description=data_description, - extra_instructions=extra_instructions, - prompt_template=prompt_template, - ) - - return QueryChatConfig( - data_source=data_source_obj, - system_prompt=system_prompt_, - greeting=greeting_str, - client=resolved_client, - ) - - -@module.ui -def mod_ui() -> ui.TagList: - """ - Create the UI for the querychat component. - - Returns - ------- - : - A UI component. - - """ - # Include CSS and JS - css_path = Path(__file__).parent / "static" / "css" / "styles.css" - js_path = Path(__file__).parent / "static" / "js" / "querychat.js" - - return ui.TagList( - ui.include_css(css_path), - ui.include_js(js_path), - shinychat.chat_ui("chat", class_="querychat"), - ) - - -def sidebar( - id: str, - width: int = 400, - height: str = "100%", - **kwargs, -) -> ui.Sidebar: - """ - Create a sidebar containing the querychat UI. - - Parameters - ---------- - id - The module ID. - width - Width of the sidebar in pixels. - height - Height of the sidebar. - **kwargs - Additional arguments to pass to the sidebar component. - - Returns - ------- - : - A sidebar UI component. - - """ - return ui.sidebar( - mod_ui(id), - width=width, - height=height, - class_="querychat-sidebar", - **kwargs, - ) - - -@module.server -def mod_server( # noqa: D417 - input: Inputs, - output: Outputs, - session: Session, - querychat_config: QueryChatConfig, -) -> QueryChat: - """ - Initialize the querychat server. - - Parameters - ---------- - querychat_config - Configuration object from init(). - - Returns - ------- - : - A QueryChat object representing the chat session. This can be used to - access the chat, current SQL query, title, and filtered data frame. - - """ - # Extract config parameters - data_source = querychat_config.data_source - system_prompt = querychat_config.system_prompt - greeting = querychat_config.greeting - client = querychat_config.client - - # Reactive values to store state - current_title = ReactiveStringOrNone(None) - current_query = ReactiveString("") - - @reactive.calc - def filtered_df(): - if current_query.get() == "": - return data_source.get_data() - else: - return data_source.execute_query(current_query.get()) - - # Create the tool functions - update_dashboard_tool = tool_update_dashboard( - data_source, - current_query, - current_title, - ) - reset_dashboard_tool = tool_reset_dashboard( - current_query, - current_title, - ) - query_tool = tool_query(data_source) - - chat_ui = shinychat.Chat("chat") - - # Set up the chat object for this session - chat = copy.deepcopy(client) - chat.set_turns([]) - chat.system_prompt = system_prompt - - # Register tools with annotations for the UI - chat.register_tool(update_dashboard_tool) - chat.register_tool(query_tool) - chat.register_tool(reset_dashboard_tool) - - # Handle user input - @chat_ui.on_user_submit - async def _(user_input: str): - stream = await chat.stream_async(user_input, echo="none", content="all") - await chat_ui.append_message_stream(stream) - - # Handle update button clicks - @reactive.effect - @reactive.event(input.chat_update) - def _(): - update = input.chat_update() - if update is None: - return - if not isinstance(update, dict): - return - - query = update.get("query") - title = update.get("title") - if query is not None: - current_query.set(query) - if title is not None: - current_title.set(title) - - @reactive.effect - async def greet_on_startup(): - if querychat_config.greeting: - await chat_ui.append_message(greeting) - elif querychat_config.greeting is None: - stream = await chat.stream_async( - "Please give me a friendly greeting. Include a few sample prompts in a two-level bulleted list.", - echo="none", - ) - await chat_ui.append_message_stream(stream) - - # Return the interface for other components to use - return QueryChat(chat, current_query, current_title, filtered_df) diff --git a/pkg-py/src/querychat/tools.py b/pkg-py/src/querychat/tools.py index 0e5fd964..8e88bc0d 100644 --- a/pkg-py/src/querychat/tools.py +++ b/pkg-py/src/querychat/tools.py @@ -11,8 +11,8 @@ from ._utils import df_to_html if TYPE_CHECKING: + from ._querychat_module import ReactiveString, ReactiveStringOrNone from .datasource import DataSource - from .querychat import ReactiveString, ReactiveStringOrNone def _read_prompt_template(filename: str, **kwargs) -> str: diff --git a/pkg-py/tests/test_greeting.py b/pkg-py/tests/test_greeting.py deleted file mode 100644 index d72383c6..00000000 --- a/pkg-py/tests/test_greeting.py +++ /dev/null @@ -1,92 +0,0 @@ -import os - -import pandas as pd -import pytest - -from querychat import greeting, init - - -@pytest.fixture(autouse=True) -def set_dummy_api_key(): - """Set a dummy OpenAI API key for testing.""" - old_api_key = os.environ.get("OPENAI_API_KEY") - os.environ["OPENAI_API_KEY"] = "sk-dummy-api-key-for-testing" - yield - if old_api_key is not None: - os.environ["OPENAI_API_KEY"] = old_api_key - else: - del os.environ["OPENAI_API_KEY"] - - -@pytest.fixture -def querychat_config(): - """Create a test querychat configuration.""" - # Create a simple pandas DataFrame - df = pd.DataFrame( - { - "id": [1, 2, 3], - "name": ["Alice", "Bob", "Charlie"], - "age": [25, 30, 35], - }, - ) - - # Create a config with a greeting - return init( - data_source=df, - table_name="test_table", - greeting="Hello! This is a test greeting.", - ) - - -@pytest.fixture -def querychat_config_no_greeting(): - """Create a test querychat configuration without a greeting.""" - # Create a simple pandas DataFrame - df = pd.DataFrame( - { - "id": [1, 2, 3], - "name": ["Alice", "Bob", "Charlie"], - "age": [25, 30, 35], - }, - ) - - # Create a config without a greeting - return init( - data_source=df, - table_name="test_table", - greeting=None, - ) - - -def test_greeting_retrieval(querychat_config): - """ - Test that greeting() returns the existing greeting when generate=False. - """ - result = greeting(querychat_config, generate=False) - assert result == "Hello! This is a test greeting." - - -def test_greeting_retrieval_none(querychat_config_no_greeting): - """ - Test that greeting() returns None when there's no existing greeting and - generate=False. - """ - result = greeting(querychat_config_no_greeting, generate=False) - assert result is None - - -def test_greeting_retrieval_empty(querychat_config): - """ - Test that greeting() returns None when the existing greeting is empty and - generate=False. - """ - querychat_config.greeting = "" - - result = greeting(querychat_config, generate=False) - assert result is None - - -def test_greeting_invalid_config(): - """Test that greeting() raises TypeError when given an invalid config.""" - with pytest.raises(TypeError): - greeting("not a config") diff --git a/pkg-py/tests/test_init_with_pandas.py b/pkg-py/tests/test_init_with_pandas.py index 4654a0b0..3f94b639 100644 --- a/pkg-py/tests/test_init_with_pandas.py +++ b/pkg-py/tests/test_init_with_pandas.py @@ -3,7 +3,7 @@ import narwhals.stable.v1 as nw import pandas as pd import pytest -from querychat.querychat import init +from querychat import QueryChat @pytest.fixture(autouse=True) @@ -19,7 +19,7 @@ def set_dummy_api_key(): def test_init_with_pandas_dataframe(): - """Test that init() can accept a pandas DataFrame.""" + """Test that QueryChat() can accept a pandas DataFrame.""" # Create a simple pandas DataFrame df = pd.DataFrame( { @@ -29,24 +29,24 @@ def test_init_with_pandas_dataframe(): }, ) - # Call init with the pandas DataFrame - it should not raise errors + # Call QueryChat with the pandas DataFrame - it should not raise errors # The function should accept a pandas DataFrame even with the narwhals import change - result = init( + qc = QueryChat( data_source=df, table_name="test_table", greeting="hello!", ) - # Verify the result is an instance of QueryChatConfig - assert result is not None - assert hasattr(result, "data_source") - assert hasattr(result, "system_prompt") - assert hasattr(result, "greeting") - assert hasattr(result, "client") + # Verify the result is properly configured + assert qc is not None + assert hasattr(qc, "data_source") + assert hasattr(qc, "system_prompt") + assert hasattr(qc, "greeting") + assert hasattr(qc, "client") def test_init_with_narwhals_dataframe(): - """Test that init() can accept a narwhals DataFrame.""" + """Test that QueryChat() can accept a narwhals DataFrame.""" # Create a pandas DataFrame and convert to narwhals pdf = pd.DataFrame( { @@ -57,21 +57,21 @@ def test_init_with_narwhals_dataframe(): ) nw_df = nw.from_native(pdf) - # Call init with the narwhals DataFrame - it should not raise errors - result = init( + # Call QueryChat with the narwhals DataFrame - it should not raise errors + qc = QueryChat( data_source=nw_df, table_name="test_table", greeting="hello!", ) # Verify the result is correctly configured - assert result is not None - assert hasattr(result, "data_source") - assert hasattr(result, "system_prompt") + assert qc is not None + assert hasattr(qc, "data_source") + assert hasattr(qc, "system_prompt") def test_init_with_narwhals_lazyframe_direct_query(): - """Test that init() can accept a narwhals LazyFrame and execute queries.""" + """Test that QueryChat() can accept a narwhals LazyFrame and execute queries.""" # Create a pandas DataFrame and convert to narwhals LazyFrame pdf = pd.DataFrame( { @@ -82,19 +82,19 @@ def test_init_with_narwhals_lazyframe_direct_query(): ) nw_lazy = nw.from_native(pdf).lazy() - # Call init with the narwhals LazyFrame - result = init( + # Call QueryChat with the narwhals LazyFrame + qc = QueryChat( data_source=nw_lazy, # TODO(@gadebuie): Fix this type error table_name="test_table", greeting="hello!", ) # Verify the result is correctly configured - assert result is not None - assert hasattr(result, "data_source") + assert qc is not None + assert hasattr(qc, "data_source") # Test that we can run a query on the data source - query_result = result.data_source.execute_query( + query_result = qc.data_source.execute_query( "SELECT * FROM test_table WHERE id = 2", ) assert len(query_result) == 1 diff --git a/pkg-py/tests/test_querychat.py b/pkg-py/tests/test_querychat.py new file mode 100644 index 00000000..b8267b46 --- /dev/null +++ b/pkg-py/tests/test_querychat.py @@ -0,0 +1,110 @@ +import os + +import pandas as pd +import pytest +from querychat import QueryChat + + +@pytest.fixture(autouse=True) +def set_dummy_api_key(): + """Set a dummy OpenAI API key for testing.""" + old_api_key = os.environ.get("OPENAI_API_KEY") + os.environ["OPENAI_API_KEY"] = "sk-dummy-api-key-for-testing" + yield + if old_api_key is not None: + os.environ["OPENAI_API_KEY"] = old_api_key + else: + del os.environ["OPENAI_API_KEY"] + + +@pytest.fixture +def sample_df(): + """Create a sample pandas DataFrame for testing.""" + return pd.DataFrame( + { + "id": [1, 2, 3], + "name": ["Alice", "Bob", "Charlie"], + "age": [25, 30, 35], + }, + ) + + +def test_querychat_init(sample_df): + """Test that QueryChat (Express mode) initializes correctly.""" + qc = QueryChat( + data_source=sample_df, + table_name="test_table", + greeting="Hello!", + ) + + # Verify basic attributes are set + assert qc is not None + assert hasattr(qc, "data_source") + assert hasattr(qc, "system_prompt") + assert hasattr(qc, "greeting") + assert hasattr(qc, "client") + assert qc.id == "test_table" + + # Even without server initialization, we should be able to query the data source + result = qc.data_source.execute_query( + "SELECT * FROM test_table WHERE id = 2", + ) + + assert len(result) == 1 + assert result.iloc[0]["name"] == "Bob" + + +def test_querychat_custom_id(sample_df): + """Test that QueryChat accepts custom ID.""" + qc = QueryChat( + data_source=sample_df, + table_name="test_table", + id="custom_id", + greeting="Hello!", + ) + + assert qc.id == "custom_id" + + +def test_querychat_set_methods(sample_df): + """Test that setter methods work.""" + qc = QueryChat( + data_source=sample_df, + table_name="test_table", + greeting="Hello!", + ) + + # Test set_system_prompt + qc.set_system_prompt( + qc.data_source, + data_description="A test dataset", + ) + assert "test dataset" in qc.system_prompt.lower() + + # Test set_data_source + new_df = pd.DataFrame({"x": [1, 2, 3]}) + qc.set_data_source(new_df, "new_table") + assert qc.data_source is not None + + # Test set_client + qc.set_client("openai/gpt-4o-mini") + assert qc.client is not None + + +def test_querychat_core_reactive_access_before_server_raises(sample_df): + """Test that accessing reactive properties before .server() raises error.""" + qc = QueryChat( + data_source=sample_df, + table_name="test_table", + greeting="Hello!", + ) + + # Accessing reactive properties before .server() should raise + with pytest.raises(RuntimeError, match="Must call \\.server\\(\\)"): + qc.title() + + with pytest.raises(RuntimeError, match="Must call \\.server\\(\\)"): + qc.sql() + + with pytest.raises(RuntimeError, match="Must call \\.server\\(\\)"): + qc.df() diff --git a/pyproject.toml b/pyproject.toml index c90b99d6..69fcd83e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "shiny", "shinywidgets", "htmltools", - "chatlas>=0.12.0", + "chatlas>=0.13.2", "narwhals", "chevron", "sqlalchemy>=2.0.0", # Using 2.0+ for improved type hints and API