Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions aperag/auth/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,7 @@ async def get_user_from_api_key(key):
if user is not None:
return user
try:
# api_key = ApiKeyToken.objects.get(key=key)
api_key = await sync_to_async(ApiKeyToken.objects.get)(key=key)
api_key = await ApiKeyToken.objects.aget(key=key)
except ApiKeyToken.DoesNotExist:
return None
if api_key.status == ApiKeyStatus.DELETED:
Expand Down Expand Up @@ -94,7 +93,7 @@ async def authenticate(self, request, token, scheme):
request.META[KEY_USER_ID] = get_user_from_token(token)
elif scheme == self.api_key_scheme:
user = await get_user_from_api_key(token)
if user == None:
if user is None:
return None
request.META[KEY_USER_ID] = user
return token
Expand Down
3 changes: 0 additions & 3 deletions aperag/chat/sse/sse.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from asgiref.sync import sync_to_async
from channels.generic.http import AsyncHttpConsumer
from ninja import NinjaAPI

from aperag.chat.history.redis import RedisChatMessageHistory
from aperag.chat.utils import fail_response, get_async_redis_client, start_response, stop_response, success_response
Expand All @@ -27,8 +26,6 @@

logger = logging.getLogger(__name__)

api = NinjaAPI(version="1.0.0", urls_namespace="events")


class ServerSentEventsConsumer(AsyncHttpConsumer):
async def handle(self, body):
Expand Down
6 changes: 2 additions & 4 deletions aperag/views/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,11 @@

import json

from ninja import NinjaAPI, Router
from ninja import Router

from config import settings
from aperag.db.ops import query_config
from aperag.views.utils import success

api = NinjaAPI(version="1.0.0", urls_namespace="config")
from config import settings

router = Router()

Expand Down
35 changes: 17 additions & 18 deletions aperag/views/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,6 @@

logger = logging.getLogger(__name__)

api = NinjaAPI(version="1.0.0", auth=GlobalHTTPAuth(), urls_namespace="collection")
router = Router()


Expand Down Expand Up @@ -161,13 +160,13 @@ class ConnectionInfo(Schema):


@router.get("/user_info")
def get_user_info(request):
async def get_user_info(request):
user = get_user(request)
return success({"is_admin": user == settings.ADMIN_USER})


@router.get("/models")
def list_models(request):
async def list_models(request):
response = []
model_families = yaml.safe_load(settings.MODEL_FAMILIES)
for model_family in model_families:
Expand All @@ -193,7 +192,7 @@ def list_models(request):


@router.get("/prompt_templates")
def list_prompt_templates(request):
async def list_prompt_templates(request):
language = request.headers.get('Lang', "zh-CN")
if language == "zh-CN":
return success(MULTI_ROLE_ZH_PROMPT_TEMPLATES)
Expand Down Expand Up @@ -434,7 +433,7 @@ async def update_collection(request, collection_id, collection: CollectionIn):
bot_ids = []
async for bot in bots:
bot_ids.append(bot.id)

return success(instance.view(bot_ids=bot_ids))


Expand Down Expand Up @@ -467,27 +466,27 @@ async def create_questions(request, collection_id):
return fail(HTTPStatus.NOT_FOUND, "Collection not found")
if collection.status == CollectionStatus.QUESTION_PENDING:
return fail(HTTPStatus.BAD_REQUEST, "Collection is generating questions")

collection.status = CollectionStatus.QUESTION_PENDING
await collection.asave()

documents = await sync_to_async(collection.document_set.exclude)(status=DocumentStatus.DELETED)
generate_tasks = []
async for document in documents:
generate_tasks.append(generate_questions.si(document.id))
generate_group = group(*generate_tasks)
callback_chain = chain(generate_group, update_collection_status.s(collection.id))
callback_chain.delay()
return success({})

return success({})

@router.put("/collections/{collection_id}/questions")
async def update_question(request, collection_id, question_in: QuestionIn):
user = get_user(request)
collection = await query_collection(user, collection_id)
if collection is None:
return fail(HTTPStatus.NOT_FOUND, "Collection not found")

# ceate question
if not question_in.id:
question_instance = Question(
Expand All @@ -499,13 +498,13 @@ async def update_question(request, collection_id, question_in: QuestionIn):
else:
question_instance = await query_question(user, question_in.id)
if question_instance is None:
return fail(HTTPStatus.NOT_FOUND, "Question not found")
return fail(HTTPStatus.NOT_FOUND, "Question not found")

question_instance.question = question_in.question
question_instance.answer = question_in.answer if question_in.answer else ""
question_instance.status = QuestionStatus.PENDING
await sync_to_async(question_instance.documents.clear)()

if question_in.relate_documents:
for document_id in question_in.relate_documents:
document = await query_document(user, collection_id, document_id)
Expand Down Expand Up @@ -688,7 +687,7 @@ async def update_document(
await instance.asave()
# if user add labels for a document, we need to update index
update_index_for_document.delay(instance.id)

related_questions = await sync_to_async(document.question_set.exclude)(status=QuestionStatus.DELETED)
async for question in related_questions:
question.status = QuestionStatus.WARNING
Expand All @@ -712,13 +711,13 @@ async def delete_document(request, collection_id, document_id):
await document.asave()

remove_index.delay(document.id)

related_questions = await sync_to_async(document.question_set.exclude)(status=QuestionStatus.DELETED)
async for question in related_questions:
question.documents.remove(document)
question.status = QuestionStatus.WARNING
await question.asave()

return success(document.view())


Expand All @@ -736,13 +735,13 @@ async def delete_documents(request, collection_id, document_ids: List[str]):
document.gmt_deleted = timezone.now()
await document.asave()
remove_index.delay(document.id)

related_questions = await sync_to_async(document.question_set.exclude)(status=QuestionStatus.DELETED)
async for question in related_questions:
question.documents.remove(document)
question.status = QuestionStatus.WARNING
await question.asave()

ok.append(document.id)
except Exception as e:
logger.exception(e)
Expand Down
Loading