Skip to content

Commit 674c8c3

Browse files
committed
Add docstrings and minor formatting improvements for code clarity
1 parent bdbf1d1 commit 674c8c3

File tree

6 files changed

+158
-81
lines changed

6 files changed

+158
-81
lines changed

server/api/services/conversions_services.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,23 @@
22

33

44
def convert_uuids(data):
5+
"""
6+
Recursively convert UUID objects to strings in nested data structures.
7+
8+
Traverses dictionaries, lists, and other data structures to find UUID objects
9+
and converts them to their string representation for serialization.
10+
11+
Parameters
12+
----------
13+
data : any
14+
The data structure to process (dict, list, UUID, or any other type)
15+
16+
Returns
17+
-------
18+
any
19+
The data structure with all UUID objects converted to strings.
20+
Structure and types are preserved except for UUID -> str conversion.
21+
"""
522
if isinstance(data, dict):
623
return {key: convert_uuids(value) for key, value in data.items()}
724
elif isinstance(data, list):

server/api/services/embedding_services.py

Lines changed: 45 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,63 @@
11
# services/embedding_services.py
2+
3+
from pgvector.django import L2Distance
4+
25
from .sentencetTransformer_model import TransformerModel
6+
37
# Adjust import path as needed
48
from ..models.model_embeddings import Embeddings
5-
from pgvector.django import L2Distance
69

710

8-
def get_closest_embeddings(user, message_data, document_name=None, guid=None, num_results=10):
11+
def get_closest_embeddings(
12+
user, message_data, document_name=None, guid=None, num_results=10
13+
):
14+
"""
15+
Find the closest embeddings to a given message for a specific user.
16+
17+
Parameters
18+
----------
19+
user : User
20+
The user whose uploaded documents will be searched
21+
message_data : str
22+
The input message to find similar embeddings for
23+
document_name : str, optional
24+
Filter results to a specific document name
25+
guid : str, optional
26+
Filter results to a specific document GUID (takes precedence over document_name)
27+
num_results : int, default 10
28+
Maximum number of results to return
29+
30+
Returns
31+
-------
32+
list[dict]
33+
List of dictionaries containing embedding results with keys:
34+
- name: document name
35+
- text: embedded text content
36+
- page_number: page number in source document
37+
- chunk_number: chunk number within the document
38+
- distance: L2 distance from query embedding
39+
- file_id: GUID of the source file
40+
"""
41+
942
#
1043
transformerModel = TransformerModel.get_instance().model
1144
embedding_message = transformerModel.encode(message_data)
1245
# Start building the query based on the message's embedding
13-
closest_embeddings_query = Embeddings.objects.filter(
14-
upload_file__uploaded_by=user
15-
).annotate(
16-
distance=L2Distance(
17-
'embedding_sentence_transformers', embedding_message)
18-
).order_by('distance')
46+
closest_embeddings_query = (
47+
Embeddings.objects.filter(upload_file__uploaded_by=user)
48+
.annotate(
49+
distance=L2Distance("embedding_sentence_transformers", embedding_message)
50+
)
51+
.order_by("distance")
52+
)
1953

2054
# Filter by GUID if provided, otherwise filter by document name if provided
2155
if guid:
2256
closest_embeddings_query = closest_embeddings_query.filter(
23-
upload_file__guid=guid)
57+
upload_file__guid=guid
58+
)
2459
elif document_name:
25-
closest_embeddings_query = closest_embeddings_query.filter(
26-
name=document_name)
60+
closest_embeddings_query = closest_embeddings_query.filter(name=document_name)
2761

2862
# Slice the results to limit to num_results
2963
closest_embeddings_query = closest_embeddings_query[:num_results]

server/api/services/tools/tools.py

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
from django.db import connection
21
from typing import Dict, Any, Callable, List
32
from dataclasses import dataclass
3+
4+
from django.db import connection
5+
46
from .database import ask_database, get_database_info
57

68
database_schema_dict = get_database_info(connection)
@@ -11,13 +13,15 @@
1113
]
1214
)
1315

16+
1417
@dataclass
1518
class ToolFunction:
1619
name: str
1720
func: Callable
1821
description: str
1922
parameters: Dict[str, Any]
2023

