Skip to content

Commit da4dacc

Browse files
committed
Restore files that only have linting changes
1 parent 2b5af3a commit da4dacc

File tree

4 files changed

+70
-96
lines changed

4 files changed

+70
-96
lines changed

server/api/services/tools/tools.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1+
from django.db import connection
12
from typing import Dict, Any, Callable, List
23
from dataclasses import dataclass
3-
4-
from django.db import connection
5-
64
from .database import ask_database, get_database_info
75

86
database_schema_dict = get_database_info(connection)
@@ -13,15 +11,13 @@
1311
]
1412
)
1513

16-
1714
@dataclass
1815
class ToolFunction:
1916
name: str
2017
func: Callable
2118
description: str
2219
parameters: Dict[str, Any]
2320

24-
2521
def create_tool_dict(tool: ToolFunction) -> Dict[str, Any]:
2622
return {
2723
"type": "function",
@@ -32,11 +28,10 @@ def create_tool_dict(tool: ToolFunction) -> Dict[str, Any]:
3228
"type": "object",
3329
"properties": tool.parameters,
3430
"required": list(tool.parameters.keys()),
35-
},
36-
},
31+
}
32+
}
3733
}
3834

39-
4035
TOOL_FUNCTIONS = [
4136
ToolFunction(
4237
name="ask_database",
@@ -61,58 +56,60 @@ def create_tool_dict(tool: ToolFunction) -> Dict[str, Any]:
6156
SQL should be written using this database schema:
6257
{database_schema_string}
6358
The query should be returned in plain text, not in JSON.
64-
""",
59+
"""
6560
}
66-
},
61+
}
6762
),
6863
]
6964

7065
# Automatically generate the tool_functions dictionary and tools list
71-
tool_functions: Dict[str, Callable] = {tool.name: tool.func for tool in TOOL_FUNCTIONS}
72-
73-
tools: List[Dict[str, Any]] = [create_tool_dict(tool) for tool in TOOL_FUNCTIONS]
66+
tool_functions: Dict[str, Callable] = {
67+
tool.name: tool.func for tool in TOOL_FUNCTIONS
68+
}
7469

70+
tools: List[Dict[str, Any]] = [
71+
create_tool_dict(tool) for tool in TOOL_FUNCTIONS
72+
]
7573

7674
def validate_tool_inputs(tool_function_name, tool_arguments):
7775
"""Validate the inputs for the execute_tool function."""
7876
if not isinstance(tool_function_name, str) or not tool_function_name:
7977
raise ValueError("Invalid tool function name")
80-
78+
8179
if not isinstance(tool_arguments, dict):
8280
raise ValueError("Tool arguments must be a dictionary")
83-
81+
8482
# Check if the tool_function_name exists in the tools
8583
tool = next((t for t in tools if t["function"]["name"] == tool_function_name), None)
8684
if not tool:
8785
raise ValueError(f"Tool function '{tool_function_name}' does not exist")
88-
86+
8987
# Validate the tool arguments based on the tool's parameters
9088
parameters = tool["function"].get("parameters", {})
9189
required_params = parameters.get("required", [])
9290
for param in required_params:
9391
if param not in tool_arguments:
9492
raise ValueError(f"Missing required parameter: {param}")
95-
93+
9694
# Check if the parameter types match the expected types
9795
properties = parameters.get("properties", {})
9896
for param, prop in properties.items():
99-
expected_type = prop.get("type")
97+
expected_type = prop.get('type')
10098
if param in tool_arguments:
101-
if expected_type == "string" and not isinstance(tool_arguments[param], str):
99+
if expected_type == 'string' and not isinstance(tool_arguments[param], str):
102100
raise ValueError(f"Parameter '{param}' must be of type string")
103-
104-
101+
105102
def execute_tool(function_name: str, arguments: Dict[str, Any]) -> str:
106103
"""
107104
Execute the appropriate function based on the function name.
108-
105+
109106
:param function_name: The name of the function to execute
110107
:param arguments: A dictionary of arguments to pass to the function
111108
:return: The result of the function execution
112109
"""
113110
# Validate tool inputs
114111
validate_tool_inputs(function_name, arguments)
115-
112+
116113
try:
117114
return tool_functions[function_name](**arguments)
118115
except Exception as e:
Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,13 @@
11
from django.urls import path, include
2-
from rest_framework.routers import DefaultRouter
3-
42
from api.views.conversations import views
3+
from rest_framework.routers import DefaultRouter
4+
# from views import ConversationViewSet
55

