Skip to content

Commit 755ac76

Browse files
committed
feat: interstitial commit
1 parent 6263715 commit 755ac76

File tree

5 files changed

+73
-77
lines changed

5 files changed

+73
-77
lines changed

backend/pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ dependencies = [
1414
"simplejson",
1515
"python-dotenv",
1616
"pandas>=2.3.0",
17+
"vertexai",
1718
]
1819

1920
[tool.setuptools.packages.find]

backend/scripts/generate_conversation/chat.py

Lines changed: 19 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import argparse
1313
from pathlib import Path
1414
import pandas as pd
15-
from google.oauth2 import service_account
1615

1716
from tenantfirstaid.chat import DEFAULT_INSTRUCTIONS, ChatManager
1817

@@ -46,9 +45,7 @@ def __init__(self, starting_message, user_facts, city, state):
4645
self.city = city
4746
self.state = state
4847

49-
self.input_messages = [
50-
dict(role="user", content=starting_message)
51-
]
48+
self.input_messages = [dict(role="user", content=starting_message)]
5249
self.starting_message = starting_message # Store the starting message
5350

5451
self.openai_tools = []
@@ -63,13 +60,9 @@ def _reverse_message_roles(self, messages):
6360
reversed_messages = []
6461
for message in messages:
6562
if message["role"] == "user":
66-
reversed_messages.append(
67-
dict(role="model", content=message["content"])
68-
)
63+
reversed_messages.append(dict(role="model", content=message["content"]))
6964
elif message["role"] == "model":
70-
reversed_messages.append(
71-
dict(role="user", content=message["content"])
72-
)
65+
reversed_messages.append(dict(role="user", content=message["content"]))
7366
else:
7467
reversed_messages.append(message)
7568
return reversed_messages
@@ -78,26 +71,22 @@ def bot_response(self):
7871
"""Generates a response from the bot using the OpenAI API."""
7972
tries = 0
8073
while tries < 3:
81-
# Use the BOT_INSTRUCTIONS for bot responses
82-
start = time()
83-
response = self.chat_manager.generate_gemini_chat_response(
84-
self.input_messages,
85-
city=self.city,
86-
state=self.state,
87-
stream=False,
88-
model_name="gemini-2.5-pro",
89-
)
90-
end = time()
91-
self.input_messages.append(
92-
dict(role="model", content=response.text)
93-
)
94-
self.input_messages = self._reverse_message_roles(self.input_messages)
95-
return response.text, end - start
74+
# Use the BOT_INSTRUCTIONS for bot responses
75+
start = time()
76+
response = self.chat_manager.generate_gemini_chat_response(
77+
self.input_messages,
78+
city=self.city,
79+
state=self.state,
80+
stream=False,
81+
model_name="gemini-2.5-pro",
82+
)
83+
end = time()
84+
self.input_messages.append(dict(role="model", content=response.text))
85+
self.input_messages = self._reverse_message_roles(self.input_messages)
86+
return response.text, end - start
9687
# If all attempts fail, return a failure message
9788
failure_message = "I'm sorry, I am unable to generate a response at this time. Please try again later."
98-
self.input_messages.append(
99-
dict(role="model", content=failure_message)
100-
)
89+
self.input_messages.append(dict(role="model", content=failure_message))
10190
return failure_message, None
10291

10392
def user_response(self):
@@ -116,18 +105,14 @@ def user_response(self):
116105
use_tools=False,
117106
model_name="gemini-2.0-flash-lite",
118107
)
119-
self.input_messages.append(
120-
dict(role="user", content=response.text)
121-
)
108+
self.input_messages.append(dict(role="user", content=response.text))
122109
return response.text
123110
except Exception as e:
124111
print(f"Error generating user response: {e}")
125112
tries += 1
126113
# If all attempts fail, return a failure message
127114
failure_message = "I'm sorry, I am unable to generate a user response at this time. Please try again later."
128-
self.input_messages.append(
129-
dict(role="user", content=failure_message)
130-
)
115+
self.input_messages.append(dict(role="user", content=failure_message))
131116
return failure_message
132117

133118
def generate_conversation(self, num_turns=5):

backend/tenantfirstaid/chat.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,19 @@
4242
class ChatManager:
4343
def __init__(self):
4444
creds = service_account.Credentials.from_service_account_file(
45-
os.getenv("GOOGLE_SERVICE_ACCOUNT_CREDENTIALS_FILE", "google-service-account.json")
45+
os.getenv(
46+
"GOOGLE_SERVICE_ACCOUNT_CREDENTIALS_FILE", "google-service-account.json"
47+
)
4648
)
4749
vertexai.init(
4850
project="tenantfirstaid",
4951
location="us-west1",
5052
credentials=creds,
5153
)
54+
self.model = GenerativeModel(
55+
model_name=MODEL,
56+
system_instruction=DEFAULT_INSTRUCTIONS,
57+
)
5258

