Skip to content

Commit 7ea1639

Browse files
authored
add generated and hand-tweaked tests for test_chat.py (#162)
1 parent fc643ca commit 7ea1639

File tree

1 file changed

+185
-0
lines changed

1 file changed

+185
-0
lines changed

backend/tests/test_chat.py

Lines changed: 185 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
import pytest
2+
from tenantfirstaid.chat import (
3+
ChatManager,
4+
DEFAULT_INSTRUCTIONS,
5+
OREGON_LAW_CENTER_PHONE_NUMBER,
6+
)
7+
import os
8+
from unittest import mock
9+
from flask import Flask
10+
from tenantfirstaid.chat import ChatView
11+
from tenantfirstaid.session import TenantSession, TenantSessionData, InitSessionView
12+
from openai import OpenAI
13+
from openai.types.responses import ResponseTextDeltaEvent
14+
from typing import Dict
15+
16+
17+
@pytest.fixture
18+
def mock_openai(mocker):
19+
mock_openai_client = mocker.Mock(spec=OpenAI)
20+
mocker.patch("tenantfirstaid.chat.OpenAI", return_value=mock_openai_client)
21+
return mock_openai_client
22+
23+
24+
@pytest.fixture
25+
def chat_manager(mocker, mock_openai):
26+
return ChatManager()
27+
28+
29+
def test_prepare_developer_instructions_includes_city_state(chat_manager):
30+
city = "Portland"
31+
state = "or"
32+
instructions = chat_manager.prepare_developer_instructions(city, state)
33+
assert f"The user is in {city} {state.upper()}." in instructions
34+
35+
36+
def test_prepare_openai_tools_returns_none_if_no_vector_store_id(chat_manager):
37+
with mock.patch.dict(os.environ, {}, clear=True):
38+
assert chat_manager.prepare_openai_tools("Portland", "or") is None
39+
40+
41+
def test_prepare_openai_tools_returns_tools_with_vector_store_id(chat_manager):
42+
with mock.patch.dict(os.environ, {"VECTOR_STORE_ID": "abc123"}):
43+
tools = chat_manager.prepare_openai_tools("Portland", "or")
44+
assert tools is not None
45+
assert type(tools) is list
46+
assert len(tools) == 1
47+
assert tools[0].get("vector_store_ids") == ["abc123"]
48+
49+
50+
def test_prepare_openai_tools_city_null(chat_manager):
51+
with mock.patch.dict(os.environ, {"VECTOR_STORE_ID": "abc123"}):
52+
tools = chat_manager.prepare_openai_tools("null", "or")
53+
assert tools is not None
54+
assert tools[0].get("filters").get("type") == "and"
55+
56+
57+
def test_generate_chat_response_streaming(chat_manager):
58+
with mock.patch.object(chat_manager.client.responses, "create") as mock_create:
59+
mock_create.return_value = iter([])
60+
result = chat_manager.generate_chat_response([], "Portland", "or", stream=True)
61+
assert hasattr(result, "__iter__")
62+
63+
64+
def test_generate_chat_response_non_streaming(chat_manager):
65+
with mock.patch.object(chat_manager.client.responses, "create") as mock_create:
66+
mock_create.return_value = "response"
67+
result = chat_manager.generate_chat_response([], "Portland", "or", stream=False)
68+
assert result == "response"
69+
70+
71+
def test_default_instructions_contains_oregon_law_center_phone():
72+
assert OREGON_LAW_CENTER_PHONE_NUMBER in DEFAULT_INSTRUCTIONS
73+
74+
75+
def test_default_instructions_contains_citation_links():
76+
assert "https://oregon.public.law/statutes" in DEFAULT_INSTRUCTIONS
77+
assert 'target="_blank"' in DEFAULT_INSTRUCTIONS
78+
79+
80+
@pytest.fixture
81+
def mock_valkey_ping_nop(mocker, monkeypatch):
82+
"""Mock the Valkey class with the db_con.ping() method."""
83+
84+
monkeypatch.setenv("DB_HOST", "8.8.8.8")
85+
monkeypatch.setenv("DB_PORT", "8888")
86+
monkeypatch.setenv("DB_PASSWORD", "test_password")
87+
monkeypatch.setenv("DB_USE_SSL", "false")
88+
89+
mock_valkey_client = mocker.Mock()
90+
mocker.patch("tenantfirstaid.session.Valkey", return_value=mock_valkey_client)
91+
mock_valkey_client.ping = ()
92+
return mock_valkey_client
93+
94+
95+
@pytest.fixture
96+
def mock_valkey(mock_valkey_ping_nop, mocker):
97+
_data: Dict[str, str] = {}
98+
99+
mock_valkey_ping_nop.set = mocker.Mock(
100+
side_effect=lambda key, value: _data.update({key: value})
101+
)
102+
103+
mock_valkey_ping_nop.get = mocker.Mock(side_effect=lambda key: _data[key])
104+
105+
return mock_valkey_ping_nop
106+
107+
108+
@pytest.fixture
109+
def app(mock_valkey):
110+
app = Flask(__name__)
111+
app.testing = True # propagate exceptions to the test client
112+
app.secret_key = "test_secret_key" # Set a secret key for session management
113+
114+
return app
115+
116+
117+
def test_chat_view_dispatch_request_streams_response(app, mocker, mock_openai):
118+
tenant_session = TenantSession()
119+
120+
app.add_url_rule(
121+
"/api/init",
122+
view_func=InitSessionView.as_view("init", tenant_session),
123+
methods=["POST"],
124+
)
125+
126+
app.add_url_rule(
127+
"/api/query",
128+
view_func=ChatView.as_view("chat", tenant_session),
129+
methods=["POST"],
130+
)
131+
132+
test_data_obj = TenantSessionData(
133+
city="Test City",
134+
state="Test State",
135+
messages=[],
136+
)
137+
138+
# Initialize the session with session_id and test data
139+
with app.test_request_context(
140+
"/api/init", method="POST", json=test_data_obj
141+
) as init_ctx:
142+
init_response = app.full_dispatch_request()
143+
assert init_response.status_code == 200 # Ensure the response is successful
144+
145+
tenant_session.set(test_data_obj)
146+
session_id = init_response.json["session_id"]
147+
assert session_id is not None # Ensure session_id is set
148+
assert isinstance(session_id, str) # Ensure session_id is a string
149+
assert tenant_session.get() == test_data_obj
150+
151+
# each test-request is a new context (nesting does not do what you think)
152+
# so we need to set the session_id in the request context manually
153+
with app.test_request_context(
154+
"/api/query", method="POST", json={"message": "Salutations mock openai api"}
155+
) as chat_ctx:
156+
chat_ctx.session["session_id"] = (
157+
session_id # Simulate session ID in request context
158+
)
159+
chat_response = init_ctx.app.full_dispatch_request()
160+
assert chat_response.status_code == 200 # Ensure the response is successful
161+
assert chat_response.mimetype == "text/plain"
162+
163+
mock_openai.responses.create = mocker.Mock(
164+
side_effect=lambda model,
165+
input,
166+
instructions,
167+
reasoning,
168+
stream,
169+
include,
170+
tools: iter(
171+
[
172+
ResponseTextDeltaEvent(
173+
type="response.output_text.delta",
174+
delta="Greetings, test prompt!",
175+
content_index=-1,
176+
item_id="item-id",
177+
output_index=0,
178+
sequence_number=-1,
179+
)
180+
]
181+
)
182+
)
183+
184+
response_chunks = "".join(chat_response.response)
185+
assert "Greetings, test prompt!" in response_chunks

0 commit comments

Comments
 (0)