Skip to content

Commit a24c223

Browse files
Modified document select and type checking for models (#518)
* Modified document select and type checking for models * added tokens for llama3 * added manual token cutoff
1 parent 86842e9 commit a24c223

File tree

2 files changed

+27
-12
lines changed

2 files changed

+27
-12
lines changed

backend/src/QA_integration_new.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,15 @@
2323
from src.llm import get_llm
2424
import json
2525

26+
## Chat models
27+
from langchain_openai import ChatOpenAI, AzureChatOpenAI
28+
from langchain_google_vertexai import ChatVertexAI
29+
from langchain_groq import ChatGroq
30+
from langchain_anthropic import ChatAnthropic
31+
from langchain_fireworks import ChatFireworks
32+
from langchain_aws import ChatBedrock
33+
from langchain_community.chat_models import ChatOllama
34+
2635
load_dotenv()
2736

2837
EMBEDDING_MODEL = os.getenv('EMBEDDING_MODEL')
@@ -38,8 +47,8 @@ def get_neo4j_retriever(graph, retrieval_query,document_names,index_name="vector
3847
graph=graph
3948
)
4049
logging.info(f"Successfully retrieved Neo4jVector index '{index_name}'")
50+
document_names= list(map(str.strip, json.loads(document_names)))
4151
if document_names:
42-
document_names= list(map(str.strip, json.loads(document_names)))
4352
retriever = neo_db.as_retriever(search_kwargs={'k': search_k, "score_threshold": score_threshold,'filter':{'fileName': {'$in': document_names}}})
4453
logging.info(f"Successfully created retriever for index '{index_name}' with search_k={search_k}, score_threshold={score_threshold} for documents {document_names}")
4554
else:
@@ -178,17 +187,22 @@ def summarize_messages(llm,history,stored_messages):
178187
return True
179188

180189

181-
def get_total_tokens(model, ai_response):
182-
if "gemini" in model:
190+
def get_total_tokens(ai_response,llm):
191+
192+
if isinstance(llm,(ChatOpenAI,AzureChatOpenAI,ChatFireworks,ChatGroq)):
193+
total_tokens = ai_response.response_metadata['token_usage']['total_tokens']
194+
elif isinstance(llm,(ChatVertexAI)):
183195
total_tokens = ai_response.response_metadata['usage_metadata']['prompt_token_count']
184-
elif "bedrock" in model:
196+
elif isinstance(llm,(ChatBedrock)):
185197
total_tokens = ai_response.response_metadata['usage']['total_tokens']
186-
elif "anthropic" in model:
198+
elif isinstance(llm,(ChatAnthropic)):
187199
input_tokens = int(ai_response.response_metadata['usage']['input_tokens'])
188200
output_tokens = int(ai_response.response_metadata['usage']['output_tokens'])
189201
total_tokens = input_tokens + output_tokens
190-
else:
191-
total_tokens = ai_response.response_metadata['token_usage']['total_tokens']
202+
elif isinstance(llm,(ChatOllama)):
203+
total_tokens = ai_response.response_metadata["prompt_eval_count"]
204+
else:
205+
total_tokens = 0
192206
return total_tokens
193207

194208

@@ -206,7 +220,7 @@ def clear_chat_history(graph,session_id):
206220

207221
def setup_chat(model, graph, session_id, document_names,retrieval_query):
208222
start_time = time.time()
209-
if model in ["diffbot", "LLM_MODEL_CONFIG_ollama_llama3"]:
223+
if model in ["diffbot"]:
210224
model = "openai-gpt-4o"
211225
llm,model_name = get_llm(model)
212226
logging.info(f"Model called in chat {model} and model version is {model_name}")
@@ -236,7 +250,7 @@ def process_documents(docs, question, messages, llm,model):
236250
})
237251
result = get_sources_and_chunks(sources, docs)
238252
content = ai_response.content
239-
total_tokens = get_total_tokens(model, ai_response)
253+
total_tokens = get_total_tokens(ai_response,llm)
240254

241255

242256
predict_time = time.time() - start_time

backend/src/shared/constants.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222
CHAT_DOC_SPLIT_SIZE = 3000
2323
CHAT_EMBEDDING_FILTER_SCORE_THRESHOLD = 0.10
2424
CHAT_TOKEN_CUT_OFF = {
25-
("openai-gpt-3.5","gemini-1.0-pro","gemini-1.5-pro","groq-llama3" ) : 4,
26-
("openai-gpt-4","diffbot" , "openai-gpt-4o") : 28
27-
}
25+
("openai-gpt-3.5",'azure_ai_gpt_35',"gemini-1.0-pro","gemini-1.5-pro","groq-llama3",'groq_llama3_70b','anthropic_claude_3_5_sonnet','fireworks_llama_v3_70b','bedrock_claude_3_5_sonnet', ) : 4,
26+
("openai-gpt-4","diffbot" ,'azure_ai_gpt_4o',"openai-gpt-4o") : 28,
27+
("ollama_llama3") : 2
28+
}
2829

2930

3031
### CHAT TEMPLATES

0 commit comments

Comments
 (0)