66
router = DefaultRouter()
7-
router.register(r"conversations", views.ConversationViewSet, basename="conversation")
7+
router.register(r'conversations', views.ConversationViewSet,
8+
basename='conversation')
89

910
urlpatterns = [
1011
path("chatgpt/extract_text/", views.extract_text, name="post_web_text"),
11-
path("chatgpt/", include(router.urls)),
12+
path("chatgpt/", include(router.urls))
1213
]

server/api/views/conversations/views.py

Lines changed: 39 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
import os
2-
import json
3-
import logging
4-
51
from rest_framework.response import Response
62
from rest_framework import viewsets, status
73
from rest_framework.decorators import action
@@ -13,8 +9,10 @@
139
import requests
1410
from openai import OpenAI
1511
import tiktoken
12+
import os
13+
import json
14+
import logging
1615
from django.views.decorators.csrf import csrf_exempt
17-
1816
from .models import Conversation, Message
1917
from .serializers import ConversationSerializer
2018
from ...services.tools.tools import tools, execute_tool
@@ -69,7 +67,6 @@ def get_tokens(string: str, encoding_name: str) -> str:
6967

7068
class OpenAIAPIException(APIException):
7169
"""Custom exception for OpenAI API errors."""
72-
7370
status_code = status.HTTP_500_INTERNAL_SERVER_ERROR
7471
default_detail = "An error occurred while communicating with the OpenAI API."
7572
default_code = "openai_api_error"
@@ -98,29 +95,26 @@ def destroy(self, request, *args, **kwargs):
9895
self.perform_destroy(instance)
9996
return Response(status=status.HTTP_204_NO_CONTENT)
10097

101-
@action(detail=True, methods=["post"])
98+
@action(detail=True, methods=['post'])
10299
def continue_conversation(self, request, pk=None):
103100
conversation = self.get_object()
104-
user_message = request.data.get("message")
105-
page_context = request.data.get("page_context")
101+
user_message = request.data.get('message')
102+
page_context = request.data.get('page_context')
106103

107104
if not user_message:
108105
return Response({"error": "Message is required"}, status=400)
109106

110107
# Save user message
111-
Message.objects.create(
112-
conversation=conversation, content=user_message, is_user=True
113-
)
108+
Message.objects.create(conversation=conversation,
109+
content=user_message, is_user=True)
114110

115111
# Get ChatGPT response
116112
chatgpt_response = self.get_chatgpt_response(
117-
conversation, user_message, page_context
118-
)
113+
conversation, user_message, page_context)
119114

120115
# Save ChatGPT response
121-
Message.objects.create(
122-
conversation=conversation, content=chatgpt_response, is_user=False
123-
)
116+
Message.objects.create(conversation=conversation,
117+
content=chatgpt_response, is_user=False)
124118

125119
# Generate or update title if it's the first message or empty
126120
if conversation.messages.count() <= 2 or not conversation.title:
@@ -129,31 +123,25 @@ def continue_conversation(self, request, pk=None):
129123

130124
return Response({"response": chatgpt_response, "title": conversation.title})
131125

132-
@action(detail=True, methods=["patch"])
126+
@action(detail=True, methods=['patch'])
133127
def update_title(self, request, pk=None):
134128
conversation = self.get_object()
135-
new_title = request.data.get("title")
129+
new_title = request.data.get('title')
136130

137131
if not new_title:
138-
return Response(
139-
{"error": "New title is required"}, status=status.HTTP_400_BAD_REQUEST
140-
)
132+
return Response({"error": "New title is required"}, status=status.HTTP_400_BAD_REQUEST)
141133

142134
conversation.title = new_title
143135
conversation.save()
144136

145-
return Response(
146-
{"status": "Title updated successfully", "title": conversation.title}
147-
)
137+
return Response({"status": "Title updated successfully", "title": conversation.title})
148138

