Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
337 changes: 188 additions & 149 deletions ai_ta_backend/database/sql.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,33 @@
import os
from typing import Dict, TypedDict, List, Union

from injector import inject
from datetime import datetime, timedelta
from typing import Dict, List, TypedDict, Union

import supabase
from injector import inject
from tenacity import retry, stop_after_attempt, wait_exponential


class ProjectStats(TypedDict):
total_messages: int
total_conversations: int
unique_users: int
avg_conversations_per_user: float
avg_messages_per_user: float
avg_messages_per_conversation: float
total_messages: int
total_conversations: int
unique_users: int
avg_conversations_per_user: float
avg_messages_per_user: float
avg_messages_per_conversation: float


class WeeklyMetric(TypedDict):
current_week_value: int
metric_name: str
percentage_change: float
previous_week_value: int
current_week_value: int
metric_name: str
percentage_change: float
previous_week_value: int


class ModelUsage(TypedDict):
model_name: str
count: int
percentage: float
model_name: str
count: int
percentage: float


class SQLDatabase:

Expand Down Expand Up @@ -126,12 +129,14 @@ def getCountFromLLMConvoMonitor(self, course_name: str, last_id: int):

def getCountFromDocuments(self, course_name: str, last_id: int):
if last_id == 0:
return self.supabase_client.table("documents").select("id", count='exact').eq("course_name",
course_name).order('id', desc=False).execute()
return self.supabase_client.table("documents").select("id",
count='exact').eq("course_name",
course_name).order('id',
desc=False).execute()
else:
return self.supabase_client.table("documents").select("id", count='exact').eq("course_name",
course_name).gt("id", last_id).order('id', desc=False).execute()
return self.supabase_client.table("documents").select("id", count='exact').eq("course_name", course_name).gt(
"id", last_id).order('id', desc=False).execute()

def getDocMapFromProjects(self, course_name: str):
return self.supabase_client.table("projects").select("doc_map_id").eq("course_name", course_name).execute()

