|
| 1 | +# Import utils |
| 2 | +from uuid import UUID |
| 3 | +from utils import (build_logger, get_config, hash_token) |
| 4 | + |
| 5 | +# Import misc |
| 6 | +from azure.identity import DefaultAzureCredential |
| 7 | +from models.user import UserModel |
| 8 | +from tenacity import retry, stop_after_attempt, wait_random_exponential |
| 9 | +from typing import Any, Dict, List, AsyncGenerator, Union |
| 10 | +import asyncio |
| 11 | +import openai |
| 12 | + |
| 13 | + |
| 14 | +### |
| 15 | +# Init misc |
| 16 | +### |
| 17 | + |
| 18 | +logger = build_logger(__name__) |
| 19 | +loop = asyncio.get_running_loop() |
| 20 | + |
| 21 | + |
| 22 | +### |
| 23 | +# Init OpenIA |
| 24 | +### |
| 25 | + |
| 26 | +async def refresh_oai_token_background(): |
| 27 | + """ |
| 28 | + Refresh OpenAI token every 15 minutes. |
| 29 | +
|
| 30 | + The OpenAI SDK does not support token refresh, so we need to do it manually. We passe manually the token to the SDK. Azure AD tokens are valid for 30 mins, but we refresh every 15 minutes to be safe. |
| 31 | +
|
| 32 | + See: https://github.com/openai/openai-python/pull/350#issuecomment-1489813285 |
| 33 | + """ |
| 34 | + while True: |
| 35 | + logger.info("Refreshing OpenAI token") |
| 36 | + oai_cred = DefaultAzureCredential() |
| 37 | + oai_token = oai_cred.get_token("https://cognitiveservices.azure.com/.default") |
| 38 | + openai.api_key = oai_token.token |
| 39 | + # Execute every 20 minutes |
| 40 | + await asyncio.sleep(15 * 60) |
| 41 | + |
| 42 | + |
| 43 | +openai.api_base = get_config("openai", "api_base", str, required=True) |
| 44 | +openai.api_type = "azure_ad" |
| 45 | +openai.api_version = "2023-05-15" |
| 46 | +logger.info(f"Using Aure private service ({openai.api_base})") |
| 47 | +loop.create_task(refresh_oai_token_background()) |
| 48 | + |
| 49 | +OAI_GPT_DEPLOY_ID = get_config("openai", "gpt_deploy_id", str, required=True) |
| 50 | +OAI_GPT_MAX_TOKENS = get_config("openai", "gpt_max_tokens", int, required=True) |
| 51 | +OAI_GPT_MODEL = get_config( |
| 52 | + "openai", "gpt_model", str, default="gpt-3.5-turbo", required=True |
| 53 | +) |
| 54 | +logger.info( |
| 55 | + f'Using OpenAI ADA model "{OAI_GPT_MODEL}" ({OAI_GPT_DEPLOY_ID}) with {OAI_GPT_MAX_TOKENS} tokens max' |
| 56 | +) |
| 57 | + |
| 58 | +OAI_ADA_DEPLOY_ID = get_config("openai", "ada_deploy_id", str, required=True) |
| 59 | +OAI_ADA_MAX_TOKENS = get_config("openai", "ada_max_tokens", int, required=True) |
| 60 | +OAI_ADA_MODEL = get_config( |
| 61 | + "openai", "ada_model", str, default="text-embedding-ada-002", required=True |
| 62 | +) |
| 63 | +logger.info( |
| 64 | + f'Using OpenAI ADA model "{OAI_ADA_MODEL}" ({OAI_ADA_DEPLOY_ID}) with {OAI_ADA_MAX_TOKENS} tokens max' |
| 65 | +) |
| 66 | + |
| 67 | + |
| 68 | +class OpenAI: |
| 69 | + @retry( |
| 70 | + reraise=True, |
| 71 | + stop=stop_after_attempt(3), |
| 72 | + wait=wait_random_exponential(multiplier=0.5, max=30), |
| 73 | + ) |
| 74 | + async def vector_from_text(self, prompt: str, user_id: UUID) -> List[float]: |
| 75 | + logger.debug(f"Getting vector for text: {prompt}") |
| 76 | + try: |
| 77 | + res = openai.Embedding.create( |
| 78 | + deployment_id=OAI_ADA_DEPLOY_ID, |
| 79 | + input=prompt, |
| 80 | + model=OAI_ADA_MODEL, |
| 81 | + user=user_id.hex, |
| 82 | + ) |
| 83 | + except openai.error.AuthenticationError as e: |
| 84 | + logger.exception(e) |
| 85 | + return [] |
| 86 | + |
| 87 | + return res.data[0].embedding |
| 88 | + |
| 89 | + @retry( |
| 90 | + reraise=True, |
| 91 | + stop=stop_after_attempt(3), |
| 92 | + wait=wait_random_exponential(multiplier=0.5, max=30), |
| 93 | + ) |
| 94 | + async def completion(self, messages: List[Dict[str, str]], current_user: UserModel) -> Union[str, None]: |
| 95 | + try: |
| 96 | + # Use chat completion to get a more natural response and lower the usage cost |
| 97 | + completion = openai.ChatCompletion.create( |
| 98 | + deployment_id=OAI_GPT_DEPLOY_ID, |
| 99 | + messages=messages, |
| 100 | + model=OAI_GPT_MODEL, |
| 101 | + presence_penalty=1, # Increase the model's likelihood to talk about new topics |
| 102 | + user=hash_token(current_user.id.bytes).hex, |
| 103 | + ) |
| 104 | + content = completion["choices"][0].message.content |
| 105 | + except openai.error.AuthenticationError as e: |
| 106 | + logger.exception(e) |
| 107 | + return |
| 108 | + |
| 109 | + return content |
| 110 | + |
| 111 | + @retry( |
| 112 | + reraise=True, |
| 113 | + stop=stop_after_attempt(3), |
| 114 | + wait=wait_random_exponential(multiplier=0.5, max=30), |
| 115 | + ) |
| 116 | + async def completion_stream(self, messages: List[Dict[str, str]], current_user: UserModel) -> AsyncGenerator[Any, None]: |
| 117 | + try: |
| 118 | + # Use chat completion to get a more natural response and lower the usage cost |
| 119 | + chunks = openai.ChatCompletion.create( |
| 120 | + deployment_id=OAI_GPT_DEPLOY_ID, |
| 121 | + messages=messages, |
| 122 | + model=OAI_GPT_MODEL, |
| 123 | + presence_penalty=1, # Increase the model's likelihood to talk about new topics |
| 124 | + stream=True, |
| 125 | + user=hash_token(current_user.id.bytes).hex, |
| 126 | + ) |
| 127 | + except openai.error.AuthenticationError as e: |
| 128 | + logger.exception(e) |
| 129 | + return |
| 130 | + |
| 131 | + for chunk in chunks: |
| 132 | + content = chunk["choices"][0].get("delta", {}).get("content") |
| 133 | + if content is not None: |
| 134 | + yield content |
0 commit comments