Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 44 additions & 8 deletions pkg-py/src/querychat/_querychat_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,12 +81,18 @@ def __init__(
self._extra_instructions = extra_instructions
self._categorical_threshold = categorical_threshold

# Normalize and initialize client (doesn't need data_source)
client = normalize_client(client)
self._client = copy.deepcopy(client)
self._client.set_turns([])

# Initialize client
# When data_source is None (deferred pattern), also defer client initialization
# unless an explicit client is provided
self._client_console = None
if data_source is None and client is None:
# Deferred pattern: don't try to create a default client
self._client: chatlas.Chat | None = None
else:
# Normalize and initialize client
normalized_client = normalize_client(client)
self._client = copy.deepcopy(normalized_client)
self._client.set_turns([])

# Initialize data source (may be None for deferred pattern)
if data_source is not None:
Expand Down Expand Up @@ -114,7 +120,9 @@ def _build_system_prompt(self) -> None:
extra_instructions=self._extra_instructions,
categorical_threshold=self._categorical_threshold,
)
self._client.system_prompt = self._system_prompt.render(self.tools)
# Only set system_prompt on client if client is available
if self._client is not None:
self._client.system_prompt = self._system_prompt.render(self.tools)

def _require_data_source(self, method_name: str) -> DataSource[IntoFrameT]:
"""Raise if data_source is not set, otherwise return it for type narrowing."""
Expand All @@ -126,6 +134,16 @@ def _require_data_source(self, method_name: str) -> DataSource[IntoFrameT]:
)
return self._data_source

def _require_client(self, method_name: str) -> chatlas.Chat:
"""Raise if client is not set, otherwise return it for type narrowing."""
if self._client is None:
raise RuntimeError(
f"client must be set before calling {method_name}(). "
"Either pass client to __init__(), set the chat_client property, "
"or pass client to server()."
)
return self._client