Expand Down Expand Up @@ -183,157 +188,191 @@ def getAllConversationsForUserAndProject(self, user_email: str, project_name: st

def insertProject(self, project_info):
return self.supabase_client.table("projects").insert(project_info).execute()

def getPreAssignedAPIKeys(self, email: str):
return self.supabase_client.table("pre_authorized_api_keys").select("*").contains("emails", '["' + email + '"]').execute()

return self.supabase_client.table("pre_authorized_api_keys").select("*").contains("emails",
'["' + email + '"]').execute()

def getConversationsCreatedAtByCourse(self, course_name: str):
try:
count_response = self.supabase_client.table("llm-convo-monitor")\
.select("created_at", count="exact")\
.eq("course_name", course_name)\
.execute()

total_count = count_response.count if hasattr(count_response, 'count') else 0

if total_count <= 0:
print(f"No conversations found for course: {course_name}")
return [], 0

all_data = []
batch_size = 1000
start = 0

while start < total_count:
end = min(start + batch_size - 1, total_count - 1)

try:
response = self.supabase_client.table("llm-convo-monitor")\
.select("created_at")\
.eq("course_name", course_name)\
.range(start, end)\
.execute()

if not response or not hasattr(response, 'data') or not response.data:
print(f"No data returned for range {start} to {end}.")
break

all_data.extend(response.data)
start += batch_size

except Exception as batch_error:
print(f"Error fetching batch {start}-{end}: {str(batch_error)}")
continue

if not all_data:
print(f"No conversation data could be retrieved for course: {course_name}")
return [], 0

return all_data, len(all_data)
count_response = self.supabase_client.table("llm-convo-monitor")\
.select("created_at", count="exact")\
.eq("course_name", course_name)\
.execute()

except Exception as e:
print(f"Error in getConversationsCreatedAtByCourse for {course_name}: {str(e)}")
total_count = count_response.count if hasattr(count_response, 'count') else 0

if total_count <= 0:
print(f"No conversations found for course: {course_name}")
return [], 0

all_data = []
batch_size = 1000
start = 0

while start < total_count:
end = min(start + batch_size - 1, total_count - 1)

try:
response = self.supabase_client.table("llm-convo-monitor")\
.select("created_at")\
.eq("course_name", course_name)\
.range(start, end)\
.execute()

if not response or not hasattr(response, 'data') or not response.data:
print(f"No data returned for range {start} to {end}.")
break

all_data.extend(response.data)
start += batch_size

except Exception as batch_error:
print(f"Error fetching batch {start}-{end}: {str(batch_error)}")
continue

if not all_data:
print(f"No conversation data could be retrieved for course: {course_name}")
return [], 0


return all_data, len(all_data)

except Exception as e:
print(f"Error in getConversationsCreatedAtByCourse for {course_name}: {str(e)}")
return [], 0

def getProjectStats(self, project_name: str) -> ProjectStats:
try:
response = self.supabase_client.table("project_stats").select("total_messages, total_conversations, unique_users")\
.eq("project_name", project_name).execute()

stats: Dict[str, int | float] = {
"total_messages": 0,
"total_conversations": 0,
"unique_users": 0,
"avg_conversations_per_user": 0.0,
"avg_messages_per_user": 0.0,
"avg_messages_per_conversation": 0.0
}

if response and hasattr(response, 'data') and response.data:
base_stats = response.data[0]
stats.update(base_stats)

if stats["unique_users"] > 0:
stats["avg_conversations_per_user"] = float(round(stats["total_conversations"] / stats["unique_users"], 2))
stats["avg_messages_per_user"] = float(round(stats["total_messages"] / stats["unique_users"], 2))

if stats["total_conversations"] > 0:
stats["avg_messages_per_conversation"] = float(round(stats["total_messages"] / stats["total_conversations"], 2))

# Convert stats to proper types before creating ProjectStats
stats_typed = {
"total_messages": int(stats["total_messages"]),
"total_conversations": int(stats["total_conversations"]),
"unique_users": int(stats["unique_users"]),
"avg_conversations_per_user": float(stats["avg_conversations_per_user"]),
"avg_messages_per_user": float(stats["avg_messages_per_user"]),
"avg_messages_per_conversation": float(stats["avg_messages_per_conversation"])
}
return ProjectStats(**stats_typed)

response = self.supabase_client.table("project_stats").select("total_messages, total_conversations, unique_users")\
.eq("project_name", project_name).execute()

stats: Dict[str, int | float] = {
"total_messages": 0,
"total_conversations": 0,
"unique_users": 0,
"avg_conversations_per_user": 0.0,
"avg_messages_per_user": 0.0,
"avg_messages_per_conversation": 0.0
}

if response and hasattr(response, 'data') and response.data:
base_stats = response.data[0]
stats.update(base_stats)

if stats["unique_users"] > 0:
stats["avg_conversations_per_user"] = float(round(stats["total_conversations"] / stats["unique_users"], 2))
stats["avg_messages_per_user"] = float(round(stats["total_messages"] / stats["unique_users"], 2))

if stats["total_conversations"] > 0:
stats["avg_messages_per_conversation"] = float(
round(stats["total_messages"] / stats["total_conversations"], 2))

# Convert stats to proper types before creating ProjectStats
stats_typed = {
"total_messages": int(stats["total_messages"]),
"total_conversations": int(stats["total_conversations"]),
"unique_users": int(stats["unique_users"]),
"avg_conversations_per_user": float(stats["avg_conversations_per_user"]),
"avg_messages_per_user": float(stats["avg_messages_per_user"]),
"avg_messages_per_conversation": float(stats["avg_messages_per_conversation"])
}
return ProjectStats(**stats_typed)

except Exception as e:
print(f"Error fetching project stats for {project_name}: {str(e)}")
return ProjectStats(
total_messages=0,
total_conversations=0,
unique_users=0,
avg_conversations_per_user=0.0,
avg_messages_per_user=0.0,
avg_messages_per_conversation=0.0
)
print(f"Error fetching project stats for {project_name}: {str(e)}")
return ProjectStats(total_messages=0,
total_conversations=0,
unique_users=0,
avg_conversations_per_user=0.0,
avg_messages_per_user=0.0,
avg_messages_per_conversation=0.0)

def getWeeklyTrends(self, project_name: str) -> List[WeeklyMetric]:
response = self.supabase_client.rpc('calculate_weekly_trends', {
'course_name_input': project_name
}).execute()

response = self.supabase_client.rpc('calculate_weekly_trends', {'course_name_input': project_name}).execute()

if response and hasattr(response, 'data'):
return [WeeklyMetric(
current_week_value=item['current_week_value'],
metric_name=item['metric_name'],
percentage_change=item['percentage_change'],
previous_week_value=item['previous_week_value']
) for item in response.data]
return [
WeeklyMetric(current_week_value=item['current_week_value'],
metric_name=item['metric_name'],
percentage_change=item['percentage_change'],
previous_week_value=item['previous_week_value']) for item in response.data
]

return []

def getModelUsageCounts(self, project_name: str) -> List[ModelUsage]:
response = self.supabase_client.rpc('count_models_by_project', {
'project_name_input': project_name
}).execute()

response = self.supabase_client.rpc('count_models_by_project', {'project_name_input': project_name}).execute()

if response and hasattr(response, 'data'):
total_count = sum(item['count'] for item in response.data if item.get('model'))

model_counts = []
for item in response.data:
if item.get('model'):
percentage = round((item['count'] / total_count * 100), 2) if total_count > 0 else 0
model_counts.append(ModelUsage(
model_name=item['model'],
count=item['count'],
percentage=percentage
))

return model_counts

total_count = sum(item['count'] for item in response.data if item.get('model'))

model_counts = []
for item in response.data:
if item.get('model'):
percentage = round((item['count'] / total_count * 100), 2) if total_count > 0 else 0
model_counts.append(ModelUsage(model_name=item['model'], count=item['count'], percentage=percentage))

return model_counts

return []

def getAllProjects(self):
return self.supabase_client.table("projects").select("course_name, doc_map_id, convo_map_id, last_uploaded_doc_id, last_uploaded_convo_id").execute()

return self.supabase_client.table("projects").select(
"course_name, doc_map_id, convo_map_id, last_uploaded_doc_id, last_uploaded_convo_id").execute()

def getConvoMapDetails(self):
return self.supabase_client.rpc("get_convo_maps", params={}).execute()

def getDocMapDetails(self):
return self.supabase_client.rpc("get_doc_map_details", params={}).execute()

def getProjectsWithConvoMaps(self):
return self.supabase_client.table("projects").select("course_name, convo_map_id, last_uploaded_convo_id, conversation_map_index").neq("convo_map_id", None).execute()

return self.supabase_client.table("projects").select(
"course_name, convo_map_id, last_uploaded_convo_id, conversation_map_index").neq("convo_map_id",
None).execute()

def getProjectsWithDocMaps(self):
return self.supabase_client.table("projects").select("course_name, doc_map_id, last_uploaded_doc_id, document_map_index").neq("doc_map_id", None).execute()

return self.supabase_client.table("projects").select(
"course_name, doc_map_id, last_uploaded_doc_id, document_map_index").neq("doc_map_id", None).execute()

def getProjectMapName(self, course_name, field_name):
return self.supabase_client.table("projects").select(field_name).eq("course_name", course_name).execute()


def getConversationsFromLast24Hours(self):
"""Get conversations and their messages from the last 24 hours"""
# First get conversations from the last 24 hours
conversations = self.supabase_client.table("conversations").select("id").gte(
"created_at",
datetime.now() - timedelta(days=1)).limit(1500).execute()

if not conversations or not hasattr(conversations, 'data') or not conversations.data:
return conversations

# Get the conversation IDs
conversation_ids = [conv['id'] for conv in conversations.data]

# Then get the messages for these conversations
messages = self.supabase_client.table("messages").select("conversation_id, content_text, role").in_(
"conversation_id", conversation_ids).limit(1500).execute()

# Group messages by conversation_id
if messages and hasattr(messages, 'data'):
grouped_conversations = {}
for message in messages.data:
conv_id = message['conversation_id']
if conv_id not in grouped_conversations:
grouped_conversations[conv_id] = {'id': conv_id, 'messages': []}
grouped_conversations[conv_id]['messages'].append({
'role': message['role'],
'content_text': message['content_text']
})

# Convert to list of conversations
result = messages
result.data = list(grouped_conversations.values())
return result

return messages

# def getMessagesFromConversations(self, conversations):
# return self.supabase_client.table("messages").select("*").in_("conversation_id", [conversation['id'] for conversation in conversations]).execute()
Loading