24+
2125
def create_tool_dict(tool: ToolFunction) -> Dict[str, Any]:
2226
return {
2327
"type": "function",
@@ -28,10 +32,11 @@ def create_tool_dict(tool: ToolFunction) -> Dict[str, Any]:
2832
"type": "object",
2933
"properties": tool.parameters,
3034
"required": list(tool.parameters.keys()),
31-
}
32-
}
35+
},
36+
},
3337
}
3438

39+
3540
TOOL_FUNCTIONS = [
3641
ToolFunction(
3742
name="ask_database",
@@ -56,60 +61,58 @@ def create_tool_dict(tool: ToolFunction) -> Dict[str, Any]:
5661
SQL should be written using this database schema:
5762
{database_schema_string}
5863
The query should be returned in plain text, not in JSON.
59-
"""
64+
""",
6065
}
61-
}
66+
},
6267
),
6368
]
6469

6570
# Automatically generate the tool_functions dictionary and tools list
66-
tool_functions: Dict[str, Callable] = {
67-
tool.name: tool.func for tool in TOOL_FUNCTIONS
68-
}
71+
tool_functions: Dict[str, Callable] = {tool.name: tool.func for tool in TOOL_FUNCTIONS}
72+
73+
tools: List[Dict[str, Any]] = [create_tool_dict(tool) for tool in TOOL_FUNCTIONS]
6974

70-
tools: List[Dict[str, Any]] = [
71-
create_tool_dict(tool) for tool in TOOL_FUNCTIONS
72-
]
7375

7476
def validate_tool_inputs(tool_function_name, tool_arguments):
7577
"""Validate the inputs for the execute_tool function."""
7678
if not isinstance(tool_function_name, str) or not tool_function_name:
7779
raise ValueError("Invalid tool function name")
78-
80+
7981
if not isinstance(tool_arguments, dict):
8082
raise ValueError("Tool arguments must be a dictionary")
81-
83+
8284
# Check if the tool_function_name exists in the tools
8385
tool = next((t for t in tools if t["function"]["name"] == tool_function_name), None)
8486
if not tool:
8587
raise ValueError(f"Tool function '{tool_function_name}' does not exist")
86-
88+
8789
# Validate the tool arguments based on the tool's parameters
8890
parameters = tool["function"].get("parameters", {})
8991
required_params = parameters.get("required", [])
9092
for param in required_params:
9193
if param not in tool_arguments:
9294
raise ValueError(f"Missing required parameter: {param}")
93-
95+
9496
# Check if the parameter types match the expected types
9597
properties = parameters.get("properties", {})
9698
for param, prop in properties.items():
97-
expected_type = prop.get('type')
99+
expected_type = prop.get("type")
98100
if param in tool_arguments:
99-
if expected_type == 'string' and not isinstance(tool_arguments[param], str):
101+
if expected_type == "string" and not isinstance(tool_arguments[param], str):
100102
raise ValueError(f"Parameter '{param}' must be of type string")
101-
103+
104+
102105
def execute_tool(function_name: str, arguments: Dict[str, Any]) -> str:
103106
"""
104107
Execute the appropriate function based on the function name.
105-
108+
106109
:param function_name: The name of the function to execute
107110
:param arguments: A dictionary of arguments to pass to the function
108111
:return: The result of the function execution
109112
"""
110113
# Validate tool inputs
111114
validate_tool_inputs(function_name, arguments)
112-
115+
113116
try:
114117
return tool_functions[function_name](**arguments)
115118
except Exception as e:
Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
from django.urls import path, include
2-
from api.views.conversations import views
32
from rest_framework.routers import DefaultRouter
4-
# from views import ConversationViewSet
3+
4+
from api.views.conversations import views
55

66
router = DefaultRouter()
7-
router.register(r'conversations', views.ConversationViewSet,
8-
basename='conversation')
7+
router.register(r"conversations", views.ConversationViewSet, basename="conversation")
98

109
urlpatterns = [
1110
path("chatgpt/extract_text/", views.extract_text, name="post_web_text"),
12-
path("chatgpt/", include(router.urls))
11+
path("chatgpt/", include(router.urls)),
1312
]

0 commit comments

Comments
 (0)