Skip to content

Commit 2b33c73

Browse files
authored
Restore embedding model caching (#581)
* Restore embedding model caching * Remove import
1 parent b6af8f6 commit 2b33c73

File tree

3 files changed

+5
-18
lines changed

3 files changed

+5
-18
lines changed

gramps_webapi/api/search/__init__.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
from flask import current_app
2727

2828
from .indexer import SearchIndexer, SemanticSearchIndexer, SearchIndexerBase
29-
from .embeddings import embedding_function_factory
3029

3130

3231
def get_search_indexer(tree: str, semantic: bool = False) -> SearchIndexerBase:
@@ -50,15 +49,10 @@ def get_search_indexer(tree: str, semantic: bool = False) -> SearchIndexerBase:
5049
if not path.exists() and not path.parent.exists():
5150
path.parent.mkdir(parents=True, exist_ok=True)
5251
if semantic:
53-
model = current_app.config.get("VECTOR_EMBEDDING_MODEL")
52+
model = current_app.config.get("_INITIALIZED_VECTOR_EMBEDDING_MODEL")
5453
if not model:
5554
raise ValueError("VECTOR_EMBEDDING_MODEL option not set")
56-
try:
57-
embedding_function = embedding_function_factory(model)
58-
except OSError:
59-
raise ValueError(f"Failed initializing model {model}")
60-
# cache on app instance
6155
return SemanticSearchIndexer(
62-
db_url=db_url, tree=tree, embedding_function=embedding_function
56+
db_url=db_url, tree=tree, embedding_function=model.encode
6357
)
6458
return SearchIndexer(db_url=db_url, tree=tree)

gramps_webapi/api/search/embeddings.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,6 @@
33
from ..util import get_logger
44

55

6-
def embedding_function_factory(model_name: str):
7-
model = load_model(model_name)
8-
9-
def embedding_function(queries: list[str]):
10-
return model.encode(queries)
11-
12-
return embedding_function
13-
14-
156
def load_model(model_name: str):
167
"""Load the sentence transformer model.
178

gramps_webapi/app.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,9 @@ def close_user_db_connection(exception) -> None:
191191
user_db.session.remove() # pylint: disable=no-member
192192

193193
if app.config.get("VECTOR_EMBEDDING_MODEL"):
194-
load_model(app.config["VECTOR_EMBEDDING_MODEL"])
194+
app.config["_INITIALIZED_VECTOR_EMBEDDING_MODEL"] = load_model(
195+
app.config["VECTOR_EMBEDDING_MODEL"]
196+
)
195197

196198
@app.route("/ready", methods=["GET"])
197199
def ready():

0 commit comments

Comments
 (0)