def client(
self,
*,
Expand All @@ -152,11 +170,12 @@ def client(

"""
data_source = self._require_data_source("client")
base_client = self._require_client("client")
if self._system_prompt is None:
raise RuntimeError("System prompt not initialized")
tools = normalize_tools(tools, default=self.tools)

chat = copy.deepcopy(self._client)
chat = copy.deepcopy(base_client)
chat.set_turns([])
chat.system_prompt = self._system_prompt.render(tools)

Expand All @@ -177,7 +196,8 @@ def client(
def generate_greeting(self, *, echo: Literal["none", "output"] = "none") -> str:
"""Generate a welcome greeting for the chat."""
self._require_data_source("generate_greeting")
client = copy.deepcopy(self._client)
base_client = self._require_client("generate_greeting")
client = copy.deepcopy(base_client)
client.set_turns([])
return str(client.chat(GREETING_PROMPT, echo=echo))

Expand All @@ -190,6 +210,7 @@ def console(
) -> None:
"""Launch an interactive console chat with the data."""
self._require_data_source("console")
self._require_client("console")
tools = normalize_tools(tools, default=("query",))

if new or self._client_console is None:
Expand All @@ -216,6 +237,21 @@ def data_source(self, value: IntoFrame | sqlalchemy.Engine) -> None:
self._data_source = normalize_data_source(value, self._table_name)
self._build_system_prompt()

@property
def chat_client(self) -> chatlas.Chat | None:
"""Get the current chat client."""
return self._client

@chat_client.setter
def chat_client(self, value: str | chatlas.Chat) -> None:
"""Set the chat client, normalizing and updating system prompt if needed."""
normalized_client = normalize_client(value)
self._client = copy.deepcopy(normalized_client)
self._client.set_turns([])
# Update system prompt on client if data_source is already set
if self._data_source is not None and self._system_prompt is not None:
self._client.system_prompt = self._system_prompt.render(self.tools)

def cleanup(self) -> None:
"""Clean up resources associated with the data source."""
if self._data_source is not None:
Expand Down
10 changes: 10 additions & 0 deletions pkg-py/src/querychat/_shiny.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,7 @@ def server(
self,
*,
data_source: Optional[IntoFrame | sqlalchemy.Engine | ibis.Table] = None,
client: Optional[str | chatlas.Chat] = None,
enable_bookmarking: bool = False,
id: Optional[str] = None,
) -> ServerValues[IntoFrameT]:
Expand All @@ -422,6 +423,11 @@ def server(
Optional data source to use. If provided, sets the data_source property
before initializing server logic. This is useful for the deferred pattern
where data_source is not known at initialization time.
client
Optional chat client to use. If provided, sets the chat_client property
before initializing server logic. This is useful for the deferred pattern
where the client cannot be created at initialization time (e.g., when
using Posit Connect managed OAuth credentials that require session access).
enable_bookmarking
Whether to enable bookmarking for the querychat module.
id
Expand Down Expand Up @@ -485,7 +491,11 @@ def title():
if data_source is not None:
self.data_source = data_source

if client is not None:
self.chat_client = client

resolved_data_source = self._require_data_source("server")
self._require_client("server")

return mod_server(
id or self.id,
Expand Down
190 changes: 190 additions & 0 deletions pkg-py/tests/test_deferred_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,190 @@
"""Tests for deferred chat client initialization."""

import pandas as pd
import pytest
from chatlas import ChatOpenAI
from querychat._querychat_base import QueryChatBase


@pytest.fixture
def sample_df():
"""Create a sample pandas DataFrame for testing."""
return pd.DataFrame(
{
"id": [1, 2, 3],
"name": ["Alice", "Bob", "Charlie"],
"age": [25, 30, 35],
},
)


class TestDeferredClientInit:
"""Tests for initializing QueryChatBase with deferred client."""

def test_init_with_none_data_source_defers_client(self):
"""When data_source is None and client is not provided, client should be None."""
qc = QueryChatBase(None, "users")
assert qc._client is None
assert qc.chat_client is None

def test_init_with_explicit_client_and_none_data_source(self, monkeypatch):
"""When data_source is None but client is provided, client should be initialized."""
monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing")
qc = QueryChatBase(None, "users", client="openai")
assert qc._client is not None
assert qc.chat_client is not None

def test_init_with_data_source_initializes_client(self, sample_df, monkeypatch):
"""When data_source is provided, client should be initialized with default."""
monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing")
qc = QueryChatBase(sample_df, "users")
assert qc._client is not None
assert qc.chat_client is not None


class TestChatClientProperty:
"""Tests for the chat_client property setter."""

def test_chat_client_setter(self, monkeypatch):
"""Setting chat_client should normalize and store the client."""
monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing")
qc = QueryChatBase(None, "users")
assert qc.chat_client is None

qc.chat_client = "openai"
assert qc.chat_client is not None

def test_chat_client_setter_with_chat_object(self, monkeypatch):
"""Setting chat_client with a Chat object should work."""
monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing")
qc = QueryChatBase(None, "users")
assert qc.chat_client is None

chat = ChatOpenAI()
qc.chat_client = chat
assert qc.chat_client is not None

def test_chat_client_setter_updates_system_prompt(self, sample_df, monkeypatch):
"""Setting chat_client should update system_prompt if data_source is set."""
monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing")
# Start with data_source but deferred client
qc = QueryChatBase(None, "users")
qc.data_source = sample_df

# Now set the client - it should get the system prompt
qc.chat_client = "openai"
assert qc._client is not None
# The system prompt should have been set on the client
assert qc._client.system_prompt is not None

def test_chat_client_getter_returns_none_when_not_set(self):
"""chat_client property returns None when not set."""
qc = QueryChatBase(None, "users")
assert qc.chat_client is None


class TestClientMethodRequirements:
"""Tests that methods properly require client to be set."""

def test_client_method_requires_client(self, sample_df, monkeypatch):
"""client() should raise if client not set."""
monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing")
# Initialize with data_source but no client
qc = QueryChatBase(None, "users")
qc.data_source = sample_df

with pytest.raises(RuntimeError, match="client must be set"):
qc.client()

def test_console_requires_client(self, sample_df, monkeypatch):
"""console() should raise if client not set."""
monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing")
qc = QueryChatBase(None, "users")
qc.data_source = sample_df

with pytest.raises(RuntimeError, match="client must be set"):
qc.console()

def test_generate_greeting_requires_client(self, sample_df, monkeypatch):
"""generate_greeting() should raise if client not set."""
monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing")
qc = QueryChatBase(None, "users")
qc.data_source = sample_df

with pytest.raises(RuntimeError, match="client must be set"):
qc.generate_greeting()


class TestDeferredClientIntegration:
"""Integration tests for the full deferred client workflow."""

def test_deferred_data_source_and_client(self, sample_df, monkeypatch):
"""Test setting both data_source and client after init."""
monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing")

# Create with both deferred
qc = QueryChatBase(None, "users")
assert qc.data_source is None
assert qc.chat_client is None

# Set data_source first
qc.data_source = sample_df
assert qc.data_source is not None

# Set client second
qc.chat_client = "openai"
assert qc.chat_client is not None

# Now methods should work
client = qc.client()
assert client is not None
assert "users" in qc.system_prompt

def test_deferred_client_then_data_source(self, sample_df, monkeypatch):
"""Test setting client before data_source."""
monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing")

# Create with both deferred
qc = QueryChatBase(None, "users")

# Set client first
qc.chat_client = "openai"
assert qc.chat_client is not None

# Set data_source second
qc.data_source = sample_df
assert qc.data_source is not None

# Now methods should work
client = qc.client()
assert client is not None

def test_no_openai_key_error_when_deferred(self, monkeypatch):
"""When data_source is None, no OpenAI API key error should occur."""
# Remove OpenAI API key if set
monkeypatch.delenv("OPENAI_API_KEY", raising=False)
monkeypatch.delenv("QUERYCHAT_CLIENT", raising=False)

# This should NOT raise an error about missing API key
qc = QueryChatBase(None, "users")
assert qc._client is None
assert qc.chat_client is None


class TestBackwardCompatibility:
"""Tests that existing patterns continue to work."""

def test_immediate_pattern_unchanged(self, sample_df, monkeypatch):
"""Existing code with data_source continues to work."""
monkeypatch.setenv("OPENAI_API_KEY", "sk-dummy-key-for-testing")
qc = QueryChatBase(sample_df, "test_table")

assert qc.data_source is not None
assert qc.chat_client is not None

# All methods should work immediately
client = qc.client()
assert client is not None

prompt = qc.system_prompt
assert "test_table" in prompt
9 changes: 7 additions & 2 deletions pkg-py/tests/test_deferred_datasource.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,14 +139,19 @@ class TestDeferredPatternIntegration:

def test_deferred_then_set_property(self, sample_df):
"""Test setting data_source via property after init."""
# Create with None
# Create with None - both data_source and client are deferred
qc = QueryChatBase(None, "users")
assert qc.data_source is None
assert qc.chat_client is None

# Set via property
# Set data_source via property
qc.data_source = sample_df
assert qc.data_source is not None

# Set client via property (required now that we defer both)
qc.chat_client = "openai"
assert qc.chat_client is not None

# Now methods should work
client = qc.client()
assert client is not None
Expand Down
Loading