149139
def get_chatgpt_response(self, conversation, user_message, page_context=None):
150140
client = OpenAI(api_key=os.environ["OPENAI_API_KEY"])
151-
messages = [
152-
{
153-
"role": "system",
154-
"content": "You are a knowledgeable assistant. Balancer is a powerful tool for selecting bipolar medication for patients. We are open-source and available for free use. Your primary role is to assist licensed clinical professionals with information related to Balancer and bipolar medication selection. If applicable, use the supplied tools to assist the professional.",
155-
}
156-
]
141+
messages = [{
142+
"role": "system",
143+
"content": "You are a knowledgeable assistant. Balancer is a powerful tool for selecting bipolar medication for patients. We are open-source and available for free use. Your primary role is to assist licensed clinical professionals with information related to Balancer and bipolar medication selection. If applicable, use the supplied tools to assist the professional."
144+
}]
157145

158146
if page_context:
159147
context_message = f"If applicable, please use the following content to ask questions. If not applicable, please answer to the best of your ability: {page_context}"
@@ -169,7 +157,7 @@ def get_chatgpt_response(self, conversation, user_message, page_context=None):
169157
model="gpt-3.5-turbo",
170158
messages=messages,
171159
tools=tools,
172-
tool_choice="auto",
160+
tool_choice="auto"
173161
)
174162

175163
response_message = response.choices[0].message
@@ -178,41 +166,37 @@ def get_chatgpt_response(self, conversation, user_message, page_context=None):
178166
tool_calls = response_message.model_dump().get("tool_calls", [])
179167

180168
if not tool_calls:
181-
return response_message["content"]
169+
return response_message['content']
182170

183171
# Handle tool calls
184172
# Add the assistant's message with tool calls to the conversation
185-
messages.append(
186-
{
187-
"role": "assistant",
188-
"content": response_message.content or "",
189-
"tool_calls": tool_calls,
190-
}
191-
)
173+
messages.append({
174+
"role": "assistant",
175+
"content": response_message.content or "",
176+
"tool_calls": tool_calls
177+
})
192178

193179
# Process each tool call
194180
for tool_call in tool_calls:
195-
tool_call_id = tool_call["id"]
196-
tool_function_name = tool_call["function"]["name"]
181+
tool_call_id = tool_call['id']
182+
tool_function_name = tool_call['function']['name']
197183
tool_arguments = json.loads(
198-
tool_call["function"].get("arguments", "{}")
199-
)
184+
tool_call['function'].get('arguments', '{}'))
200185

201186
# Execute the tool
202187
results = execute_tool(tool_function_name, tool_arguments)
203188

204189
# Add the tool response message
205-
messages.append(
206-
{
207-
"role": "tool",
208-
"content": str(results), # Convert results to string
209-
"tool_call_id": tool_call_id,
210-
}
211-
)
190+
messages.append({
191+
"role": "tool",
192+
"content": str(results), # Convert results to string
193+
"tool_call_id": tool_call_id
194+
})
212195

213196
# Final API call with tool results
214197
final_response = client.chat.completions.create(
215-
model="gpt-3.5-turbo", messages=messages
198+
model="gpt-3.5-turbo",
199+
messages=messages
216200
)
217201
return final_response.choices[0].message.content
218202
except OpenAI.error.OpenAIError as e:
@@ -231,12 +215,9 @@ def generate_title(self, conversation):
231215
response = client.chat.completions.create(
232216
model="gpt-3.5-turbo",
233217
messages=[
234-
{
235-
"role": "system",
236-
"content": "You are a helpful assistant that generates short, descriptive titles.",
237-
},
238-
{"role": "user", "content": prompt},
239-
],
218+
{"role": "system", "content": "You are a helpful assistant that generates short, descriptive titles."},
219+
{"role": "user", "content": prompt}
220+
]
240221
)
241222

242223
return response.choices[0].message.content.strip()
Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,11 @@
11
from django.urls import path
2-
32
from .views import RuleExtractionAPIView, RuleExtractionAPIOpenAIView
43

4+
55
urlpatterns = [
6-
path(
7-
"v1/api/rule_extraction",
8-
RuleExtractionAPIView.as_view(),
9-
name="rule_extraction",
10-
),
11-
path(
12-
"v1/api/rule_extraction_openai",
13-
RuleExtractionAPIOpenAIView.as_view(),
14-
name="rule_extraction_openai",
15-
),
6+
7+
path('v1/api/rule_extraction', RuleExtractionAPIView.as_view(),
8+
name='rule_extraction'),
9+
path('v1/api/rule_extraction_openai', RuleExtractionAPIOpenAIView.as_view(),
10+
name='rule_extraction_openai')
1611
]

0 commit comments

Comments
 (0)