Skip to content

Commit 6ad3ea1

Browse files
committed
Fix FAISS cache path for non-ASCII dataset names on Windows
1 parent a3c9bae commit 6ad3ea1

File tree

1 file changed

+37
-36
lines changed

1 file changed

+37
-36
lines changed

models/retriever/faiss_filter.py

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import os
33
import time
4+
import hashlib
45
from collections import defaultdict
56
from itertools import combinations
67
from typing import Dict, List, Set, Tuple
@@ -26,10 +27,10 @@ def __init__(self, dataset, graph: nx.MultiDiGraph, model_name: str = "all-MiniL
2627
self.cache_dir = cache_dir
2728
os.makedirs(cache_dir, exist_ok=True)
2829
self.dataset = dataset
29-
30-
# Create dataset-specific cache directory
31-
dataset_cache_dir = f"{self.cache_dir}/{self.dataset}"
32-
os.makedirs(dataset_cache_dir, exist_ok=True)
30+
# 使用 ASCII-safe 目录名避免 Windows + FAISS 处理中文路径失败
31+
safe_suffix = hashlib.md5(str(dataset).encode("utf-8")).hexdigest()[:8]
32+
self.dataset_cache_dir = os.path.join(self.cache_dir, safe_suffix)
33+
os.makedirs(self.dataset_cache_dir, exist_ok=True)
3334

3435
self.triple_index = None
3536
self.comm_index = None
@@ -551,7 +552,7 @@ def clear_embedding_cache(self, max_cache_size: int = 10000):
551552

552553
def save_embedding_cache(self):
553554
"""Save embedding cache to disk using numpy format to avoid pickle issues"""
554-
cache_path = f"{self.cache_dir}/{self.dataset}/node_embedding_cache.pt"
555+
cache_path = f"{self.dataset_cache_dir}/node_embedding_cache.pt"
555556
try:
556557
if not self.node_embedding_cache:
557558
return False
@@ -601,7 +602,7 @@ def save_embedding_cache(self):
601602

602603
def load_embedding_cache(self):
603604
"""从磁盘加载嵌入缓存"""
604-
cache_path = f"{self.cache_dir}/{self.dataset}/node_embedding_cache.pt"
605+
cache_path = f"{self.dataset_cache_dir}/node_embedding_cache.pt"
605606
if os.path.exists(cache_path):
606607
try:
607608
file_size = os.path.getsize(cache_path)
@@ -787,14 +788,14 @@ def _precompute_node_embeddings(self, batch_size: int = 100, force_recompute: bo
787788
def build_indices(self):
788789
"""Build FAISS Index only if they don't already exist and are consistent with current graph"""
789790
# Check if all indices and embedding files already exist
790-
node_path = f"{self.cache_dir}/{self.dataset}/node.index"
791-
relation_path = f"{self.cache_dir}/{self.dataset}/relation.index"
792-
triple_path = f"{self.cache_dir}/{self.dataset}/triple.index"
793-
comm_path = f"{self.cache_dir}/{self.dataset}/comm.index"
794-
node_embed_path = f"{self.cache_dir}/{self.dataset}/node_embeddings.pt"
795-
relation_embed_path = f"{self.cache_dir}/{self.dataset}/relation_embeddings.pt"
796-
node_map_path = f"{self.cache_dir}/{self.dataset}/node_map.json"
797-
dim_transform_path = f"{self.cache_dir}/{self.dataset}/dim_transform.pt"
791+
node_path = f"{self.dataset_cache_dir}/node.index"
792+
relation_path = f"{self.dataset_cache_dir}/relation.index"
793+
triple_path = f"{self.dataset_cache_dir}/triple.index"
794+
comm_path = f"{self.dataset_cache_dir}/comm.index"
795+
node_embed_path = f"{self.dataset_cache_dir}/node_embeddings.pt"
796+
relation_embed_path = f"{self.dataset_cache_dir}/relation_embeddings.pt"
797+
node_map_path = f"{self.dataset_cache_dir}/node_map.json"
798+
dim_transform_path = f"{self.dataset_cache_dir}/dim_transform.pt"
798799

799800
all_exist = (os.path.exists(node_path) and
800801
os.path.exists(relation_path) and
@@ -883,7 +884,7 @@ def build_indices(self):
883884

884885
def _save_dim_transform(self):
885886
"""Save dimension transform state to disk"""
886-
dim_transform_path = f"{self.cache_dir}/{self.dataset}/dim_transform.pt"
887+
dim_transform_path = f"{self.dataset_cache_dir}/dim_transform.pt"
887888
try:
888889
save_data = {
889890
'model_dim': self.model_dim,
@@ -901,7 +902,7 @@ def _save_dim_transform(self):
901902

902903
def _load_dim_transform(self):
903904
"""Load dimension transform state from disk"""
904-
dim_transform_path = f"{self.cache_dir}/{self.dataset}/dim_transform.pt"
905+
dim_transform_path = f"{self.dataset_cache_dir}/dim_transform.pt"
905906
if not os.path.exists(dim_transform_path):
906907
return False
907908

@@ -956,7 +957,7 @@ def _build_node_index(self):
956957
# Store embeddings on CPU to save GPU memory
957958
self.node_embeddings = embeddings.cpu()
958959
# Save as .pt for consistency across the codebase
959-
torch.save(self.node_embeddings, f"{self.cache_dir}/{self.dataset}/node_embeddings.pt")
960+
torch.save(self.node_embeddings, f"{self.dataset_cache_dir}/node_embeddings.pt")
960961

961962
# Build FAISS index
962963
embeddings_np = embeddings.cpu().numpy()
@@ -965,9 +966,9 @@ def _build_node_index(self):
965966
faiss.normalize_L2(embeddings_np)
966967
index.add(embeddings_np)
967968

968-
faiss.write_index(index, f"{self.cache_dir}/{self.dataset}/node.index")
969+
faiss.write_index(index, f"{self.dataset_cache_dir}/node.index")
969970
self.node_map = {str(i): n for i, n in enumerate(nodes)}
970-
with open(f"{self.cache_dir}/{self.dataset}/node_map.json", 'w') as f:
971+
with open(f"{self.dataset_cache_dir}/node_map.json", 'w') as f:
971972
json.dump(self.node_map, f)
972973

973974
self.node_index = index
@@ -983,7 +984,7 @@ def _build_relation_index(self):
983984
# Store embeddings on CPU
984985
self.relation_embeddings = embeddings.cpu()
985986
# Save as .pt for consistency across the codebase
986-
torch.save(self.relation_embeddings, f"{self.cache_dir}/{self.dataset}/relation_embeddings.pt")
987+
torch.save(self.relation_embeddings, f"{self.dataset_cache_dir}/relation_embeddings.pt")
987988

988989
# Build FAISS index
989990
embeddings_np = embeddings.cpu().numpy()
@@ -992,9 +993,9 @@ def _build_relation_index(self):
992993
faiss.normalize_L2(embeddings_np)
993994
index.add(embeddings_np)
994995

995-
faiss.write_index(index, f"{self.cache_dir}/{self.dataset}/relation.index")
996+
faiss.write_index(index, f"{self.dataset_cache_dir}/relation.index")
996997
self.relation_map = {str(i): r for i, r in enumerate(relations)}
997-
with open(f"{self.cache_dir}/{self.dataset}/relation_map.json", 'w') as f:
998+
with open(f"{self.dataset_cache_dir}/relation_map.json", 'w') as f:
998999
json.dump(self.relation_map, f)
9991000

10001001
self.relation_index = index
@@ -1014,8 +1015,8 @@ def _build_triple_index(self):
10141015
faiss.normalize_L2(embeddings)
10151016
index.add(embeddings)
10161017

1017-
faiss.write_index(index, f"{self.cache_dir}/{self.dataset}/triple.index")
1018-
with open(f"{self.cache_dir}/{self.dataset}/triple_map.json", 'w') as f:
1018+
faiss.write_index(index, f"{self.dataset_cache_dir}/triple.index")
1019+
with open(f"{self.dataset_cache_dir}/triple_map.json", 'w') as f:
10191020
json.dump({i: n for i, n in enumerate(triples)}, f)
10201021

10211022
self.triple_index = index
@@ -1050,21 +1051,21 @@ def _build_community_index(self):
10501051
faiss.normalize_L2(embeddings)
10511052
index.add(embeddings)
10521053

1053-
faiss.write_index(index, f"{self.cache_dir}/{self.dataset}/comm.index")
1054-
with open(f"{self.cache_dir}/{self.dataset}/comm_map.json", 'w') as f:
1054+
faiss.write_index(index, f"{self.dataset_cache_dir}/comm.index")
1055+
with open(f"{self.dataset_cache_dir}/comm_map.json", 'w') as f:
10551056
json.dump({i: n for i, n in enumerate(valid_communities)}, f)
10561057

10571058
self.comm_index = index
10581059
self.comm_map = {str(i): n for i, n in enumerate(valid_communities)}
10591060

10601061
def _load_indices(self):
10611062
logger.info("Starting _load_indices...")
1062-
triple_path = f"{self.cache_dir}/{self.dataset}/triple.index"
1063-
comm_path = f"{self.cache_dir}/{self.dataset}/comm.index"
1064-
node_path = f"{self.cache_dir}/{self.dataset}/node.index"
1065-
relation_path = f"{self.cache_dir}/{self.dataset}/relation.index"
1066-
node_embed_path = f"{self.cache_dir}/{self.dataset}/node_embeddings.pt"
1067-
relation_embed_path = f"{self.cache_dir}/{self.dataset}/relation_embeddings.pt"
1063+
triple_path = f"{self.dataset_cache_dir}/triple.index"
1064+
comm_path = f"{self.dataset_cache_dir}/comm.index"
1065+
node_path = f"{self.dataset_cache_dir}/node.index"
1066+
relation_path = f"{self.dataset_cache_dir}/relation.index"
1067+
node_embed_path = f"{self.dataset_cache_dir}/node_embeddings.pt"
1068+
relation_embed_path = f"{self.dataset_cache_dir}/relation_embeddings.pt"
10681069

10691070
logger.debug(f"Checking cache files...")
10701071
logger.debug(f"node_path exists: {os.path.exists(node_path)}")
@@ -1077,22 +1078,22 @@ def _load_indices(self):
10771078
if os.path.exists(node_path):
10781079
logger.debug("Loading node index...")
10791080
self.node_index = faiss.read_index(node_path)
1080-
with open(f"{self.cache_dir}/{self.dataset}/node_map.json", 'r') as f:
1081+
with open(f"{self.dataset_cache_dir}/node_map.json", 'r') as f:
10811082
self.node_map = json.load(f)
10821083

10831084
if os.path.exists(relation_path):
10841085
self.relation_index = faiss.read_index(relation_path)
1085-
with open(f"{self.cache_dir}/{self.dataset}/relation_map.json", 'r') as f:
1086+
with open(f"{self.dataset_cache_dir}/relation_map.json", 'r') as f:
10861087
self.relation_map = json.load(f)
10871088

10881089
if os.path.exists(triple_path):
10891090
self.triple_index = faiss.read_index(triple_path)
1090-
with open(f"{self.cache_dir}/{self.dataset}/triple_map.json", 'r') as f:
1091+
with open(f"{self.dataset_cache_dir}/triple_map.json", 'r') as f:
10911092
self.triple_map = json.load(f)
10921093

10931094
if os.path.exists(comm_path):
10941095
self.comm_index = faiss.read_index(comm_path)
1095-
with open(f"{self.cache_dir}/{self.dataset}/comm_map.json", 'r') as f:
1096+
with open(f"{self.dataset_cache_dir}/comm_map.json", 'r') as f:
10961097
self.comm_map = json.load(f)
10971098

10981099
if os.path.exists(node_embed_path):

0 commit comments

Comments
 (0)