diff --git a/backend/score.py b/backend/score.py index 80eb3a7c2..46c4ba137 100644 --- a/backend/score.py +++ b/backend/score.py @@ -17,6 +17,7 @@ from src.post_processing import create_vector_fulltext_indexes, create_entity_embedding from sse_starlette.sse import EventSourceResponse from src.communities import create_communities +from src.neighbours import get_neighbour_nodes import json from typing import List, Mapping from starlette.middleware.sessions import SessionMiddleware @@ -93,6 +94,10 @@ async def create_source_knowledge_graph_url( try: start = time.time() + payload_json_obj = {'api_name':'url_scan', 'db_url':uri, 'userName':userName, 'database':database, 'source_url':source_url, 'aws_access_key_id':aws_access_key_id, + 'model':model, 'gcs_bucket_name':gcs_bucket_name, 'gcs_bucket_folder':gcs_bucket_folder, 'source_type':source_type, + 'gcs_project_id':gcs_project_id, 'wiki_query':wiki_query, 'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") if source_url is not None: source = source_url else: @@ -174,6 +179,11 @@ async def extract_knowledge_graph_from_file( """ try: start_time = time.time() + payload_json_obj = {'api_name':'extract', 'db_url':uri, 'userName':userName, 'database':database, 'source_url':source_url, 'aws_access_key_id':aws_access_key_id, + 'model':model, 'gcs_bucket_name':gcs_bucket_name, 'gcs_bucket_folder':gcs_bucket_folder, 'source_type':source_type,'gcs_blob_filename':gcs_blob_filename, + 'file_name':file_name, 'gcs_project_id':gcs_project_id, 'wiki_query':wiki_query,'allowedNodes':allowedNodes,'allowedRelationship':allowedRelationship, + 'language':language ,'retry_condition':retry_condition,'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") graph = create_graph_database_connection(uri, userName, password, database) graphDb_data_Access = graphDBdataAccess(graph) @@ -238,6 +248,8 @@ async def get_source_list(uri:str, userName:str, password:str, database:str=None """ try: start = time.time() + payload_json_obj = {'api_name':'sources_list', 'db_url':uri, 'userName':userName, 'database':database, 'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") decoded_password = decode_password(password) if " " in uri: uri = uri.replace(" ","+") @@ -257,6 +269,8 @@ async def get_source_list(uri:str, userName:str, password:str, database:str=None @app.post("/post_processing") async def post_processing(uri=Form(), userName=Form(), password=Form(), database=Form(), tasks=Form(None)): try: + payload_json_obj = {'api_name':'post_processing', 'db_url':uri, 'userName':userName, 'database':database, 'tasks':tasks, 'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") graph = create_graph_database_connection(uri, userName, password, database) tasks = set(map(str.strip, json.loads(tasks))) @@ -276,11 +290,11 @@ async def post_processing(uri=Form(), userName=Form(), password=Form(), database logging.info(f'Entity Embeddings created') if "enable_communities" in tasks: - model = "openai_gpt_4o" - await asyncio.to_thread(create_communities, uri, userName, password, database,model) - josn_obj = {'api_name': 'post_processing/create_communities', 'db_url': uri, 'logging_time': formatted_time(datetime.now(timezone.utc))} - logger.log_struct(josn_obj) + await asyncio.to_thread(create_communities, uri, userName, password, database) + json_obj = {'api_name': 'post_processing/create_communities', 'db_url': uri, 'logging_time': formatted_time(datetime.now(timezone.utc))} logging.info(f'created communities') + + logger.log_struct(json_obj) return create_api_response('Success', message='All tasks completed successfully') except Exception as e: @@ -298,6 +312,9 @@ async def chat_bot(uri=Form(),model=Form(None),userName=Form(), password=Form(), logging.info(f"QA_RAG called at {datetime.now()}") qa_rag_start_time = time.time() try: + payload_json_obj = {'api_name':'chat_bot', 'db_url':uri, 'userName':userName, 'database':database, 'question':question,'document_names':document_names, + 'session_id':session_id, 'mode':mode, 'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") if mode == "graph": graph = Neo4jGraph( url=uri,username=userName,password=password,database=database,sanitize = True, refresh_schema=True) else: @@ -311,7 +328,7 @@ async def chat_bot(uri=Form(),model=Form(None),userName=Form(), password=Form(), logging.info(f"Total Response time is {total_call_time:.2f} seconds") result["info"]["response_time"] = round(total_call_time, 2) - json_obj = {'api_name':'chat_bot','db_url':uri,'session_id':session_id, 'logging_time': formatted_time(datetime.now(timezone.utc)), 'elapsed_api_time':f'{total_call_time:.2f}'} + json_obj = {'api_name':'chat_bot','db_url':uri,'session_id':session_id,'mode':mode, 'logging_time': formatted_time(datetime.now(timezone.utc)), 'elapsed_api_time':f'{total_call_time:.2f}'} logger.log_struct(json_obj, "INFO") return create_api_response('Success',data=result) @@ -328,6 +345,9 @@ async def chat_bot(uri=Form(),model=Form(None),userName=Form(), password=Form(), async def chunk_entities(uri=Form(),userName=Form(), password=Form(), database=Form(), nodedetails=Form(None),entities=Form(),mode=Form()): try: start = time.time() + payload_json_obj = {'api_name':'chunk_entities', 'db_url':uri, 'userName':userName, 'database':database, 'nodedetails':nodedetails,'entities':entities, + 'mode':mode, 'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") result = await asyncio.to_thread(get_entities_from_chunkids,uri=uri, username=userName, password=password, database=database,nodedetails=nodedetails,entities=entities,mode=mode) end = time.time() elapsed_time = end - start @@ -343,6 +363,25 @@ async def chunk_entities(uri=Form(),userName=Form(), password=Form(), database=F finally: gc.collect() +@app.post("/get_neighbours") +async def get_neighbours(uri=Form(),userName=Form(), password=Form(), database=Form(), elementId=Form(None)): + try: + start = time.time() + result = await asyncio.to_thread(get_neighbour_nodes,uri=uri, username=userName, password=password,database=database, element_id=elementId) + end = time.time() + elapsed_time = end - start + json_obj = {'api_name':'get_neighbours','db_url':uri, 'logging_time': formatted_time(datetime.now(timezone.utc)), 'elapsed_api_time':f'{elapsed_time:.2f}'} + logger.log_struct(json_obj, "INFO") + return create_api_response('Success',data=result,message=f"Total elapsed API time {elapsed_time:.2f}") + except Exception as e: + job_status = "Failed" + message="Unable to extract neighbour nodes for given element ID" + error_message = str(e) + logging.exception(f'Exception in get neighbours :{error_message}') + return create_api_response(job_status, message=message, error=error_message) + finally: + gc.collect() + @app.post("/graph_query") async def graph_query( uri: str = Form(), @@ -352,7 +391,9 @@ async def graph_query( document_names: str = Form(None), ): try: - # print(document_names) + payload_json_obj = {'api_name':'graph_query', 'db_url':uri, 'userName':userName, 'database':database, 'document_names':document_names, + 'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") start = time.time() result = await asyncio.to_thread( get_graph_results, @@ -380,6 +421,8 @@ async def graph_query( @app.post("/clear_chat_bot") async def clear_chat_bot(uri=Form(),userName=Form(), password=Form(), database=Form(), session_id=Form(None)): try: + payload_json_obj = {'api_name':'clear_chat_bot', 'db_url':uri, 'userName':userName, 'database':database, 'session_id':session_id, 'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") graph = create_graph_database_connection(uri, userName, password, database) result = await asyncio.to_thread(clear_chat_history,graph=graph,session_id=session_id) return create_api_response('Success',data=result) @@ -396,6 +439,8 @@ async def clear_chat_bot(uri=Form(),userName=Form(), password=Form(), database=F async def connect(uri=Form(), userName=Form(), password=Form(), database=Form()): try: start = time.time() + payload_json_obj = {'api_name':'connect', 'db_url':uri, 'userName':userName, 'database':database, 'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") graph = create_graph_database_connection(uri, userName, password, database) result = await asyncio.to_thread(connection_check_and_get_vector_dimensions, graph, database) end = time.time() @@ -417,6 +462,9 @@ async def upload_large_file_into_chunks(file:UploadFile = File(...), chunkNumber password=Form(), database=Form()): try: start = time.time() + payload_json_obj = {'api_name':'upload', 'db_url':uri, 'userName':userName, 'database':database, 'chunkNumber':chunkNumber,'totalChunks':totalChunks, + 'original_file_name':originalname,'model':model, 'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") graph = create_graph_database_connection(uri, userName, password, database) result = await asyncio.to_thread(upload_file, graph, model, file, chunkNumber, totalChunks, originalname, uri, CHUNK_DIR, MERGED_DIR) end = time.time() @@ -442,6 +490,8 @@ async def upload_large_file_into_chunks(file:UploadFile = File(...), chunkNumber async def get_structured_schema(uri=Form(), userName=Form(), password=Form(), database=Form()): try: start = time.time() + payload_json_obj = {'api_name':'schema', 'db_url':uri, 'userName':userName, 'database':database, 'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") graph = create_graph_database_connection(uri, userName, password, database) result = await asyncio.to_thread(get_labels_and_relationtypes, graph) end = time.time() @@ -512,6 +562,9 @@ async def delete_document_and_entities(uri=Form(), deleteEntities=Form()): try: start = time.time() + payload_json_obj = {'api_name':'delete_document_and_entities', 'db_url':uri, 'userName':userName, 'database':database, 'filenames':filenames,'deleteEntities':deleteEntities, + 'source_types':source_types, 'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") graph = create_graph_database_connection(uri, userName, password, database) graphDb_data_Access = graphDBdataAccess(graph) result, files_list_size = await asyncio.to_thread(graphDb_data_Access.delete_file_from_graph, filenames, source_types, deleteEntities, MERGED_DIR, uri) @@ -568,6 +621,9 @@ async def get_document_status(file_name, url, userName, password, database): @app.post("/cancelled_job") async def cancelled_job(uri=Form(), userName=Form(), password=Form(), database=Form(), filenames=Form(None), source_types=Form(None)): try: + payload_json_obj = {'api_name':'cancelled_job', 'db_url':uri, 'userName':userName, 'database':database, + 'filenames':filenames,'source_types':source_types,'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") graph = create_graph_database_connection(uri, userName, password, database) result = manually_cancelled_job(graph,filenames, source_types, MERGED_DIR, uri) @@ -584,6 +640,8 @@ async def cancelled_job(uri=Form(), userName=Form(), password=Form(), database=F @app.post("/populate_graph_schema") async def populate_graph_schema(input_text=Form(None), model=Form(None), is_schema_description_checked=Form(None)): try: + payload_json_obj = {'api_name':'populate_graph_schema', 'model':model, 'is_schema_description_checked':is_schema_description_checked, 'input_text':input_text, 'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") result = populate_graph_schema_from_text(input_text, model, is_schema_description_checked) return create_api_response('Success',data=result) except Exception as e: @@ -598,6 +656,8 @@ async def populate_graph_schema(input_text=Form(None), model=Form(None), is_sche @app.post("/get_unconnected_nodes_list") async def get_unconnected_nodes_list(uri=Form(), userName=Form(), password=Form(), database=Form()): try: + payload_json_obj = {'api_name':'get_unconnected_nodes_list', 'db_url':uri, 'userName':userName, 'database':database, 'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") start = time.time() graph = create_graph_database_connection(uri, userName, password, database) graphDb_data_Access = graphDBdataAccess(graph) @@ -619,6 +679,9 @@ async def get_unconnected_nodes_list(uri=Form(), userName=Form(), password=Form( @app.post("/delete_unconnected_nodes") async def delete_orphan_nodes(uri=Form(), userName=Form(), password=Form(), database=Form(),unconnected_entities_list=Form()): try: + payload_json_obj = {'api_name':'delete_unconnected_nodes', 'db_url':uri, 'userName':userName, 'database':database, + 'unconnected_entities_list':unconnected_entities_list, 'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") start = time.time() graph = create_graph_database_connection(uri, userName, password, database) graphDb_data_Access = graphDBdataAccess(graph) @@ -641,6 +704,8 @@ async def delete_orphan_nodes(uri=Form(), userName=Form(), password=Form(), data async def get_duplicate_nodes(uri=Form(), userName=Form(), password=Form(), database=Form()): try: start = time.time() + payload_json_obj = {'api_name':'get_duplicate_nodes', 'db_url':uri, 'userName':userName, 'database':database, 'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") graph = create_graph_database_connection(uri, userName, password, database) graphDb_data_Access = graphDBdataAccess(graph) nodes_list, total_nodes = graphDb_data_Access.get_duplicate_nodes_list() @@ -662,6 +727,9 @@ async def get_duplicate_nodes(uri=Form(), userName=Form(), password=Form(), data async def merge_duplicate_nodes(uri=Form(), userName=Form(), password=Form(), database=Form(),duplicate_nodes_list=Form()): try: start = time.time() + payload_json_obj = {'api_name':'merge_duplicate_nodes', 'db_url':uri, 'userName':userName, 'database':database, + 'duplicate_nodes_list':duplicate_nodes_list, 'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") graph = create_graph_database_connection(uri, userName, password, database) graphDb_data_Access = graphDBdataAccess(graph) result = graphDb_data_Access.merge_duplicate_nodes(duplicate_nodes_list) @@ -682,6 +750,9 @@ async def merge_duplicate_nodes(uri=Form(), userName=Form(), password=Form(), da @app.post("/drop_create_vector_index") async def merge_duplicate_nodes(uri=Form(), userName=Form(), password=Form(), database=Form(), isVectorIndexExist=Form()): try: + payload_json_obj = {'api_name':'drop_create_vector_index', 'db_url':uri, 'userName':userName, 'database':database, + 'isVectorIndexExist':isVectorIndexExist, 'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") graph = create_graph_database_connection(uri, userName, password, database) graphDb_data_Access = graphDBdataAccess(graph) result = graphDb_data_Access.drop_create_vector_index(isVectorIndexExist) @@ -698,6 +769,9 @@ async def merge_duplicate_nodes(uri=Form(), userName=Form(), password=Form(), da @app.post("/retry_processing") async def retry_processing(uri=Form(), userName=Form(), password=Form(), database=Form(), file_name=Form(), retry_condition=Form()): try: + payload_json_obj = {'api_name':'retry_processing', 'db_url':uri, 'userName':userName, 'database':database, 'file_name':file_name,'retry_condition':retry_condition, + 'logging_time': formatted_time(datetime.now(timezone.utc))} + logger.log_struct(payload_json_obj, "INFO") graph = create_graph_database_connection(uri, userName, password, database) await asyncio.to_thread(set_status_retry, graph,file_name,retry_condition) #set_status_retry(graph,file_name,retry_condition) @@ -712,22 +786,34 @@ async def retry_processing(uri=Form(), userName=Form(), password=Form(), databas gc.collect() @app.post('/metric') -async def calculate_metric(question=Form(), context=Form(), answer=Form(), model=Form()): +async def calculate_metric(question: str = Form(), + context: str = Form(), + answer: str = Form(), + model: str = Form(), + mode: str = Form()): try: - result = await asyncio.to_thread(get_ragas_metrics, question, context, answer, model) + context_list = [str(item).strip() for item in json.loads(context)] if context else [] + answer_list = [str(item).strip() for item in json.loads(answer)] if answer else [] + mode_list = [str(item).strip() for item in json.loads(mode)] if mode else [] + + result = await asyncio.to_thread( + get_ragas_metrics, question, context_list, answer_list, model + ) if result is None or "error" in result: return create_api_response( 'Failed', message='Failed to calculate evaluation metrics.', error=result.get("error", "Ragas evaluation returned null") ) - return create_api_response('Success', data=result) + data = {mode: {metric: result[metric][i] for metric in result} for i, mode in enumerate(mode_list)} + return create_api_response('Success', data=data) except Exception as e: - job_status = "Failed" - message = "Error while calculating evaluation metrics" - error_message = str(e) - logging.exception(f'{error_message}') - return create_api_response(job_status, message=message, error=error_message) + logging.exception(f"Error while calculating evaluation metrics: {e}") + return create_api_response( + 'Failed', + message="Error while calculating evaluation metrics", + error=str(e) + ) finally: gc.collect() diff --git a/backend/src/QA_integration.py b/backend/src/QA_integration.py index cf7c74f6d..468069531 100644 --- a/backend/src/QA_integration.py +++ b/backend/src/QA_integration.py @@ -276,8 +276,9 @@ def retrieve_documents(doc_retriever, messages): logging.info(f"Documents retrieved in {doc_retrieval_time:.2f} seconds") except Exception as e: - logging.error(f"Error retrieving documents: {e}") - raise + error_message = f"Error retrieving documents: {str(e)}" + logging.error(error_message) + raise RuntimeError(error_message) return docs,transformed_question @@ -434,7 +435,7 @@ def process_chat_response(messages, history, question, model, graph, document_na total_tokens = 0 formatted_docs = "" - question = transformed_question if transformed_question else question + # question = transformed_question if transformed_question else question # metrics = get_ragas_metrics(question,formatted_docs,content) # print(metrics) diff --git a/backend/src/chunkid_entities.py b/backend/src/chunkid_entities.py index 8bb9c2198..31ae07496 100644 --- a/backend/src/chunkid_entities.py +++ b/backend/src/chunkid_entities.py @@ -1,6 +1,7 @@ import logging from src.graph_query import * from src.shared.constants import * +import re def process_records(records): """ @@ -191,7 +192,11 @@ def get_entities_from_chunkids(uri, username, password, database ,nodedetails,en if "entitydetails" in nodedetails and nodedetails["entitydetails"]: entity_ids = [item["id"] for item in nodedetails["entitydetails"]] logging.info(f"chunkid_entities module: Starting for entity ids: {entity_ids}") - return process_entityids(driver, entity_ids) + result = process_entityids(driver, entity_ids) + if "chunk_data" in result.keys(): + for chunk in result["chunk_data"]: + chunk["text"] = re.sub(r'\s+', ' ', chunk["text"]) + return result else: logging.info("chunkid_entities module: No entity ids are passed") return default_response @@ -201,7 +206,11 @@ def get_entities_from_chunkids(uri, username, password, database ,nodedetails,en if "chunkdetails" in nodedetails and nodedetails["chunkdetails"]: chunk_ids = [item["id"] for item in nodedetails["chunkdetails"]] logging.info(f"chunkid_entities module: Starting for chunk ids: {chunk_ids}") - return process_chunkids(driver, chunk_ids, entities) + result = process_chunkids(driver, chunk_ids, entities) + if "chunk_data" in result.keys(): + for chunk in result["chunk_data"]: + chunk["text"] = re.sub(r'\s+', ' ', chunk["text"]) + return result else: logging.info("chunkid_entities module: No chunk ids are passed") return default_response diff --git a/backend/src/communities.py b/backend/src/communities.py index 1b19c689b..d1130150c 100644 --- a/backend/src/communities.py +++ b/backend/src/communities.py @@ -13,7 +13,8 @@ NODE_PROJECTION_ENTITY = "__Entity__" MAX_WORKERS = 10 MAX_COMMUNITY_LEVELS = 3 -MIN_COMMUNITY_SIZE = 1 +MIN_COMMUNITY_SIZE = 1 +COMMUNITY_CREATION_DEFAULT_MODEL = "openai_gpt_4o" CREATE_COMMUNITY_GRAPH_PROJECTION = """ @@ -466,7 +467,7 @@ def clear_communities(gds): raise -def create_communities(uri, username, password, database,model): +def create_communities(uri, username, password, database,model=COMMUNITY_CREATION_DEFAULT_MODEL): try: gds = get_gds_driver(uri, username, password, database) clear_communities(gds) diff --git a/backend/src/neighbours.py b/backend/src/neighbours.py new file mode 100644 index 000000000..08022ecc6 --- /dev/null +++ b/backend/src/neighbours.py @@ -0,0 +1,63 @@ +import logging +from src.graph_query import * + +NEIGHBOURS_FROM_ELEMENT_ID_QUERY = """ +MATCH (n) +WHERE elementId(n) = $element_id + +MATCH (n)<-[rels]->(m) +WITH n, + ([n] + COLLECT(DISTINCT m)) AS allNodes, + COLLECT(DISTINCT rels) AS allRels + +RETURN + [node IN allNodes | + node { + .*, + embedding: null, + text: null, + summary: null, + labels: [coalesce(apoc.coll.removeAll(labels(node), ['__Entity__'])[0], "*")], + element_id: elementId(node), + properties: { + id: CASE WHEN node.id IS NOT NULL THEN node.id ELSE node.fileName END + } + } + ] AS nodes, + + [r IN allRels | + { + start_node_element_id: elementId(startNode(r)), + end_node_element_id: elementId(endNode(r)), + type: type(r), + element_id: elementId(r) + } + ] AS relationships +""" + + +def get_neighbour_nodes(uri, username, password, database, element_id, query=NEIGHBOURS_FROM_ELEMENT_ID_QUERY): + driver = None + + try: + logging.info(f"Querying neighbours for element_id: {element_id}") + driver = get_graphDB_driver(uri, username, password, database) + driver.verify_connectivity() + logging.info("Database connectivity verified.") + + records, summary, keys = driver.execute_query(query,element_id=element_id) + nodes = records[0].get("nodes", []) + relationships = records[0].get("relationships", []) + result = {"nodes": nodes, "relationships": relationships} + + logging.info(f"Successfully retrieved neighbours for element_id: {element_id}") + return result + + except Exception as e: + logging.error(f"Error retrieving neighbours for element_id: {element_id}: {e}") + return {"nodes": [], "relationships": []} + + finally: + if driver is not None: + driver.close() + logging.info("Database driver closed.") \ No newline at end of file diff --git a/backend/src/ragas_eval.py b/backend/src/ragas_eval.py index e177b6d61..8052cb9a2 100644 --- a/backend/src/ragas_eval.py +++ b/backend/src/ragas_eval.py @@ -1,88 +1,44 @@ import os import logging import time -from typing import Dict, Tuple, Optional -import boto3 +from src.llm import get_llm from datasets import Dataset from dotenv import load_dotenv -from langchain_anthropic import ChatAnthropic -from langchain_aws import ChatBedrock -from langchain_community.chat_models import ChatOllama -from langchain_experimental.graph_transformers.diffbot import DiffbotGraphTransformer -from langchain_fireworks import ChatFireworks -from langchain_google_vertexai import ( - ChatVertexAI, - HarmBlockThreshold, - HarmCategory, -) -from langchain_groq import ChatGroq -from langchain_openai import AzureChatOpenAI, ChatOpenAI from ragas import evaluate -from ragas.metrics import answer_relevancy, context_utilization, faithfulness +from ragas.metrics import answer_relevancy, faithfulness from src.shared.common_fn import load_embedding_model - load_dotenv() -RAGAS_MODEL_VERSIONS = { - "openai_gpt_3.5": "gpt-3.5-turbo-16k", - "openai_gpt_4": "gpt-4-turbo-2024-04-09", - "openai_gpt_4o_mini": "gpt-4o-mini-2024-07-18", - "openai_gpt_4o": "gpt-4o-mini-2024-07-18", - "groq_llama3_70b": "groq_llama3_70b", -} -EMBEDDING_MODEL = os.getenv("EMBEDDING_MODEL") +EMBEDDING_MODEL = os.getenv("RAGAS_EMBEDDING_MODEL") EMBEDDING_FUNCTION, _ = load_embedding_model(EMBEDDING_MODEL) - -def get_ragas_llm(model: str) -> Tuple[object, str]: - """Retrieves the specified language model. Improved error handling and structure.""" - env_key = f"LLM_MODEL_CONFIG_{model}" - env_value = os.environ.get(env_key) - logging.info(f"Loading model configuration: {env_key}") - try: - if "openai" in model: - model_name = RAGAS_MODEL_VERSIONS[model] - llm = ChatOpenAI( - api_key=os.environ.get("OPENAI_API_KEY"), model=model_name, temperature=0 - ) - elif "groq" in model: - model_name, base_url, api_key = env_value.split(",") - llm = ChatGroq(api_key=api_key, model_name=model_name, temperature=0) - else: - raise ValueError(f"Unsupported model for evaluation: {model}") - - logging.info(f"Model loaded - Model Version: {model}") - return llm, model_name - except (ValueError, KeyError) as e: - logging.error(f"Error loading LLM: {e}") - raise - - -def get_ragas_metrics( - question: str, context: str, answer: str, model: str -) -> Optional[Dict[str, float]]: +def get_ragas_metrics(question: str, context: list, answer: list, model: str): """Calculates RAGAS metrics.""" try: start_time = time.time() dataset = Dataset.from_dict( - {"question": [question], "answer": [answer], "contexts": [[context]]} + {"question": [question] * len(answer), "answer": answer, "contexts": [[ctx] for ctx in context]} ) - logging.info("Dataset created successfully.") - - llm, model_name = get_ragas_llm(model=model) + logging.info("Evaluation dataset created successfully.") + if ("diffbot" in model) or ("ollama" in model): + raise ValueError(f"Unsupported model for evaluation: {model}") + else: + llm, model_name = get_llm(model=model) + logging.info(f"Evaluating with model: {model_name}") - + score = evaluate( dataset=dataset, - metrics=[faithfulness, answer_relevancy, context_utilization], + metrics=[faithfulness, answer_relevancy], llm=llm, embeddings=EMBEDDING_FUNCTION, ) - + score_dict = ( - score.to_pandas()[["faithfulness", "answer_relevancy", "context_utilization"]] + score.to_pandas()[["faithfulness", "answer_relevancy"]] + .fillna(0) .round(4) - .to_dict(orient="records")[0] + .to_dict(orient="list") ) end_time = time.time() logging.info(f"Evaluation completed in: {end_time - start_time:.2f} seconds") @@ -90,7 +46,7 @@ def get_ragas_metrics( except ValueError as e: if "Unsupported model for evaluation" in str(e): logging.error(f"Unsupported model error: {e}") - return {"error": str(e)} # Return the specific error message as a dictionary + return {"error": str(e)} logging.exception(f"ValueError during metrics evaluation: {e}") return {"error": str(e)} except Exception as e: diff --git a/backend/src/shared/constants.py b/backend/src/shared/constants.py index 33b27b20a..cde354f16 100644 --- a/backend/src/shared/constants.py +++ b/backend/src/shared/constants.py @@ -121,7 +121,7 @@ RETURN d AS doc, [chunk IN chunks | - chunk {.*, embedding: null} + chunk {.*, embedding: null, element_id: elementId(chunk)} ] AS chunks, [ node IN nodes | @@ -168,10 +168,10 @@ CHAT_EMBEDDING_FILTER_SCORE_THRESHOLD = 0.10 CHAT_TOKEN_CUT_OFF = { - ("openai-gpt-3.5",'azure_ai_gpt_35',"gemini-1.0-pro","gemini-1.5-pro", "gemini-1.5-flash","groq-llama3",'groq_llama3_70b','anthropic_claude_3_5_sonnet','fireworks_llama_v3_70b','bedrock_claude_3_5_sonnet', ) : 4, - ("openai-gpt-4","diffbot" ,'azure_ai_gpt_4o',"openai-gpt-4o", "openai-gpt-4o-mini") : 28, + ('openai_gpt_3.5','azure_ai_gpt_35',"gemini_1.0_pro","gemini_1.5_pro", "gemini_1.5_flash","groq-llama3",'groq_llama3_70b','anthropic_claude_3_5_sonnet','fireworks_llama_v3_70b','bedrock_claude_3_5_sonnet', ) : 4, + ("openai-gpt-4","diffbot" ,'azure_ai_gpt_4o',"openai_gpt_4o", "openai_gpt_4o_mini") : 28, ("ollama_llama3") : 2 -} +} ### CHAT TEMPLATES CHAT_SYSTEM_TEMPLATE = """ @@ -473,14 +473,16 @@ .*, embedding: null, fileName: d.fileName, - fileSource: d.fileSource + fileSource: d.fileSource, + element_id: elementId(c) } ] AS chunks, [ community IN communities WHERE community IS NOT NULL | community { .*, - embedding: null + embedding: null, + element_id:elementId(community) } ] AS communities, [ @@ -551,7 +553,7 @@ WHERE elementId(community) IN $communityids WITH collect(distinct community) AS communities RETURN [community IN communities | - community {.*, embedding: null, elementid: elementId(community)}] AS communities + community {.*, embedding: null, element_id: elementId(community)}] AS communities """ ## CHAT MODES diff --git a/frontend/src/components/ChatBot/ChatInfoModal.tsx b/frontend/src/components/ChatBot/ChatInfoModal.tsx index 7fb819a32..757503fb1 100644 --- a/frontend/src/components/ChatBot/ChatInfoModal.tsx +++ b/frontend/src/components/ChatBot/ChatInfoModal.tsx @@ -13,7 +13,7 @@ import { import { DocumentDuplicateIconOutline, ClipboardDocumentCheckIconOutline } from '@neo4j-ndl/react/icons'; import '../../styling/info.css'; import Neo4jRetrievalLogo from '../../assets/images/Neo4jRetrievalLogo.png'; -import { Entity, ExtendedNode, UserCredentials, chatInfoMessage } from '../../types'; +import { ExtendedNode, UserCredentials, chatInfoMessage } from '../../types'; import { useContext, useEffect, useMemo, useState } from 'react'; import GraphViewButton from '../Graph/GraphViewButton'; import { chunkEntitiesAPI } from '../../services/ChunkEntitiesInfo'; @@ -23,13 +23,14 @@ import { tokens } from '@neo4j-ndl/base'; import ChunkInfo from './ChunkInfo'; import EntitiesInfo from './EntitiesInfo'; import SourcesInfo from './SourcesInfo'; -import CommunitiesInfo from './Communities'; +import CommunitiesInfo from './CommunitiesInfo'; import { chatModeLables, supportedLLmsForRagas } from '../../utils/Constants'; import { Relationship } from '@neo4j-nvl/base'; import { getChatMetrics } from '../../services/GetRagasMetric'; import MetricsTab from './MetricsTab'; import { Stack } from '@mui/material'; -import { capitalizeWithUnderscore } from '../../utils/Utils'; +import { capitalizeWithUnderscore, getNodes } from '../../utils/Utils'; +import MultiModeMetrics from './MultiModeMetrics'; const ChatInfoModal: React.FC = ({ sources, @@ -46,14 +47,6 @@ const ChatInfoModal: React.FC = ({ metriccontexts, metricquestion, metricmodel, - saveNodes, - saveChunks, - saveChatRelationships, - saveCommunities, - saveInfoEntitites, - saveMetrics, - toggleInfoLoading, - toggleMetricsLoading, nodes, chunks, infoEntities, @@ -62,6 +55,18 @@ const ChatInfoModal: React.FC = ({ relationships, infoLoading, metricsLoading, + activeChatmodes, + metricError, + multiModelMetrics, + saveNodes, + saveChunks, + saveChatRelationships, + saveCommunities, + saveInfoEntitites, + saveMetrics, + toggleInfoLoading, + toggleMetricsLoading, + saveMultimodemetrics, }) => { const { breakpoints } = tokens; const isTablet = useMediaQuery(`(min-width:${breakpoints.xs}) and (max-width: ${breakpoints.lg})`); @@ -73,6 +78,8 @@ const ChatInfoModal: React.FC = ({ const [, copy] = useCopyToClipboard(); const [copiedText, setcopiedText] = useState(false); const [showMetricsTable, setShowMetricsTable] = useState(Boolean(metricDetails)); + const [showMultiModeMetrics, setShowMultiModeMetrics] = useState(Boolean(multiModelMetrics.length)) + const [multiModeError, setMultiModeError] = useState(''); const actions: CypherCodeBlockProps['actions'] = useMemo( () => [ @@ -124,32 +131,11 @@ const ChatInfoModal: React.FC = ({ .filter((rel: any) => nodeIds.has(rel.end_node_element_id) && nodeIds.has(rel.start_node_element_id)); const communitiesData = response?.data?.data?.community_data; const chunksData = response?.data?.data?.chunk_data; - - saveInfoEntitites( - nodesData.map((n: Entity) => { - if (!n.labels.length && mode === chatModeLables.entity_vector) { - return { - ...n, - labels: ['Entity'], - }; - } - return n; - }) - ); - saveNodes( - nodesData.map((n: ExtendedNode) => { - if (!n.labels.length && mode === chatModeLables.entity_vector) { - return { - ...n, - labels: ['Entity'], - }; - } - return n ?? []; - }) - ); + saveInfoEntitites(getNodes(nodesData, mode)); + saveNodes(getNodes(nodesData, mode)); saveChatRelationships(relationshipsData ?? []); saveCommunities( - (communitiesData || []) + (communitiesData ?? []) .map((community: { element_id: string }) => { const communityScore = nodeDetails?.communitydetails?.find( (c: { id: string }) => c.id === community.element_id @@ -161,7 +147,6 @@ const ChatInfoModal: React.FC = ({ }) .sort((a: any, b: any) => b.score - a.score) ); - saveChunks( chunksData .map((chunk: any) => { @@ -182,29 +167,76 @@ const ChatInfoModal: React.FC = ({ } () => { setcopiedText(false); - toggleMetricsLoading(); + if (metricsLoading) { + toggleMetricsLoading(); + } }; - }, [nodeDetails, mode, error]); + }, [nodeDetails, mode, error, metricsLoading]); const onChangeTabs = (tabId: number) => { setActiveTab(tabId); }; const loadMetrics = async () => { - setShowMetricsTable(true); - try { - toggleMetricsLoading(); - const response = await getChatMetrics(metricquestion, metriccontexts, metricanswer, metricmodel); - toggleMetricsLoading(); - if (response.data.status === 'Success') { - saveMetrics({ ...response.data.data, error: '' }); + if (activeChatmodes) { + if (Object.keys(activeChatmodes).length <= 1) { + setShowMetricsTable(true); + const [defaultMode] = Object.keys(activeChatmodes); + try { + toggleMetricsLoading(); + const response = await getChatMetrics(metricquestion, [metriccontexts], [metricanswer], metricmodel, [ + defaultMode, + ]); + toggleMetricsLoading(); + if (response.data.status === 'Success') { + const data = response; + saveMetrics(data.data.data[defaultMode]); + } else { + throw new Error(response.data.error); + } + } catch (error) { + if (error instanceof Error) { + toggleMetricsLoading(); + console.log('Error in getting chat metrics', error); + saveMetrics({ faithfulness: 0, answer_relevancy: 0, error: error.message }); + } + } } else { - throw new Error(response.data.error); - } - } catch (error) { - if (error instanceof Error) { + setShowMultiModeMetrics(true) toggleMetricsLoading(); - console.log('Error in getting chat metrics', error); - saveMetrics({ error: error.message, faithfulness: 0, answer_relevancy: 0, context_utilization: 0 }); + const contextarray = Object.values(activeChatmodes).map((r) => { + return r.metric_contexts; + }); + const answerarray = Object.values(activeChatmodes).map((r) => { + return r.metric_answer; + }); + const modesarray = Object.keys(activeChatmodes).map((mode) => { + return mode; + }); + try { + const responses = await getChatMetrics( + metricquestion, + contextarray as string[], + answerarray as string[], + metricmodel, + modesarray + ); + toggleMetricsLoading(); + if (responses.data.status === 'Success') { + const modewisedata = responses.data.data; + const metricsdata = Object.entries(modewisedata).map(([mode, scores]) => { + return { mode, answer_relevancy: scores.answer_relevancy, faithfulness: scores.faithfulness }; + }); + saveMultimodemetrics(metricsdata); + } else { + throw new Error(responses.data.error); + } + } catch (error) { + toggleMetricsLoading(); + console.log('Error in getting chat metrics', error); + if (error instanceof Error) { + setMultiModeError(error.message); + } + } } } }; @@ -301,14 +333,19 @@ const ChatInfoModal: React.FC = ({ Answer Relevancy: Determines How well the answer addresses the user's question. - - Context Utilization: Determines How effectively the system uses the - retrieved information to answer thequestion. - - {showMetricsTable && } - {!metricDetails && ( + {showMultiModeMetrics && activeChatmodes != null && Object.keys(activeChatmodes).length > 1 && ( + + )} + {showMetricsTable && activeChatmodes != null && Object.keys(activeChatmodes).length <= 1 && ( + + )} + {!metricDetails && activeChatmodes != undefined && Object.keys(activeChatmodes).length <= 1 && ( )} + {!multiModelMetrics.length && activeChatmodes != undefined && Object.keys(activeChatmodes).length > 1 && ( + + )} @@ -348,9 +395,14 @@ const ChatInfoModal: React.FC = ({ <> )} - {activeTab == 4 && nodes?.length && relationships?.length ? ( + {activeTab == 4 && nodes?.length && relationships?.length && mode !== chatModeLables.graph ? ( - + ) : ( <> diff --git a/frontend/src/components/ChatBot/ChatModesSwitch.tsx b/frontend/src/components/ChatBot/ChatModesSwitch.tsx index 5bde10d26..a958b1a06 100644 --- a/frontend/src/components/ChatBot/ChatModesSwitch.tsx +++ b/frontend/src/components/ChatBot/ChatModesSwitch.tsx @@ -24,6 +24,7 @@ export default function ChatModesSwitch({ size='small' clean onClick={() => switchToOtherMode(currentModeIndex - 1)} + aria-label='left' > @@ -39,6 +40,7 @@ export default function ChatModesSwitch({ size='small' clean onClick={() => switchToOtherMode(currentModeIndex + 1)} + aria-label='right' > diff --git a/frontend/src/components/ChatBot/Chatbot.tsx b/frontend/src/components/ChatBot/Chatbot.tsx index 4a5d7c1b8..3c1992645 100644 --- a/frontend/src/components/ChatBot/Chatbot.tsx +++ b/frontend/src/components/ChatBot/Chatbot.tsx @@ -22,9 +22,10 @@ import { ExtendedNode, ExtendedRelationship, Messages, - MetricsState, ResponseMode, UserCredentials, + metricstate, + multimodelmetric, nodeDetailsProps, } from '../../types'; import { useCredentials } from '../../context/UserCredentials'; @@ -74,13 +75,14 @@ const Chatbot: FC = (props) => { const [nodes, setNodes] = useState([]); const [relationships, setRelationships] = useState([]); const [chunks, setChunks] = useState([]); - const [metricDetails, setMetricDetails] = useState(null); + const [metricDetails, setMetricDetails] = useState(null); const [infoEntities, setInfoEntities] = useState([]); const [communities, setCommunities] = useState([]); const [infoLoading, toggleInfoLoading] = useReducer((s) => !s, false); const [metricsLoading, toggleMetricsLoading] = useReducer((s) => !s, false); const downloadLinkRef = useRef(null); const [activeChat, setActiveChat] = useState(null); + const [multiModelMetrics, setMultiModelMetrics] = useState([]); const [_, copy] = useCopyToClipboard(); const { speak, cancel, speaking } = useSpeechSynthesis({ @@ -112,7 +114,10 @@ const Chatbot: FC = (props) => { const saveChunks = (chatChunks: Chunk[]) => { setChunks(chatChunks); }; - const saveMetrics = (metricInfo: MetricsState) => { + const saveMultimodemetrics = (metrics: multimodelmetric[]) => { + setMultiModelMetrics(metrics); + }; + const saveMetrics = (metricInfo: metricstate) => { setMetricDetails(metricInfo); }; const saveCommunities = (chatCommunities: Community[]) => { @@ -393,13 +398,16 @@ const Chatbot: FC = (props) => { setActiveChat(chat); if ( (previousActiveChat != null && chat.id != previousActiveChat?.id) || - (previousActiveChat != null && previousActiveChat.currentMode != chat.currentMode) + (previousActiveChat != null && chat.currentMode != previousActiveChat.currentMode) ) { setNodes([]); setChunks([]); setInfoEntities([]); setMetricDetails(null); } + if (previousActiveChat != null && chat.id != previousActiveChat?.id) { + setMultiModelMetrics([]); + } }, []); const speechHandler = useCallback((chat: Messages) => { @@ -635,7 +643,8 @@ const Chatbot: FC = (props) => { infoEntities={infoEntities} relationships={relationships} chunks={chunks} - metricDetails={metricDetails} + metricDetails={activeChat != undefined && metricDetails != null ? metricDetails : undefined} + metricError={activeChat != undefined && metricDetails != null ? (metricDetails.error as string) : ''} communities={communities} infoLoading={infoLoading} metricsLoading={metricsLoading} @@ -647,6 +656,9 @@ const Chatbot: FC = (props) => { saveNodes={saveNodes} toggleInfoLoading={toggleInfoLoading} toggleMetricsLoading={toggleMetricsLoading} + saveMultimodemetrics={saveMultimodemetrics} + activeChatmodes={activeChat?.modes} + multiModelMetrics={multiModelMetrics} /> diff --git a/frontend/src/components/ChatBot/ChunkInfo.tsx b/frontend/src/components/ChatBot/ChunkInfo.tsx index 05fa229a1..68196fb3a 100644 --- a/frontend/src/components/ChatBot/ChunkInfo.tsx +++ b/frontend/src/components/ChatBot/ChunkInfo.tsx @@ -1,5 +1,5 @@ -import { FC, useContext } from 'react'; -import { ChunkProps } from '../../types'; +import { FC, useContext, useState } from 'react'; +import { ChunkProps, UserCredentials } from '../../types'; import { Box, LoadingSpinner, TextLink, Typography } from '@neo4j-ndl/react'; import { DocumentTextIconOutline, GlobeAltIconOutline } from '@neo4j-ndl/react/icons'; import wikipedialogo from '../../assets/images/wikipedia.svg'; @@ -10,9 +10,31 @@ import ReactMarkdown from 'react-markdown'; import { generateYouTubeLink, getLogo, isAllowedHost } from '../../utils/Utils'; import { ThemeWrapperContext } from '../../context/ThemeWrapper'; import { chatModeLables } from '../../utils/Constants'; +import { useCredentials } from '../../context/UserCredentials'; +import GraphViewModal from '../Graph/GraphViewModal'; +import { handleGraphNodeClick } from './chatInfo'; const ChunkInfo: FC = ({ loading, chunks, mode }) => { const themeUtils = useContext(ThemeWrapperContext); + const { userCredentials } = useCredentials(); + const [neoNodes, setNeoNodes] = useState([]); + const [neoRels, setNeoRels] = useState([]); + const [openGraphView, setOpenGraphView] = useState(false); + const [viewPoint, setViewPoint] = useState(''); + const [loadingGraphView, setLoadingGraphView] = useState(false); + + const handleChunkClick = (elementId: string, viewMode: string) => { + handleGraphNodeClick( + userCredentials as UserCredentials, + elementId, + viewMode, + setNeoNodes, + setNeoRels, + setOpenGraphView, + setViewPoint, + setLoadingGraphView + ); + }; return ( <> @@ -22,25 +44,40 @@ const ChunkInfo: FC = ({ loading, chunks, mode }) => { ) : chunks?.length > 0 ? (
-
    +
      {chunks.map((chunk) => (
    • {chunk?.page_number ? ( <>
      - - - {chunk?.fileName} - + <> + + + {chunk?.fileName} + +
      {mode !== chatModeLables.global_vector && mode !== chatModeLables.entity_vector && - mode !== chatModeLables.graph && ( + mode !== chatModeLables.graph && + chunk.score && ( Similarity Score: {chunk?.score} )} +
      + Page: {chunk?.page_number} +
      +
      + handleChunkClick(chunk.element_id, 'Chunk')} + >{'Graph'} + +
      ) : chunk?.url && chunk?.start_time ? ( <> @@ -58,7 +95,18 @@ const ChunkInfo: FC = ({ loading, chunks, mode }) => { {mode !== chatModeLables.global_vector && mode !== chatModeLables.entity_vector && mode !== chatModeLables.graph && ( - Similarity Score: {chunk?.score} + <> + Similarity Score: {chunk?.score} +
      + handleChunkClick(chunk.element_id, 'Chunk')} + >{'Graph'} + +
      + )} ) : chunk?.url && new URL(chunk.url).host === 'wikipedia.org' ? ( @@ -70,7 +118,18 @@ const ChunkInfo: FC = ({ loading, chunks, mode }) => { {mode !== chatModeLables.global_vector && mode !== chatModeLables.entity_vector && mode !== chatModeLables.graph && ( - Similarity Score: {chunk?.score} + <> + Similarity Score: {chunk?.score} +
      + handleChunkClick(chunk.element_id, 'Chunk')} + >{'Graph'} + +
      + )} ) : chunk?.url && new URL(chunk.url).host === 'storage.googleapis.com' ? ( @@ -82,7 +141,18 @@ const ChunkInfo: FC = ({ loading, chunks, mode }) => { {mode !== chatModeLables.global_vector && mode !== chatModeLables.entity_vector && mode !== chatModeLables.graph && ( - Similarity Score: {chunk?.score} + <> + Similarity Score: {chunk?.score} +
      + handleChunkClick(chunk.element_id, 'Chunk')} + >{'Graph'} + +
      + )} ) : chunk?.url && chunk?.url.startsWith('s3://') ? ( @@ -94,7 +164,18 @@ const ChunkInfo: FC = ({ loading, chunks, mode }) => { {mode !== chatModeLables.global_vector && mode !== chatModeLables.entity_vector && mode !== chatModeLables.graph && ( - Similarity Score: {chunk?.score} + <> + Similarity Score: {chunk?.score} +
      + handleChunkClick(chunk.element_id, 'Chunk')} + >{'Graph'} + +
      + )} ) : chunk?.url && @@ -110,7 +191,18 @@ const ChunkInfo: FC = ({ loading, chunks, mode }) => { {mode !== chatModeLables.global_vector && mode !== chatModeLables.entity_vector && mode !== chatModeLables.graph && ( - Similarity Score: {chunk?.score} + <> + Similarity Score: {chunk?.score} +
      + handleChunkClick(chunk.element_id, 'Chunk')} + >{'Graph'} + +
      + )} ) : ( @@ -126,16 +218,29 @@ const ChunkInfo: FC = ({ loading, chunks, mode }) => { className='mr-2' /> )} - - {chunk.fileName} - + <> + + {chunk.fileName} + +
      + handleChunkClick(chunk.element_id, 'Chunk')} + >{'Graph'} + +
      +
)} - {chunk?.text} +
+ {chunk?.text} +
))} @@ -143,8 +248,16 @@ const ChunkInfo: FC = ({ loading, chunks, mode }) => { ) : ( No Chunks Found )} + {openGraphView && ( + + )} ); }; - export default ChunkInfo; diff --git a/frontend/src/components/ChatBot/CommonChatActions.tsx b/frontend/src/components/ChatBot/CommonChatActions.tsx index 80d904f6a..d1a5c6243 100644 --- a/frontend/src/components/ChatBot/CommonChatActions.tsx +++ b/frontend/src/components/ChatBot/CommonChatActions.tsx @@ -30,6 +30,7 @@ export default function CommonActions({ label='Retrieval Information' disabled={chat.isTyping || chat.isLoading} onClick={() => detailsHandler(chat, activeChat)} + aria-label='Retrieval Information' > {buttonCaptions.details} @@ -40,6 +41,7 @@ export default function CommonActions({ text={chat.copying ? tooltips.copied : tooltips.copy} onClick={() => copyHandler(chat.modes[chat.currentMode]?.message, chat.id)} disabled={chat.isTyping || chat.isLoading} + aria-label='copy text' > @@ -50,6 +52,7 @@ export default function CommonActions({ text={chat.speaking ? tooltips.stopSpeaking : tooltips.textTospeech} disabled={listMessages.some((msg) => msg.speaking && msg.id !== chat.id)} label={chat.speaking ? 'stop speaking' : 'text to speech'} + aria-label='speech' > {chat.speaking ? : } diff --git a/frontend/src/components/ChatBot/Communities.tsx b/frontend/src/components/ChatBot/Communities.tsx deleted file mode 100644 index 11869d3d4..000000000 --- a/frontend/src/components/ChatBot/Communities.tsx +++ /dev/null @@ -1,41 +0,0 @@ -import { Box, LoadingSpinner, Flex, Typography } from '@neo4j-ndl/react'; -import { FC } from 'react'; -import ReactMarkdown from 'react-markdown'; -import { CommunitiesProps } from '../../types'; - -const CommunitiesInfo: FC = ({ loading, communities }) => { - console.log('communities', communities); - return ( - <> - {loading ? ( - - - - ) : communities?.length > 0 ? ( -
-
    - {communities.map((community, index) => ( -
  • -
    - - ID : - {community.id} - - - Score : - {community.score} - - {community.summary} -
    -
  • - ))} -
-
- ) : ( - No Communities Found - )} - - ); -}; - -export default CommunitiesInfo; diff --git a/frontend/src/components/ChatBot/CommunitiesInfo.tsx b/frontend/src/components/ChatBot/CommunitiesInfo.tsx index 088e4e370..1a769a1f3 100644 --- a/frontend/src/components/ChatBot/CommunitiesInfo.tsx +++ b/frontend/src/components/ChatBot/CommunitiesInfo.tsx @@ -1,10 +1,33 @@ -import { Box, LoadingSpinner, Flex, Typography } from '@neo4j-ndl/react'; -import { FC } from 'react'; +import { Box, LoadingSpinner, Flex, Typography, TextLink } from '@neo4j-ndl/react'; +import { FC, useState } from 'react'; import ReactMarkdown from 'react-markdown'; -import { CommunitiesProps } from '../../types'; +import { CommunitiesProps, UserCredentials } from '../../types'; import { chatModeLables } from '../../utils/Constants'; +import { useCredentials } from '../../context/UserCredentials'; +import GraphViewModal from '../Graph/GraphViewModal'; +import { handleGraphNodeClick } from './chatInfo'; const CommunitiesInfo: FC = ({ loading, communities, mode }) => { + const { userCredentials } = useCredentials(); + const [neoNodes, setNeoNodes] = useState([]); + const [neoRels, setNeoRels] = useState([]); + const [openGraphView, setOpenGraphView] = useState(false); + const [viewPoint, setViewPoint] = useState(''); + const [loadingGraphView, setLoadingGraphView] = useState(false); + + const handleCommunityClick = (elementId: string, viewMode: string) => { + handleGraphNodeClick( + userCredentials as UserCredentials, + elementId, + viewMode, + setNeoNodes, + setNeoRels, + setOpenGraphView, + setViewPoint, + setLoadingGraphView + ); + }; + return ( <> {loading ? ( @@ -13,13 +36,16 @@ const CommunitiesInfo: FC = ({ loading, communities, mode }) = ) : communities?.length > 0 ? (
-
    +
      {communities.map((community, index) => (
    • - ID : - {community.id} + handleCommunityClick(community.element_id, 'chatInfoView')} + >{`ID : ${community.id}`} {mode === chatModeLables.global_vector && community.score && ( @@ -36,6 +62,15 @@ const CommunitiesInfo: FC = ({ loading, communities, mode }) = ) : ( No Communities Found )} + {openGraphView && ( + + )} ); }; diff --git a/frontend/src/components/ChatBot/EntitiesInfo.tsx b/frontend/src/components/ChatBot/EntitiesInfo.tsx index 6c6e0784e..80e4fdafa 100644 --- a/frontend/src/components/ChatBot/EntitiesInfo.tsx +++ b/frontend/src/components/ChatBot/EntitiesInfo.tsx @@ -1,11 +1,21 @@ -import { Box, GraphLabel, LoadingSpinner, Typography } from '@neo4j-ndl/react'; -import { FC, useMemo } from 'react'; -import { EntitiesProps, GroupedEntity } from '../../types'; +import { Box, GraphLabel, LoadingSpinner, TextLink, Typography } from '@neo4j-ndl/react'; +import { FC, useMemo, useState } from 'react'; +import { EntitiesProps, GroupedEntity, UserCredentials } from '../../types'; import { calcWordColor } from '@neo4j-devtools/word-color'; import { graphLabels } from '../../utils/Constants'; import { parseEntity } from '../../utils/Utils'; +import { useCredentials } from '../../context/UserCredentials'; +import GraphViewModal from '../Graph/GraphViewModal'; +import { handleGraphNodeClick } from './chatInfo'; const EntitiesInfo: FC = ({ loading, mode, graphonly_entities, infoEntities }) => { + const { userCredentials } = useCredentials(); + const [neoNodes, setNeoNodes] = useState([]); + const [neoRels, setNeoRels] = useState([]); + const [openGraphView, setOpenGraphView] = useState(false); + const [viewPoint, setViewPoint] = useState(''); + const [loadingGraphView, setLoadingGraphView] = useState(false); + const groupedEntities = useMemo<{ [key: string]: GroupedEntity }>(() => { const items = infoEntities.reduce((acc, entity) => { const { label, text } = parseEntity(entity); @@ -18,7 +28,6 @@ const EntitiesInfo: FC = ({ loading, mode, graphonly_entities, in }, {} as Record; color: string }>); return items; }, [infoEntities]); - const labelCounts = useMemo(() => { const counts: { [label: string]: number } = {}; for (let index = 0; index < infoEntities?.length; index++) { @@ -33,26 +42,55 @@ const EntitiesInfo: FC = ({ loading, mode, graphonly_entities, in const sortedLabels = useMemo(() => { return Object.keys(labelCounts).sort((a, b) => labelCounts[b] - labelCounts[a]); }, [labelCounts]); + + const handleEntityClick = (elementId: string, viewMode: string) => { + handleGraphNodeClick( + userCredentials as UserCredentials, + elementId, + viewMode, + setNeoNodes, + setNeoRels, + setOpenGraphView, + setViewPoint, + setLoadingGraphView + ); + }; + return ( <> {loading ? ( - ) : Object.keys(groupedEntities)?.length > 0 || Object.keys(graphonly_entities)?.length > 0 ? ( + ) : (mode !== 'graph' && Object.keys(groupedEntities)?.length > 0) || + (mode == 'graph' && Object.keys(graphonly_entities)?.length > 0) ? (
        {mode == 'graph' ? graphonly_entities.map((label, index) => (
      • -
        - { - // @ts-ignore - label[Object.keys(label)[0]].id ?? Object.keys(label)[0] - } -
        +
          + {Object.keys(label).map((key) => ( +
        • + + {key} + + + { + // @ts-ignore + label[key].id ?? label[key] + } + +
        • + ))} +
      • )) : sortedLabels.map((label, index) => { @@ -62,20 +100,32 @@ const EntitiesInfo: FC = ({ loading, mode, graphonly_entities, in key={index} className='flex items-center mb-2 text-ellipsis whitespace-nowrap max-w-[100%)] overflow-hidden' > - e.preventDefault()} - > + {label === '__Community__' ? graphLabels.community : label} ({labelCounts[label]}) - {Array.from(entity.texts).slice(0, 3).join(', ')} + {Array.from(entity.texts) + .slice(0, 3) + .map((text, idx) => { + const matchingEntity = infoEntities.find( + (e) => e.labels.includes(label) && parseEntity(e).text === text + ); + const textId = matchingEntity?.element_id; + return ( + + handleEntityClick(textId!, 'chatInfoView')} + className={loadingGraphView ? 'cursor-wait' : 'cursor-pointer'} + > + {text} + + {Array.from(entity.texts).length > 1 ? ',' : ''} + + ); + })} ); @@ -84,8 +134,16 @@ const EntitiesInfo: FC = ({ loading, mode, graphonly_entities, in ) : ( No Entities Found )} + {openGraphView && ( + + )} ); }; - export default EntitiesInfo; diff --git a/frontend/src/components/ChatBot/MetricsTab.tsx b/frontend/src/components/ChatBot/MetricsTab.tsx index 17b5e67ae..55d37db4c 100644 --- a/frontend/src/components/ChatBot/MetricsTab.tsx +++ b/frontend/src/components/ChatBot/MetricsTab.tsx @@ -1,5 +1,4 @@ import { Banner, Box, DataGrid, DataGridComponents, Typography } from '@neo4j-ndl/react'; -import { MetricsState } from '../../types'; import { memo, useMemo, useRef } from 'react'; import { useReactTable, @@ -13,9 +12,16 @@ import { capitalize } from '../../utils/Utils'; function MetricsTab({ metricsLoading, metricDetails, + error, }: { metricsLoading: boolean; - metricDetails: MetricsState | null; + metricDetails: + | { + faithfulness: number; + answer_relevancy: number; + } + | undefined; + error: string; }) { const columnHelper = createColumnHelper<{ metric: string; score: number }>(); const tableRef = useRef(null); @@ -53,11 +59,9 @@ function MetricsTab({ const table = useReactTable({ data: metricDetails != null && !metricsLoading - ? Object.entries(metricDetails) - .slice(0, Object.keys(metricDetails).length - 1) - .map(([key, value]) => { - return { metric: key, score: value }; - }) + ? Object.entries(metricDetails).map(([key, value]) => { + return { metric: key, score: value }; + }) : [], columns, getCoreRowModel: getCoreRowModel(), @@ -72,8 +76,8 @@ function MetricsTab({ }); return ( - {metricDetails != null && metricDetails?.error?.trim() != '' ? ( - {metricDetails?.error} + {error != undefined && error?.trim() != '' ? ( + {error} ) : ( (null); + + const columnHelper = createColumnHelper(); + const columns = useMemo( + () => [ + columnHelper.accessor((row) => row.mode, { + id: 'Mode', + cell: (info) => { + const metric = info.getValue(); + const capitilizedMetric = metric.includes('_') + ? metric + .split('_') + .map((w) => capitalize(w)) + .join(' ') + : capitalize(metric); + return ( +
        + {capitilizedMetric} +
        + ); + }, + header: () => Mode, + footer: (info) => info.column.id, + }), + columnHelper.accessor((row) => row.answer_relevancy, { + id: 'Answer Relevancy', + cell: (info) => { + return {info.getValue().toFixed(2)}; + }, + header: () => Answer Relevancy, + }), + columnHelper.accessor((row) => row.faithfulness, { + id: 'Score', + cell: (info) => { + return {info.getValue().toFixed(2)}; + }, + header: () => Faithfulness, + }), + ], + [] + ); + const table = useReactTable({ + data, + columns, + getCoreRowModel: getCoreRowModel(), + getFilteredRowModel: getFilteredRowModel(), + getPaginationRowModel: getPaginationRowModel(), + enableGlobalFilter: false, + autoResetPageIndex: false, + enableRowSelection: true, + enableMultiRowSelection: true, + enableSorting: true, + getSortedRowModel: getSortedRowModel(), + }); + return ( + + {error?.trim() != '' ? ( + {error} + ) : ( + , + PaginationNumericButton: ({ isSelected, innerProps, ...restProps }) => { + return ( + + ); + }, + }} + /> + )} + + ); +} diff --git a/frontend/src/components/ChatBot/SourcesInfo.tsx b/frontend/src/components/ChatBot/SourcesInfo.tsx index 91fee511f..35d76aefb 100644 --- a/frontend/src/components/ChatBot/SourcesInfo.tsx +++ b/frontend/src/components/ChatBot/SourcesInfo.tsx @@ -1,5 +1,5 @@ import { FC, useContext } from 'react'; -import { SourcesProps } from '../../types'; +import { Chunk, SourcesProps } from '../../types'; import { Box, LoadingSpinner, TextLink, Typography } from '@neo4j-ndl/react'; import { DocumentTextIconOutline, GlobeAltIconOutline } from '@neo4j-ndl/react/icons'; import { getLogo, isAllowedHost, youtubeLinkValidation } from '../../utils/Utils'; @@ -10,17 +10,31 @@ import youtubelogo from '../../assets/images/youtube.svg'; import gcslogo from '../../assets/images/gcs.webp'; import s3logo from '../../assets/images/s3logo.png'; +const filterUniqueChunks = (chunks: Chunk[]) => { + const chunkSource = new Set(); + return chunks.filter((chunk) => { + const sourceCheck = `${chunk.fileName}-${chunk.fileSource}`; + if (chunkSource.has(sourceCheck)) { + return false; + } + chunkSource.add(sourceCheck); + return true; + + }); +}; + const SourcesInfo: FC = ({ loading, mode, chunks, sources }) => { const themeUtils = useContext(ThemeWrapperContext); + const uniqueChunks = chunks ? filterUniqueChunks(chunks) : []; return ( <> {loading ? ( - ) : mode === 'entity search+vector' && chunks?.length ? ( + ) : mode === 'entity search+vector' && uniqueChunks.length ? (
          - {chunks + {uniqueChunks .map((c) => ({ fileName: c.fileName, fileSource: c.fileSource })) .map((s, index) => { return ( @@ -133,5 +147,4 @@ const SourcesInfo: FC = ({ loading, mode, chunks, sources }) => { ); }; - export default SourcesInfo; diff --git a/frontend/src/components/ChatBot/chatInfo.ts b/frontend/src/components/ChatBot/chatInfo.ts new file mode 100644 index 000000000..c7e990ae7 --- /dev/null +++ b/frontend/src/components/ChatBot/chatInfo.ts @@ -0,0 +1,40 @@ +import { getNeighbors } from '../../services/GraphQuery'; +import { NeoNode, NeoRelationship, UserCredentials } from '../../types'; + +export const handleGraphNodeClick = async ( + userCredentials: UserCredentials, + elementId: string, + viewMode: string, + setNeoNodes: React.Dispatch>, + setNeoRels: React.Dispatch>, + setOpenGraphView: React.Dispatch>, + setViewPoint: React.Dispatch>, + setLoadingGraphView?: React.Dispatch> +) => { + if (setLoadingGraphView) { + setLoadingGraphView(true); + } + try { + const result = await getNeighbors(userCredentials, elementId); + if (result && result.data.data.nodes.length > 0) { + let { nodes } = result.data.data; + if (viewMode === 'Chunk') { + nodes = nodes.filter((node: NeoNode) => node.labels.length === 1 && node.properties.id !== null); + } + const nodeIds = new Set(nodes.map((node: NeoNode) => node.element_id)); + const relationships = result.data.data.relationships.filter( + (rel: NeoRelationship) => nodeIds.has(rel.end_node_element_id) && nodeIds.has(rel.start_node_element_id) + ); + setNeoNodes(nodes); + setNeoRels(relationships); + setOpenGraphView(true); + setViewPoint('chatInfoView'); + } + } catch (error: any) { + console.error('Error fetching neighbors:', error); + } finally { + if (setLoadingGraphView) { + setLoadingGraphView(false); + } + } +}; diff --git a/frontend/src/components/Content.tsx b/frontend/src/components/Content.tsx index 3eeace857..eae9caa77 100644 --- a/frontend/src/components/Content.tsx +++ b/frontend/src/components/Content.tsx @@ -308,7 +308,7 @@ const Content: React.FC = ({ userCredentials as UserCredentials, fileItem.fileSource, fileItem.retryOption ?? '', - fileItem.source_url, + fileItem.sourceUrl, localStorage.getItem('accesskey'), localStorage.getItem('secretkey'), fileItem.name ?? '', @@ -316,9 +316,9 @@ const Content: React.FC = ({ fileItem.gcsBucketFolder ?? '', selectedNodes.map((l) => l.value), selectedRels.map((t) => t.value), - fileItem.google_project_id, + fileItem.googleProjectId, fileItem.language, - fileItem.access_token + fileItem.accessToken ); if (apiResponse?.status === 'Failed') { @@ -571,8 +571,8 @@ const Content: React.FC = ({ ...f, status: 'Reprocess', processingProgress: isStartFromBegining ? 0 : f.processingProgress, - NodesCount: isStartFromBegining ? 0 : f.NodesCount, - relationshipCount: isStartFromBegining ? 0 : f.relationshipCount, + NodesCount: isStartFromBegining ? 0 : f.nodesCount, + relationshipCount: isStartFromBegining ? 0 : f.relationshipsCount, } : f; }); diff --git a/frontend/src/components/DataSources/AWS/S3Modal.tsx b/frontend/src/components/DataSources/AWS/S3Modal.tsx index 221fb0c9c..69ac6e7bd 100644 --- a/frontend/src/components/DataSources/AWS/S3Modal.tsx +++ b/frontend/src/components/DataSources/AWS/S3Modal.tsx @@ -30,10 +30,10 @@ const S3Modal: React.FC = ({ hideModal, open }) => { const submitHandler = async (url: string) => { const defaultValues: CustomFileBase = { - processing: 0, + processingTotalTime: 0, status: 'New', - NodesCount: 0, - relationshipCount: 0, + nodesCount: 0, + relationshipsCount: 0, type: 'PDF', model: model, fileSource: 's3 bucket', @@ -81,7 +81,7 @@ const S3Modal: React.FC = ({ hideModal, open }) => { copiedFilesData.unshift({ name: item.fileName, size: item.fileSize, - source_url: item.url, + sourceUrl: item.url, // total_pages: 'N/A', id: uuidv4(), ...defaultValues, @@ -92,9 +92,9 @@ const S3Modal: React.FC = ({ hideModal, open }) => { copiedFilesData.unshift({ ...tempFileData, status: defaultValues.status, - NodesCount: defaultValues.NodesCount, - relationshipCount: defaultValues.relationshipCount, - processing: defaultValues.processing, + nodesCount: defaultValues.nodesCount, + relationshipsCount: defaultValues.relationshipsCount, + processingTotalTime: defaultValues.processingTotalTime, model: defaultValues.model, fileSource: defaultValues.fileSource, processingProgress: defaultValues.processingProgress, diff --git a/frontend/src/components/DataSources/GCS/GCSModal.tsx b/frontend/src/components/DataSources/GCS/GCSModal.tsx index 007ffab2b..039934535 100644 --- a/frontend/src/components/DataSources/GCS/GCSModal.tsx +++ b/frontend/src/components/DataSources/GCS/GCSModal.tsx @@ -23,10 +23,10 @@ const GCSModal: React.FC = ({ hideModal, open, openGCSModal }) => const { setFilesData, model, filesData } = useFileContext(); const defaultValues: CustomFileBase = { - processing: 0, + processingTotalTime: 0, status: 'New', - NodesCount: 0, - relationshipCount: 0, + nodesCount: 0, + relationshipsCount: 0, type: 'TEXT', model: model, fileSource: 'gcs bucket', @@ -101,9 +101,9 @@ const GCSModal: React.FC = ({ hideModal, open, openGCSModal }) => size: item.fileSize ?? 0, gcsBucket: item.gcsBucketName, gcsBucketFolder: item.gcsBucketFolder, - google_project_id: item.gcsProjectId, + googleProjectId: item.gcsProjectId, id: uuidv4(), - access_token: codeResponse.access_token, + accessToken: codeResponse.access_token, ...defaultValues, }); } else { @@ -112,13 +112,13 @@ const GCSModal: React.FC = ({ hideModal, open, openGCSModal }) => copiedFilesData.unshift({ ...tempFileData, status: defaultValues.status, - NodesCount: defaultValues.NodesCount, - relationshipCount: defaultValues.relationshipCount, - processing: defaultValues.processing, + nodesCount: defaultValues.nodesCount, + relationshipsCount: defaultValues.relationshipsCount, + processingTotalTime: defaultValues.processingTotalTime, model: defaultValues.model, fileSource: defaultValues.fileSource, processingProgress: defaultValues.processingProgress, - access_token: codeResponse.access_token, + accessToken: codeResponse.access_token, }); } } diff --git a/frontend/src/components/DataSources/Local/DropZone.tsx b/frontend/src/components/DataSources/Local/DropZone.tsx index 0278e679e..bddbb65f0 100644 --- a/frontend/src/components/DataSources/Local/DropZone.tsx +++ b/frontend/src/components/DataSources/Local/DropZone.tsx @@ -24,13 +24,13 @@ const DropZone: FunctionComponent = () => { setIsLoading(false); if (f.length) { const defaultValues: CustomFileBase = { - processing: 0, + processingTotalTime: 0, status: 'None', - NodesCount: 0, - relationshipCount: 0, + nodesCount: 0, + relationshipsCount: 0, model: model, fileSource: 'local file', - uploadprogess: 0, + uploadProgress: 0, processingProgress: undefined, retryOptionStatus: false, retryOption: '', @@ -46,7 +46,7 @@ const DropZone: FunctionComponent = () => { // @ts-ignore type: `${file.name.substring(file.name.lastIndexOf('.') + 1, file.name.length).toUpperCase()}`, size: file.size, - uploadprogess: file.size && file?.size < chunkSize ? 100 : 0, + uploadProgress: file.size && file?.size < chunkSize ? 100 : 0, id: uuidv4(), ...defaultValues, }); @@ -56,9 +56,9 @@ const DropZone: FunctionComponent = () => { copiedFilesData.unshift({ ...tempFileData, status: defaultValues.status, - NodesCount: defaultValues.NodesCount, - relationshipCount: defaultValues.relationshipCount, - processing: defaultValues.processing, + nodesCount: defaultValues.nodesCount, + relationshipsCount: defaultValues.relationshipsCount, + processingTotalTime: defaultValues.processingTotalTime, model: defaultValues.model, fileSource: defaultValues.fileSource, processingProgress: defaultValues.processingProgress, diff --git a/frontend/src/components/DataSources/Local/DropZoneForSmallLayouts.tsx b/frontend/src/components/DataSources/Local/DropZoneForSmallLayouts.tsx index 17c97d0bc..d7fb1e56d 100644 --- a/frontend/src/components/DataSources/Local/DropZoneForSmallLayouts.tsx +++ b/frontend/src/components/DataSources/Local/DropZoneForSmallLayouts.tsx @@ -170,13 +170,13 @@ export default function DropZoneForSmallLayouts() { setIsLoading(false); if (f.length) { const defaultValues: CustomFileBase = { - processing: 0, + processingTotalTime: 0, status: 'None', - NodesCount: 0, - relationshipCount: 0, + nodesCount: 0, + relationshipsCount: 0, model: model, fileSource: 'local file', - uploadprogess: 0, + uploadProgress: 0, processingProgress: undefined, retryOption: '', retryOptionStatus: false, @@ -192,7 +192,7 @@ export default function DropZoneForSmallLayouts() { // @ts-ignore type: `${file.name.substring(file.name.lastIndexOf('.') + 1, file.name.length).toUpperCase()}`, size: file.size, - uploadprogess: file.size && file?.size < chunkSize ? 100 : 0, + uploadProgress: file.size && file?.size < chunkSize ? 100 : 0, id: uuidv4(), ...defaultValues, }); @@ -202,9 +202,9 @@ export default function DropZoneForSmallLayouts() { copiedFilesData.unshift({ ...tempFileData, status: defaultValues.status, - NodesCount: defaultValues.NodesCount, - relationshipCount: defaultValues.relationshipCount, - processing: defaultValues.processing, + nodesCount: defaultValues.nodesCount, + relationshipsCount: defaultValues.relationshipsCount, + processingTotalTime: defaultValues.processingTotalTime, model: defaultValues.model, fileSource: defaultValues.fileSource, processingProgress: defaultValues.processingProgress, diff --git a/frontend/src/components/Dropdown.tsx b/frontend/src/components/Dropdown.tsx index 48188f38b..087e0d71c 100644 --- a/frontend/src/components/Dropdown.tsx +++ b/frontend/src/components/Dropdown.tsx @@ -39,7 +39,9 @@ const DropdownComponent: React.FC = ({ }; }), placeholder: placeholder || 'Select an option', - defaultValue: defaultValue ? { label: capitalize(defaultValue), value: defaultValue } : undefined, + defaultValue: defaultValue + ? { label: capitalizeWithUnderscore(defaultValue), value: defaultValue } + : undefined, menuPlacement: 'auto', isDisabled: isDisabled, value: value, diff --git a/frontend/src/components/FileTable.tsx b/frontend/src/components/FileTable.tsx index e5aa5ada4..2032afd28 100644 --- a/frontend/src/components/FileTable.tsx +++ b/frontend/src/components/FileTable.tsx @@ -7,7 +7,6 @@ import { ProgressBar, StatusIndicator, TextLink, - Tip, Typography, useCopyToClipboard, } from '@neo4j-ndl/react'; @@ -71,6 +70,7 @@ const FileTable = forwardRef((props, ref) => { const skipPageResetRef = useRef(false); const [_, copy] = useCopyToClipboard(); const { colorMode } = useContext(ThemeWrapperContext); + const [copyRow, setCopyRow] = useState(false); const tableRef = useRef(null); @@ -83,8 +83,13 @@ const FileTable = forwardRef((props, ref) => { } ); - const handleCopy = (message: string) => { - copy(message); + const handleCopy = (rowData: any) => { + const rowString = JSON.stringify(rowData, null, 2); + copy(rowString); + setCopyRow(true); + setTimeout(() => { + setCopyRow(false); + }, 5000); }; const columns = useMemo( () => [ @@ -136,8 +141,8 @@ const FileTable = forwardRef((props, ref) => {
          @@ -154,18 +159,14 @@ const FileTable = forwardRef((props, ref) => { cell: (info) => { if (info.getValue() != 'Processing') { return ( - -
          - - - {info.getValue()} - - {(info.getValue() === 'Completed' || - info.getValue() === 'Failed' || - (info.getValue() === 'Cancelled' && !isReadOnlyUser)) && ( +
          + + {info.getValue()} + {(info.getValue() === 'Completed' || info.getValue() === 'Failed' || info.getValue() === 'Cancelled') && + !isReadOnlyUser && ( ((props, ref) => { )} -
          - - {info.row.original?.status === 'Failed' && ( - - handleCopy(info.row.original?.errorMessage ?? '')} - > - - - - )} - +
          ); } else if (info.getValue() === 'Processing' && info.row.original.processingProgress === undefined) { return ( @@ -324,7 +307,7 @@ const FileTable = forwardRef((props, ref) => { }, }, }), - columnHelper.accessor((row) => row.uploadprogess, { + columnHelper.accessor((row) => row.uploadProgress, { id: 'uploadprogess', cell: (info: CellContext) => { if (parseInt(info.getValue()) === 100 || info.row.original?.status === 'New') { @@ -371,7 +354,7 @@ const FileTable = forwardRef((props, ref) => { return ( - + {info.row.original.fileSource} @@ -504,13 +487,13 @@ const FileTable = forwardRef((props, ref) => { }, }, }), - columnHelper.accessor((row) => row.NodesCount, { + columnHelper.accessor((row) => row.nodesCount, { id: 'NodesCount', cell: (info) => {info.getValue()}, header: () => Nodes, footer: (info) => info.column.id, }), - columnHelper.accessor((row) => row.relationshipCount, { + columnHelper.accessor((row) => row.relationshipsCount, { id: 'relationshipCount', cell: (info) => {info.getValue()}, header: () => Relations, @@ -531,9 +514,24 @@ const FileTable = forwardRef((props, ref) => { > + { + const copied={...info.row.original}; + delete copied.accessToken; + handleCopy(copied); + }} + > + + ), - header: () => View, + header: () => Actions, footer: (info) => info.column.id, }), ], @@ -667,19 +665,19 @@ const FileTable = forwardRef((props, ref) => { type: item?.fileType?.includes('.') ? item?.fileType?.substring(1)?.toUpperCase() ?? 'None' : item?.fileType?.toUpperCase() ?? 'None', - NodesCount: item?.nodeCount ?? 0, - processing: item?.processingTime ?? 'None', - relationshipCount: item?.relationshipCount ?? 0, + nodesCount: item?.nodeCount ?? 0, + processingTotalTime: item?.processingTime ?? 'None', + relationshipsCount: item?.relationshipCount ?? 0, status: waitingFile ? 'Waiting' : getFileSourceStatus(item), model: item?.model ?? model, id: !waitingFile ? uuidv4() : waitingFile.id, - source_url: item?.url != 'None' && item?.url != '' ? item.url : '', + sourceUrl: item?.url != 'None' && item?.url != '' ? item.url : '', fileSource: item?.fileSource ?? 'None', gcsBucket: item?.gcsBucket, gcsBucketFolder: item?.gcsBucketFolder, errorMessage: item?.errorMessage, - uploadprogess: item?.uploadprogress ?? 0, - google_project_id: item?.gcsProjectId, + uploadProgress: item?.uploadprogress ?? 0, + googleProjectId: item?.gcsProjectId, language: item?.language ?? '', processingProgress: item?.processed_chunk != undefined && @@ -687,7 +685,7 @@ const FileTable = forwardRef((props, ref) => { !isNaN(Math.floor((item?.processed_chunk / item?.total_chunks) * 100)) ? Math.floor((item?.processed_chunk / item?.total_chunks) * 100) : undefined, - access_token: item?.access_token ?? '', + accessToken: item?.accessToken ?? '', retryOption: item.retry_condition ?? '', retryOptionStatus: false, }); diff --git a/frontend/src/components/Graph/GraphPropertiesTable.tsx b/frontend/src/components/Graph/GraphPropertiesTable.tsx index 7600ccd35..fa270455b 100644 --- a/frontend/src/components/Graph/GraphPropertiesTable.tsx +++ b/frontend/src/components/Graph/GraphPropertiesTable.tsx @@ -2,7 +2,6 @@ import { GraphLabel, Typography } from '@neo4j-ndl/react'; import { GraphPropertiesTableProps } from '../../types'; const GraphPropertiesTable = ({ propertiesWithTypes }: GraphPropertiesTableProps): JSX.Element => { - console.log('props', propertiesWithTypes); return (
          diff --git a/frontend/src/components/Graph/GraphViewButton.tsx b/frontend/src/components/Graph/GraphViewButton.tsx index 50eb13972..8f051ce0b 100644 --- a/frontend/src/components/Graph/GraphViewButton.tsx +++ b/frontend/src/components/Graph/GraphViewButton.tsx @@ -3,7 +3,7 @@ import { Button } from '@neo4j-ndl/react'; import GraphViewModal from './GraphViewModal'; import { GraphViewButtonProps } from '../../types'; -const GraphViewButton: React.FC = ({ nodeValues, relationshipValues }) => { +const GraphViewButton: React.FC = ({ nodeValues, relationshipValues, fill, label }) => { const [openGraphView, setOpenGraphView] = useState(false); const [viewPoint, setViewPoint] = useState(''); @@ -13,7 +13,9 @@ const GraphViewButton: React.FC = ({ nodeValues, relations }; return ( <> - + = ({ if (open) { setLoading(true); setGraphType([]); - if (viewPoint !== 'chatInfoView') { + if (viewPoint !== graphLabels.chatInfoView) { graphApi(); } else { const { finalNodes, finalRels, schemeVal } = processGraphData(nodeValues ?? [], relationshipValues ?? []); @@ -176,7 +176,9 @@ const GraphViewModal: React.FunctionComponent = ({ }, [open]); useEffect(() => { - handleSearch(debouncedQuery); + if (debouncedQuery) { + handleSearch(debouncedQuery); + } }, [debouncedQuery]); const initGraph = ( @@ -244,7 +246,7 @@ const GraphViewModal: React.FunctionComponent = ({ setNodes(updatedNodes); setRelationships(updatedRelationships); }, - [nodes] + [nodes, relationships] ); // Unmounting the component @@ -377,7 +379,7 @@ const GraphViewModal: React.FunctionComponent = ({
          - ) : graphType.length === 0 ? ( + ) : graphType.length === 0 && checkBoxView ? (
          diff --git a/frontend/src/components/Popups/GraphEnhancementDialog/Deduplication/index.tsx b/frontend/src/components/Popups/GraphEnhancementDialog/Deduplication/index.tsx index f5a021e30..420fbe590 100644 --- a/frontend/src/components/Popups/GraphEnhancementDialog/Deduplication/index.tsx +++ b/frontend/src/components/Popups/GraphEnhancementDialog/Deduplication/index.tsx @@ -12,13 +12,24 @@ import { Row, getSortedRowModel, } from '@tanstack/react-table'; -import { Checkbox, DataGrid, DataGridComponents, Flex, Tag, Typography, useMediaQuery } from '@neo4j-ndl/react'; +import { + Checkbox, + DataGrid, + DataGridComponents, + Flex, + Tag, + TextLink, + Typography, + useMediaQuery, +} from '@neo4j-ndl/react'; import Legend from '../../../UI/Legend'; import { DocumentIconOutline } from '@neo4j-ndl/react/icons'; import { calcWordColor } from '@neo4j-devtools/word-color'; import ButtonWithToolTip from '../../../UI/ButtonWithToolTip'; import mergeDuplicateNodes from '../../../../services/MergeDuplicateEntities'; import { tokens } from '@neo4j-ndl/base'; +import GraphViewModal from '../../../Graph/GraphViewModal'; +import { handleGraphNodeClick } from '../../../ChatBot/chatInfo'; export default function DeduplicationTab() { const { breakpoints } = tokens; @@ -30,6 +41,11 @@ export default function DeduplicationTab() { const [isLoading, setLoading] = useState(false); const [mergeAPIloading, setmergeAPIloading] = useState(false); const tableRef = useRef(null); + const [neoNodes, setNeoNodes] = useState([]); + const [neoRels, setNeoRels] = useState([]); + const [openGraphView, setOpenGraphView] = useState(false); + const [viewPoint, setViewPoint] = useState(''); + const fetchDuplicateNodes = useCallback(async () => { try { setLoading(true); @@ -89,6 +105,19 @@ export default function DeduplicationTab() { ); }); }; + + const handleDuplicateNodeClick = (elementId: string, viewMode: string) => { + handleGraphNodeClick( + userCredentials as UserCredentials, + elementId, + viewMode, + setNeoNodes, + setNeoRels, + setOpenGraphView, + setViewPoint + ); + }; + const columns = useMemo( () => [ { @@ -121,7 +150,13 @@ export default function DeduplicationTab() { cell: (info) => { return (
          - {info.getValue()} + handleDuplicateNodeClick(info.row.id, 'chatInfoView')} + title={info.getValue()} + > + {info.getValue()} +
          ); }, @@ -225,83 +260,94 @@ export default function DeduplicationTab() { ? `Merge Duplicate Nodes (${table.getSelectedRowModel().rows.length})` : 'Select Node(s) to Merge'; return ( -
          - - - - Refine Your Knowledge Graph: Merge Duplicate Entities: - - - Identify and merge similar entries like "Apple" and "Apple Inc." to eliminate redundancy and improve the - accuracy and clarity of your knowledge graph. - + <> +
          + + + + Refine Your Knowledge Graph: Merge Duplicate Entities: + + + Identify and merge similar entries like "Apple" and "Apple Inc." to eliminate redundancy and improve the + accuracy and clarity of your knowledge graph. + + + {duplicateNodes.length > 0 && ( + + Total Duplicate Nodes: {duplicateNodes.length} + + )} - {duplicateNodes.length > 0 && ( - - Total Duplicate Nodes: {duplicateNodes.length} - - )} - - , - PaginationNumericButton: ({ isSelected, innerProps, ...restProps }) => { - return ( - - ); - }, - }} - /> - - { - await clickHandler(); - await fetchDuplicateNodes(); + - {selectedFilesCheck} - - -
          + rootProps={{ + className: 'max-h-[355px] !overflow-y-auto', + }} + isLoading={isLoading} + components={{ + Body: (props) => , + PaginationNumericButton: ({ isSelected, innerProps, ...restProps }) => { + return ( + + ); + }, + }} + /> + + { + await clickHandler(); + await fetchDuplicateNodes(); + }} + size='large' + loading={mergeAPIloading} + text={ + isLoading + ? 'Fetching Duplicate Nodes' + : !isLoading && !duplicateNodes.length + ? 'No Nodes Found' + : !table.getSelectedRowModel().rows.length + ? 'No Nodes Selected' + : mergeAPIloading + ? 'Merging' + : `Merge Selected Nodes (${table.getSelectedRowModel().rows.length})` + } + label='Merge Duplicate Node Button' + disabled={!table.getSelectedRowModel().rows.length} + placement='top' + > + {selectedFilesCheck} + + +
          + {openGraphView && ( + + )} + ); } diff --git a/frontend/src/components/Popups/GraphEnhancementDialog/DeleteTabForOrphanNodes/index.tsx b/frontend/src/components/Popups/GraphEnhancementDialog/DeleteTabForOrphanNodes/index.tsx index 24de576c5..bcc2597f1 100644 --- a/frontend/src/components/Popups/GraphEnhancementDialog/DeleteTabForOrphanNodes/index.tsx +++ b/frontend/src/components/Popups/GraphEnhancementDialog/DeleteTabForOrphanNodes/index.tsx @@ -1,4 +1,4 @@ -import { Checkbox, DataGrid, DataGridComponents, Flex, Typography, useMediaQuery } from '@neo4j-ndl/react'; +import { Checkbox, DataGrid, DataGridComponents, Flex, TextLink, Typography, useMediaQuery } from '@neo4j-ndl/react'; import { useCallback, useEffect, useMemo, useRef, useState } from 'react'; import { UserCredentials, orphanNodeProps } from '../../../../types'; import { getOrphanNodes } from '../../../../services/GetOrphanNodes'; @@ -19,6 +19,8 @@ import { } from '@tanstack/react-table'; import DeletePopUp from '../../DeletePopUp/DeletePopUp'; import { tokens } from '@neo4j-ndl/base'; +import GraphViewModal from '../../../Graph/GraphViewModal'; +import { handleGraphNodeClick } from '../../../ChatBot/chatInfo'; export default function DeletePopUpForOrphanNodes({ deleteHandler, loading, @@ -35,6 +37,10 @@ export default function DeletePopUpForOrphanNodes({ const [rowSelection, setRowSelection] = useState>({}); const tableRef = useRef(null); const [showDeletePopUp, setshowDeletePopUp] = useState(false); + const [neoNodes, setNeoNodes] = useState([]); + const [neoRels, setNeoRels] = useState([]); + const [openGraphView, setOpenGraphView] = useState(false); + const [viewPoint, setViewPoint] = useState(''); const fetchOrphanNodes = useCallback(async () => { try { @@ -66,6 +72,18 @@ export default function DeletePopUpForOrphanNodes({ }, [userCredentials]); const columnHelper = createColumnHelper(); + const handleOrphanNodeClick = (elementId: string, viewMode: string) => { + handleGraphNodeClick( + userCredentials as UserCredentials, + elementId, + viewMode, + setNeoNodes, + setNeoRels, + setOpenGraphView, + setViewPoint + ); + }; + const columns = useMemo( () => [ { @@ -98,7 +116,13 @@ export default function DeletePopUpForOrphanNodes({ cell: (info) => { return (
          - {info.getValue()} + handleOrphanNodeClick(info.row.id, 'chatInfoView')} + title={info.getValue()} + > + {info.getValue()} +
          ); }, @@ -190,94 +214,105 @@ export default function DeletePopUpForOrphanNodes({ }; return ( -
          - {showDeletePopUp && ( - setshowDeletePopUp(false)} - loading={loading} - view='settingsView' - /> - )} + <>
          - - - - Orphan Nodes Deletion (100 nodes per batch) - - {totalOrphanNodes > 0 && ( + {showDeletePopUp && ( + setshowDeletePopUp(false)} + loading={loading} + view='settingsView' + /> + )} +
          + + - Total Nodes: {totalOrphanNodes} + Orphan Nodes Deletion (100 nodes per batch) - )} - - - - This feature helps improve the accuracy of your knowledge graph by identifying and removing entities that - are not connected to any other information. These "lonely" entities can be remnants of past analyses or - errors in data processing. By removing them, we can create a cleaner and more efficient knowledge graph - that leads to more relevant and informative responses. - + {totalOrphanNodes > 0 && ( + + Total Nodes: {totalOrphanNodes} + + )} + + + + This feature helps improve the accuracy of your knowledge graph by identifying and removing entities + that are not connected to any other information. These "lonely" entities can be remnants of past + analyses or errors in data processing. By removing them, we can create a cleaner and more efficient + knowledge graph that leads to more relevant and informative responses. + + +
          + , + PaginationNumericButton: ({ isSelected, innerProps, ...restProps }) => { + return ( + + ); + }, + }} + /> + + setshowDeletePopUp(true)} + size='large' + loading={loading} + text={ + isLoading + ? 'Fetching Orphan Nodes' + : !isLoading && !orphanNodes.length + ? 'No Nodes Found' + : !table.getSelectedRowModel().rows.length + ? 'No Nodes Selected' + : `Delete Selected Nodes (${table.getSelectedRowModel().rows.length})` + } + label='Orphan Node deletion button' + disabled={!table.getSelectedRowModel().rows.length} + placement='top' + > + {selectedFilesCheck} +
          - , - PaginationNumericButton: ({ isSelected, innerProps, ...restProps }) => { - return ( - - ); - }, - }} - /> - - setshowDeletePopUp(true)} - size='large' - loading={loading} - text={ - isLoading - ? 'Fetching Orphan Nodes' - : !isLoading && !orphanNodes.length - ? 'No Nodes Found' - : !table.getSelectedRowModel().rows.length - ? 'No Nodes Selected' - : `Delete Selected Nodes (${table.getSelectedRowModel().rows.length})` - } - label='Orphan Node deletion button' - disabled={!table.getSelectedRowModel().rows.length} - placement='top' - > - {selectedFilesCheck} - - -
          + {openGraphView && ( + + )} + ); } diff --git a/frontend/src/components/Popups/GraphEnhancementDialog/EnitityExtraction/EntityExtractionSetting.tsx b/frontend/src/components/Popups/GraphEnhancementDialog/EnitityExtraction/EntityExtractionSetting.tsx index 8a7362d26..67dd10f9e 100644 --- a/frontend/src/components/Popups/GraphEnhancementDialog/EnitityExtraction/EntityExtractionSetting.tsx +++ b/frontend/src/components/Popups/GraphEnhancementDialog/EnitityExtraction/EntityExtractionSetting.tsx @@ -18,7 +18,7 @@ export default function EntityExtractionSetting({ openTextSchema, settingView, onContinue, - colseEnhanceGraphSchemaDialog, + closeEnhanceGraphSchemaDialog, }: { view: 'Dialog' | 'Tabs'; open?: boolean; @@ -26,7 +26,7 @@ export default function EntityExtractionSetting({ openTextSchema: () => void; settingView: 'contentView' | 'headerView'; onContinue?: () => void; - colseEnhanceGraphSchemaDialog?: () => void; + closeEnhanceGraphSchemaDialog?: () => void; }) { const { breakpoints } = tokens; const { @@ -240,10 +240,30 @@ export default function EntityExtractionSetting({ ); localStorage.setItem('selectedSchemas', JSON.stringify({ db: userCredentials?.uri, selectedOptions: [] })); showNormalToast(`Successfully Removed the Schema settings`); - if (view === 'Dialog' && onClose != undefined) { - onClose(); + if (view === 'Tabs' && closeEnhanceGraphSchemaDialog != undefined) { + closeEnhanceGraphSchemaDialog(); } }; + const handleApply = () => { + setIsSchema(true); + localStorage.setItem('isSchema', JSON.stringify(true)); + showNormalToast(`Successfully Applied the Schema settings`); + if (view === 'Tabs' && closeEnhanceGraphSchemaDialog != undefined) { + closeEnhanceGraphSchemaDialog(); + } + localStorage.setItem( + 'selectedNodeLabels', + JSON.stringify({ db: userCredentials?.uri, selectedOptions: selectedNodes }) + ); + localStorage.setItem( + 'selectedRelationshipLabels', + JSON.stringify({ db: userCredentials?.uri, selectedOptions: selectedRels }) + ); + localStorage.setItem( + 'selectedSchemas', + JSON.stringify({ db: userCredentials?.uri, selectedOptions: selectedSchemas }) + ); + }; // Load selectedSchemas from local storage on mount useEffect(() => { @@ -297,9 +317,8 @@ export default function EntityExtractionSetting({ options: nodeLabelOptions, onChange: onChangenodes, value: selectedNodes, - classNamePrefix: `${ - isTablet ? 'tablet_entity_extraction_Tab_node_label' : 'entity_extraction_Tab_node_label' - }`, + classNamePrefix: `${isTablet ? 'tablet_entity_extraction_Tab_node_label' : 'entity_extraction_Tab_node_label' + }`, }} type='creatable' /> @@ -313,9 +332,8 @@ export default function EntityExtractionSetting({ options: relationshipTypeOptions, onChange: onChangerels, value: selectedRels, - classNamePrefix: `${ - isTablet ? 'tablet_entity_extraction_Tab_relationship_label' : 'entity_extraction_Tab_relationship_label' - }`, + classNamePrefix: `${isTablet ? 'tablet_entity_extraction_Tab_relationship_label' : 'entity_extraction_Tab_relationship_label' + }`, }} type='creatable' /> @@ -343,8 +361,8 @@ export default function EntityExtractionSetting({ if (view === 'Dialog' && onClose != undefined) { onClose(); } - if (view === 'Tabs' && colseEnhanceGraphSchemaDialog != undefined) { - colseEnhanceGraphSchemaDialog(); + if (view === 'Tabs' && closeEnhanceGraphSchemaDialog != undefined) { + closeEnhanceGraphSchemaDialog(); } openTextSchema(); }} @@ -372,6 +390,15 @@ export default function EntityExtractionSetting({ {buttonCaptions.clearSettings} )} + + {buttonCaptions.applyGraphSchema} +
          diff --git a/frontend/src/components/Popups/GraphEnhancementDialog/index.tsx b/frontend/src/components/Popups/GraphEnhancementDialog/index.tsx index c4d9ee2f7..c7b621748 100644 --- a/frontend/src/components/Popups/GraphEnhancementDialog/index.tsx +++ b/frontend/src/components/Popups/GraphEnhancementDialog/index.tsx @@ -102,7 +102,7 @@ export default function GraphEnhancementDialog({ openTextSchema={() => { setShowTextFromSchemaDialog({ triggeredFrom: 'enhancementtab', show: true }); }} - colseEnhanceGraphSchemaDialog={onClose} + closeEnhanceGraphSchemaDialog={onClose} settingView='headerView' />
          diff --git a/frontend/src/hooks/useSourceInput.tsx b/frontend/src/hooks/useSourceInput.tsx index eae33e3d4..7a8b8e7b1 100644 --- a/frontend/src/hooks/useSourceInput.tsx +++ b/frontend/src/hooks/useSourceInput.tsx @@ -47,10 +47,10 @@ export default function useSourceInput( const submitHandler = useCallback( async (url: string) => { const defaultValues: CustomFileBase = { - processing: 0, + processingTotalTime: 0, status: 'New', - NodesCount: 0, - relationshipCount: 0, + nodesCount: 0, + relationshipsCount: 0, type: 'TEXT', model: model, fileSource: fileSource, @@ -122,7 +122,7 @@ export default function useSourceInput( ...defaultValues, }; if (isWikiQuery) { - baseValues.wiki_query = item.fileName; + baseValues.wikiQuery = item.fileName; } copiedFilesData.unshift(baseValues); } else { @@ -131,9 +131,9 @@ export default function useSourceInput( copiedFilesData.unshift({ ...tempFileData, status: defaultValues.status, - NodesCount: defaultValues.NodesCount, - relationshipCount: defaultValues.relationshipCount, - processing: defaultValues.processing, + nodesCount: defaultValues.nodesCount, + relationshipsCount: defaultValues.relationshipsCount, + processingTotalTime: defaultValues.processingTotalTime, model: defaultValues.model, fileSource: defaultValues.fileSource, processingProgress: defaultValues.processingProgress, diff --git a/frontend/src/services/GetRagasMetric.ts b/frontend/src/services/GetRagasMetric.ts index a72c69e5d..90e365d00 100644 --- a/frontend/src/services/GetRagasMetric.ts +++ b/frontend/src/services/GetRagasMetric.ts @@ -1,12 +1,19 @@ import { MetricsResponse } from '../types'; import api from '../API/Index'; -export const getChatMetrics = async (question: string, context: string, answer: string, model: string) => { +export const getChatMetrics = async ( + question: string, + context: string[], + answer: string[], + model: string, + mode: string[] +) => { const formData = new FormData(); formData.append('question', question); - formData.append('context', `[${context}]`); - formData.append('answer', answer); + formData.append('context', JSON.stringify(context)); + formData.append('answer', JSON.stringify(answer)); formData.append('model', model); + formData.append('mode', JSON.stringify(mode)); try { const response = await api.post(`/metric`, formData); return response; diff --git a/frontend/src/services/GraphQuery.ts b/frontend/src/services/GraphQuery.ts index f792f4aec..d9277a504 100644 --- a/frontend/src/services/GraphQuery.ts +++ b/frontend/src/services/GraphQuery.ts @@ -1,7 +1,7 @@ import { UserCredentials } from '../types'; import api from '../API/Index'; -const graphQueryAPI = async ( +export const graphQueryAPI = async ( userCredentials: UserCredentials, query_type: string, document_names: (string | undefined)[] | undefined @@ -26,4 +26,24 @@ const graphQueryAPI = async ( throw error; } }; -export default graphQueryAPI; + +export const getNeighbors = async (userCredentials: UserCredentials, elementId: string) => { + try { + const formData = new FormData(); + formData.append('uri', userCredentials?.uri ?? ''); + formData.append('database', userCredentials?.database ?? ''); + formData.append('userName', userCredentials?.userName ?? ''); + formData.append('password', userCredentials?.password ?? ''); + formData.append('elementId', elementId); + + const response = await api.post(`/get_neighbours`, formData, { + headers: { + 'Content-Type': 'multipart/form-data', + }, + }); + return response; + } catch (error) { + console.log('Error Posting the Question:', error); + throw error; + } +}; diff --git a/frontend/src/types.ts b/frontend/src/types.ts index 0f8f913d3..02dbc0a2b 100644 --- a/frontend/src/types.ts +++ b/frontend/src/types.ts @@ -8,24 +8,24 @@ import { BannerType } from '@neo4j-ndl/react'; import Queue from './utils/Queue'; export interface CustomFileBase extends Partial { - processing: number | string; + processingTotalTime: number | string; status: string; - NodesCount: number; - relationshipCount: number; + nodesCount: number; + relationshipsCount: number; model: string; fileSource: string; - source_url?: string; - wiki_query?: string; + sourceUrl?: string; + wikiQuery?: string; gcsBucket?: string; gcsBucketFolder?: string; errorMessage?: string; - uploadprogess?: number; + uploadProgress?: number; processingStatus?: boolean; - google_project_id?: string; + googleProjectId?: string; language?: string; processingProgress?: number; - access_token?: string; - checked?: boolean; + accessToken?: string; + isChecked?: boolean; retryOptionStatus: boolean; retryOption: string; } @@ -45,12 +45,12 @@ export type UserCredentials = { database: string; } & { [key: string]: any }; -export interface SourceNode extends Omit { +export interface SourceNode extends Omit { fileName: string; fileSize: number; fileType: string; nodeCount?: number; - processingTime?: string; + processingTime: string; relationshipCount?: number; url?: string; awsAccessKeyId?: string; @@ -61,7 +61,7 @@ export interface SourceNode extends Omit { retry_condition?: string; } -export type ExtractParams = Pick & { +export type ExtractParams = Pick & { file?: File; aws_access_key_id?: string | null; aws_secret_access_key?: string | null; @@ -410,11 +410,18 @@ export interface duplicateNodesData extends Partial { export interface OrphanNodeResponse extends Partial { data: orphanNodeProps[]; } -export type metricdetails = { +export type metricstate = { faithfulness: number; answer_relevancy: number; - context_utilization: number; + error?: string; }; +export type metricdetails = Record; + +export interface multimodelmetric { + mode: string; + answer_relevancy: number; + faithfulness: number; +} export interface MetricsResponse extends Omit { data: metricdetails; } @@ -430,10 +437,6 @@ export interface SourceListServerData { message?: string; } -export interface MetricsState extends metricdetails { - error?: string; -} - export interface chatInfoMessage extends Partial { sources: string[]; model: string; @@ -452,19 +455,32 @@ export interface chatInfoMessage extends Partial { nodes: ExtendedNode[]; relationships: ExtendedRelationship[]; chunks: Chunk[]; - metricDetails: MetricsState | null; + metricDetails: + | { + faithfulness: number; + answer_relevancy: number; + } + | undefined; + metricError: string; infoEntities: Entity[]; communities: Community[]; infoLoading: boolean; metricsLoading: boolean; + activeChatmodes: + | { + [key: string]: ResponseMode; + } + | undefined; + multiModelMetrics: multimodelmetric[]; saveInfoEntitites: (entities: Entity[]) => void; saveNodes: (chatNodes: ExtendedNode[]) => void; saveChatRelationships: (chatRels: ExtendedRelationship[]) => void; saveChunks: (chatChunks: Chunk[]) => void; - saveMetrics: (metricInfo: MetricsState) => void; + saveMetrics: (metricInfo: metricstate) => void; saveCommunities: (chatCommunities: Community[]) => void; toggleInfoLoading: React.DispatchWithoutAction; toggleMetricsLoading: React.DispatchWithoutAction; + saveMultimodemetrics: (metrics: multimodelmetric[]) => void; } export interface eventResponsetypes extends Omit { @@ -517,6 +533,7 @@ export type Community = { level: number; community_rank: number; score?: number; + element_id: string; }; export type GroupedEntity = { texts: Set; @@ -556,6 +573,7 @@ export interface Chunk { fileSource: string; score?: string; fileType: string; + element_id: string; } export interface SpeechSynthesisProps { @@ -610,6 +628,18 @@ export interface ExtendedNode extends Node { }; } +export interface NeoNode { + element_id: string; + labels: string[]; + properties: Record; +} +export interface NeoRelationship { + element_id: string; + start_node_element_id: string; + end_node_element_id: string; + type: string; +} + export interface ExtendedRelationship extends Relationship { count?: number; } @@ -657,6 +687,9 @@ export interface S3File { export interface GraphViewButtonProps { nodeValues?: ExtendedNode[]; relationshipValues?: ExtendedRelationship[]; + fill?: 'text' | 'filled' | 'outlined'; + label: string; + viewType: string; } export interface DrawerChatbotProps { isExpanded: boolean; @@ -712,6 +745,9 @@ export type CommunitiesProps = { loading: boolean; communities: Community[]; mode: string; + + // nodeValues: ExtendedNode[]; + // relationshipValues: ExtendedRelationship[]; }; export interface entity { @@ -804,3 +840,19 @@ export type GraphPropertiesPanelProps = { inspectedItem: BasicNode | BasicRelationship; newScheme: Scheme; }; + +export type withId = { + id: string; +}; + +export interface GraphViewHandlerProps { + nodeValues?: ExtendedNode[]; + relationshipValues?: ExtendedRelationship[]; + fill?: 'text' | 'filled' | 'outlined'; + label?: string; + viewType?: string; + buttonLabel: string; + graphonly_entities?: []; + entityInfo?: Entity[]; + mode?: string; +} diff --git a/frontend/src/utils/Constants.ts b/frontend/src/utils/Constants.ts index 971227724..8ab428999 100644 --- a/frontend/src/utils/Constants.ts +++ b/frontend/src/utils/Constants.ts @@ -32,7 +32,20 @@ export const defaultLLM = llms?.includes('openai_gpt_4o') : llms?.includes('gemini_1.5_pro') ? 'gemini_1.5_pro' : 'diffbot'; -export const supportedLLmsForRagas = ['openai_gpt_3.5', 'openai_gpt_4o', 'openai_gpt_4o_mini', 'groq_llama3_70b']; +export const supportedLLmsForRagas = [ + 'openai_gpt_3.5', + 'openai_gpt_4', + 'openai_gpt_4o', + 'openai_gpt_4o_mini', + 'gemini_1.5_pro', + 'gemini_1.5_flash', + 'azure_ai_gpt_35', + 'azure_ai_gpt_4o', + 'groq_llama3_70b', + 'anthropic_claude_3_5_sonnet', + 'fireworks_llama_v3_70b', + 'bedrock_claude_3_5_sonnet', +]; export const chatModeLables = { vector: 'vector', graph: 'graph', @@ -112,6 +125,7 @@ export const tooltips = { clearChat: 'Clear Chat History', continue: 'Continue', clearGraphSettings: 'Clear configured Graph Schema', + applySettings: 'Apply Graph Schema', }; export const buttonCaptions = { @@ -135,6 +149,7 @@ export const buttonCaptions = { continueSettings: 'Continue', clearSettings: 'Clear Schema', ask: 'Ask', + applyGraphSchema: 'Apply', }; export const POST_PROCESSING_JOBS: { title: string; description: string }[] = [ diff --git a/frontend/src/utils/Utils.ts b/frontend/src/utils/Utils.ts index 27645ac7e..226fd03f9 100644 --- a/frontend/src/utils/Utils.ts +++ b/frontend/src/utils/Utils.ts @@ -332,7 +332,6 @@ export const filterData = ( filteredNodes = allNodes; filteredRelations = allRelationships; filteredScheme = scheme; - console.log('entity', filteredScheme); } return { filteredNodes, filteredRelations, filteredScheme }; }; @@ -509,3 +508,14 @@ export function downloadClickHandler( downloadLinkRef.current.click(); } } +export function getNodes(nodesData: Array, mode: string) { + return nodesData.map((n) => { + if (!n.labels.length && mode === chatModeLables.entity_vector) { + return { + ...n, + labels: ['Entity'], + }; + } + return n; + }); +}