2323from src .llm import get_llm
2424import json
2525
26+ ## Chat models
27+ from langchain_openai import ChatOpenAI , AzureChatOpenAI
28+ from langchain_google_vertexai import ChatVertexAI
29+ from langchain_groq import ChatGroq
30+ from langchain_anthropic import ChatAnthropic
31+ from langchain_fireworks import ChatFireworks
32+ from langchain_aws import ChatBedrock
33+ from langchain_community .chat_models import ChatOllama
34+
2635load_dotenv ()
2736
2837EMBEDDING_MODEL = os .getenv ('EMBEDDING_MODEL' )
@@ -38,8 +47,8 @@ def get_neo4j_retriever(graph, retrieval_query,document_names,index_name="vector
3847 graph = graph
3948 )
4049 logging .info (f"Successfully retrieved Neo4jVector index '{ index_name } '" )
50+ document_names = list (map (str .strip , json .loads (document_names )))
4151 if document_names :
42- document_names = list (map (str .strip , json .loads (document_names )))
4352 retriever = neo_db .as_retriever (search_kwargs = {'k' : search_k , "score_threshold" : score_threshold ,'filter' :{'fileName' : {'$in' : document_names }}})
4453 logging .info (f"Successfully created retriever for index '{ index_name } ' with search_k={ search_k } , score_threshold={ score_threshold } for documents { document_names } " )
4554 else :
@@ -178,17 +187,22 @@ def summarize_messages(llm,history,stored_messages):
178187 return True
179188
180189
181- def get_total_tokens (model , ai_response ):
182- if "gemini" in model :
190+ def get_total_tokens (ai_response ,llm ):
191+
192+ if isinstance (llm ,(ChatOpenAI ,AzureChatOpenAI ,ChatFireworks ,ChatGroq )):
193+ total_tokens = ai_response .response_metadata ['token_usage' ]['total_tokens' ]
194+ elif isinstance (llm ,(ChatVertexAI )):
183195 total_tokens = ai_response .response_metadata ['usage_metadata' ]['prompt_token_count' ]
184- elif "bedrock" in model :
196+ elif isinstance ( llm ,( ChatBedrock )) :
185197 total_tokens = ai_response .response_metadata ['usage' ]['total_tokens' ]
186- elif "anthropic" in model :
198+ elif isinstance ( llm ,( ChatAnthropic )) :
187199 input_tokens = int (ai_response .response_metadata ['usage' ]['input_tokens' ])
188200 output_tokens = int (ai_response .response_metadata ['usage' ]['output_tokens' ])
189201 total_tokens = input_tokens + output_tokens
190- else :
191- total_tokens = ai_response .response_metadata ['token_usage' ]['total_tokens' ]
202+ elif isinstance (llm ,(ChatOllama )):
203+ total_tokens = ai_response .response_metadata ["prompt_eval_count" ]
204+ else :
205+ total_tokens = 0
192206 return total_tokens
193207
194208
@@ -206,7 +220,7 @@ def clear_chat_history(graph,session_id):
206220
207221def setup_chat (model , graph , session_id , document_names ,retrieval_query ):
208222 start_time = time .time ()
209- if model in ["diffbot" , "LLM_MODEL_CONFIG_ollama_llama3" ]:
223+ if model in ["diffbot" ]:
210224 model = "openai-gpt-4o"
211225 llm ,model_name = get_llm (model )
212226 logging .info (f"Model called in chat { model } and model version is { model_name } " )
@@ -236,7 +250,7 @@ def process_documents(docs, question, messages, llm,model):
236250 })
237251 result = get_sources_and_chunks (sources , docs )
238252 content = ai_response .content
239- total_tokens = get_total_tokens (model , ai_response )
253+ total_tokens = get_total_tokens (ai_response , llm )
240254
241255
242256 predict_time = time .time () - start_time
0 commit comments