5359
def prepare_developer_instructions(self, city: str, state: str):
5460
# Add city and state filters if they are set
@@ -74,7 +80,7 @@ def generate_gemini_chat_response(
7480
else self.prepare_developer_instructions(city, state)
7581
)
7682

77-
model = GenerativeModel(
83+
self.model = GenerativeModel(
7884
model_name=model_name,
7985
system_instruction=instructions,
8086
)
@@ -91,16 +97,11 @@ def generate_gemini_chat_response(
9197
}
9298
)
9399

94-
95100
GEMINI_RAG_CORPUS = os.getenv("GEMINI_RAG_CORPUS")
96101
rag_retrieval_tool = Tool.from_retrieval(
97102
retrieval=rag.Retrieval(
98103
source=rag.VertexRagStore(
99-
rag_resources=[
100-
rag.RagResource(
101-
rag_corpus=GEMINI_RAG_CORPUS
102-
)
103-
]
104+
rag_resources=[rag.RagResource(rag_corpus=GEMINI_RAG_CORPUS)]
104105
)
105106
)
106107
)
@@ -135,8 +136,6 @@ def generate():
135136
current_session["state"],
136137
stream=True,
137138
)
138-
139-
140139

141140
assistant_chunks = []
142141
for event in response_stream:

backend/tests/test_chat.py

Lines changed: 30 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
@pytest.fixture
2020
def mock_vertexai(mocker):
2121
mock_vertexai_init = mocker.Mock(spec=vertexai)
22-
mocker.patch("tenantfirstaid.chat.vertexai", return_value=mock_vertexai_init)
22+
mocker.patch("tenantfirstaid.chat.vertexai.init", return_value=mock_vertexai_init)
2323
return mock_vertexai_init
2424

2525

@@ -31,14 +31,7 @@ def mock_vertexai_generative_model(mocker):
3131

3232

3333
@pytest.fixture
34-
def mock_vertexai_generation_config(mocker):
35-
mock_gen_config = mocker.Mock(spec=GenerationConfig)
36-
mocker.patch("tenantfirstaid.chat.GenerationConfig", return_value=mock_gen_config)
37-
return mock_gen_config
38-
39-
40-
@pytest.fixture
41-
def chat_manager(mocker, mock_vertexai):
34+
def chat_manager(mocker, mock_vertexai, mock_vertexai_generative_model):
4235
return ChatManager()
4336

4437

@@ -136,31 +129,35 @@ def test_chat_view_dispatch_request_streams_response(
136129
with app.test_request_context(
137130
"/api/query", method="POST", json={"message": "Salutations mock openai api"}
138131
) as chat_ctx:
139-
chat_ctx.session["session_id"] = (
140-
session_id # Simulate session ID in request context
141-
)
142-
chat_response = init_ctx.app.full_dispatch_request()
143-
assert chat_response.status_code == 200 # Ensure the response is successful
144-
assert chat_response.mimetype == "text/plain"
145-
146-
mock_vertexai_generative_model.generate_content = mocker.Mock(
147-
return_value=iter(
148-
[
149-
GenerationResponse.from_dict(
150-
response_dict=dict(
151-
candidates=[
152-
dict(
153-
content=dict(
154-
role="model",
155-
parts=[dict(text="Greetings, test prompt!")],
132+
with mocker.patch("tenantfirstaid.chat.ChatManger.model") as mock_model:
133+
134+
chat_ctx.session["session_id"] = (
135+
session_id # Simulate session ID in request context
136+
)
137+
chat_response = init_ctx.app.full_dispatch_request()
138+
assert chat_response.status_code == 200 # Ensure the response is successful
139+
assert chat_response.mimetype == "text/plain"
140+
141+
mock_model.generate_content = mocker.Mock(
142+
return_value=iter(
143+
[
144+
GenerationResponse.from_dict(
145+
response_dict=dict(
146+
candidates=[
147+
dict(
148+
content=dict(
149+
role="model",
150+
parts=[dict(text="Greetings, test prompt!")],
151+
)
156152
)
157-
)
158-
]
153+
]
154+
)
159155
)
160-
)
161-
]
156+
]
157+
)
162158
)
163-
)
164159

165-
response_chunks = "".join(chat_response.response)
166-
assert "Greetings, test prompt!" in response_chunks
160+
print("CHAT RESPONSE", chat_response, chat_response.response)
161+
162+
response_chunks = "".join(chat_response.response)
163+
assert "Greetings, test prompt!" in response_chunks

backend/uv.lock

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)