Skip to content

Commit 62d0db3

Browse files
authored
add no grad to retrieval SentenceTransformer calls (#10494)
1 parent 5e8f244 commit 62d0db3

File tree

2 files changed

+4
-2
lines changed

2 files changed

+4
-2
lines changed

torch_geometric/llm/utils/backend_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,8 @@ def apply_retrieval_via_pcst(
408408
:return: Retrieved graph/query data
409409
"""
410410
# PCST relies on numpy and pcst_fast pypi libs, hence to("cpu")
411-
q_emb = model.encode([query]).to("cpu")
411+
with torch.no_grad():
412+
q_emb = model.encode([query]).to("cpu")
412413
textual_nodes = [(int(i), full_textual_nodes[i])
413414
for i in graph["node_idx"]]
414415
textual_nodes = DataFrame(textual_nodes,

torch_geometric/llm/utils/vectorrag.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,8 @@ def query(self, query: Union[str, Tensor]) -> List[str]:
6565
List[str]: Documents retrieved from the vector database.
6666
"""
6767
if isinstance(query, str):
68-
query_enc = self.encoder(query, **self.model_kwargs)
68+
with torch.no_grad():
69+
query_enc = self.encoder(query, **self.model_kwargs)
6970
else:
7071
query_enc = query
7172

0 commit comments

Comments
 (0)