diff --git a/model_loading.py b/model_loading.py new file mode 100644 index 0000000..e9f009f --- /dev/null +++ b/model_loading.py @@ -0,0 +1,4 @@ +import tensorflow_hub as hub + + +EMBEDDING_MODEL = hub.load("https://tfhub.dev/google/universal-sentence-encoder-large/5") diff --git a/utils.py b/utils.py index de0be3a..8513fb9 100644 --- a/utils.py +++ b/utils.py @@ -7,10 +7,7 @@ from typing import Dict, Any, List from sklearn.metrics.pairwise import cosine_similarity from sklearn.cluster import KMeans -import tensorflow_hub as hub - - -embedding_model = hub.load("https://tfhub.dev/google/universal-sentence-encoder-large/5") +from model_loading import EMBEDDING_MODEL def open_file(filepath): @@ -44,7 +41,7 @@ def save_message(root_folder, message: Dict[str, Any]): def search_tree(root_folder, query): # TODO add a "forks" parameter to allow for branching relevance # TODO add a "fuzziness" parameter that can generate a random vector to modify the search query - query_embedding = embedding_model([query]).numpy() + query_embedding = EMBEDDING_MODEL([query]).numpy() level = 6 taxonomy = [] @@ -147,7 +144,7 @@ def process_missing_messages(root_folder: str): timestamp = file2_data['timestamp'] combined_text = context + " --- " + response - embedding = embedding_model([combined_text]).numpy().tolist() + embedding = EMBEDDING_MODEL([combined_text]).numpy().tolist() message_pair_data = { 'content': combined_text, @@ -179,7 +176,7 @@ def create_summaries(root_folder: str, clusters: List[List[str]], target_folder: summary = quick_summarize(combined_content) # Create embedding for summary - summary_embedding = embedding_model([summary]).numpy().tolist() + summary_embedding = EMBEDDING_MODEL([summary]).numpy().tolist() # Save summary in target folder summary_data = { @@ -275,7 +272,7 @@ def integrate_new_elements(root_folder: str, target_folder: str, new_elements: L combined_content = closest_file_data["content"] + " --- " + new_element_data["content"] updated_summary = quick_summarize(combined_content) - updated_summary_embedding = embedding_model([updated_summary]).numpy().tolist() + updated_summary_embedding = EMBEDDING_MODEL([updated_summary]).numpy().tolist() closest_file_data["content"] = updated_summary closest_file_data["vector"] = updated_summary_embedding @@ -286,7 +283,7 @@ def integrate_new_elements(root_folder: str, target_folder: str, new_elements: L # Create a new summary for the new_element combined_content = new_element_data["content"] new_summary = quick_summarize(combined_content) - new_summary_embedding = embedding_model([new_summary]).numpy().tolist() + new_summary_embedding = EMBEDDING_MODEL([new_summary]).numpy().tolist() new_summary_data = { "content": new_summary,