Skip to content

Commit cd99dc5

Browse files
authored
fix: nebula multi-embedding & add BochaAI Search Retriever (#195)
* fix: fine-search bug * feat: multi embedding dimension in nebula * feat: modify multi-dimension * fix: multi-embedding error * feat: update nebula example * fix: print bug * feat: add bocha search * feat: modify bocha search * feat: finish bocha search
1 parent 2dcdf75 commit cd99dc5

File tree

7 files changed

+493
-171
lines changed

7 files changed

+493
-171
lines changed

examples/basic_modules/nebular_example.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def show(nebular_data):
2323

2424
tree_config = Neo4jGraphDBConfig.from_json_file("../../examples/data/config/neo4j_config.json")
2525
tree_config.use_multi_db = True
26-
tree_config.db_name = "nebular-show"
26+
tree_config.db_name = "nebular-show2"
2727

2828
neo4j_db = Neo4jGraphDB(tree_config)
2929
neo4j_db.clear()
@@ -42,6 +42,7 @@ def show(nebular_data):
4242
}
4343
)
4444
embedder = EmbedderFactory.from_config(embedder_config)
45+
embedder_dimension = 3072
4546

4647

4748
def embed_memory_item(memory: str) -> list[float]:
@@ -62,7 +63,7 @@ def example_multi_db(db_name: str = "paper"):
6263
"space": db_name,
6364
"use_multi_db": True,
6465
"auto_create": True,
65-
"embedding_dimension": 3072,
66+
"embedding_dimension": embedder_dimension,
6667
},
6768
)
6869

@@ -121,7 +122,7 @@ def example_shared_db(db_name: str = "shared-traval-group"):
121122
"user_name": user_name,
122123
"use_multi_db": False,
123124
"auto_create": True,
124-
"embedding_dimension": 3072,
125+
"embedding_dimension": embedder_dimension,
125126
},
126127
)
127128

@@ -208,7 +209,7 @@ def example_shared_db(db_name: str = "shared-traval-group"):
208209
"space": db_name,
209210
"user_name": user_list[0],
210211
"auto_create": True,
211-
"embedding_dimension": 3072,
212+
"embedding_dimension": embedder_dimension,
212213
"use_multi_db": False,
213214
},
214215
)
@@ -238,7 +239,7 @@ def run_user_session(
238239
"user_name": user_name,
239240
"use_multi_db": False,
240241
"auto_create": True,
241-
"embedding_dimension": 3072,
242+
"embedding_dimension": embedder_dimension,
242243
},
243244
)
244245
graph = GraphStoreFactory.from_config(config)
@@ -404,10 +405,10 @@ def example_complex_shared_db(db_name: str = "shared-traval-group-complex"):
404405

405406
if __name__ == "__main__":
406407
print("\n=== Example: Multi-DB ===")
407-
example_multi_db(db_name="paper")
408+
example_multi_db(db_name="paper-new")
408409

409410
print("\n=== Example: Single-DB ===")
410-
example_shared_db(db_name="shared_traval_group")
411+
example_shared_db(db_name="shared_traval_group-new")
411412

412413
print("\n=== Example: Single-DB-Complex ===")
413-
example_complex_shared_db(db_name="shared-traval-group-complex-new11")
414+
example_complex_shared_db(db_name="shared-traval-group-complex-new2")
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
"""
2+
Example: Using InternetRetrieverFactory with BochaAISearchRetriever
3+
"""
4+
5+
from memos.configs.embedder import EmbedderConfigFactory
6+
from memos.configs.internet_retriever import InternetRetrieverConfigFactory
7+
from memos.embedders.factory import EmbedderFactory
8+
from memos.memories.textual.tree_text_memory.retrieve.internet_retriever_factory import (
9+
InternetRetrieverFactory,
10+
)
11+
12+
13+
# ========= 1. Create an embedder =========
14+
embedder_config = EmbedderConfigFactory.model_validate(
15+
{
16+
"backend": "ollama", # Or "sentence_transformer", etc.
17+
"config": {
18+
"model_name_or_path": "nomic-embed-text:latest",
19+
},
20+
}
21+
)
22+
embedder = EmbedderFactory.from_config(embedder_config)
23+
24+
# ========= 2. Create retriever config for BochaAI =========
25+
retriever_config = InternetRetrieverConfigFactory.model_validate(
26+
{
27+
"backend": "bocha",
28+
"config": {
29+
"api_key": "sk-xxx", # Your BochaAI API Key
30+
"max_results": 5,
31+
"reader": { # Reader config for chunking web content
32+
"backend": "simple_struct",
33+
"config": { # your simple struct reader config
34+
},
35+
},
36+
},
37+
}
38+
)
39+
40+
# ========= 3. Build retriever instance via factory =========
41+
retriever = InternetRetrieverFactory.from_config(retriever_config, embedder)
42+
43+
# ========= 4. Run BochaAI Web Search =========
44+
print("=== Scenario 1: Web Search (BochaAI) ===")
45+
query_web = "Alibaba 2024 ESG report"
46+
results_web = retriever.retrieve_from_internet(query_web)
47+
48+
print(f"Retrieved {len(results_web)} memory items.")
49+
for idx, item in enumerate(results_web, 1):
50+
print(f"[{idx}] {item.memory[:500]}...") # preview first 100 chars
51+
52+
print("==" * 20)

examples/core_memories/tree_textual_memory.py

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,15 @@
11
import time
22

33
from memos import log
4-
from memos.configs.embedder import EmbedderConfigFactory
54
from memos.configs.mem_reader import SimpleStructMemReaderConfig
65
from memos.configs.memory import TreeTextMemoryConfig
7-
from memos.embedders.factory import EmbedderFactory
86
from memos.mem_reader.simple_struct import SimpleStructMemReader
97
from memos.memories.textual.tree import TreeTextMemory
108

119

1210
logger = log.get_logger(__name__)
1311

1412

15-
embedder_config = EmbedderConfigFactory.model_validate(
16-
{
17-
"backend": "ollama",
18-
"config": {
19-
"model_name_or_path": "nomic-embed-text:latest",
20-
},
21-
}
22-
)
23-
embedder = EmbedderFactory.from_config(embedder_config)
24-
25-
26-
def embed_memory_item(memory: str) -> list[float]:
27-
return embedder.embed([memory])[0]
28-
29-
3013
tree_config = TreeTextMemoryConfig.from_json_file(
3114
"examples/data/config/tree_config_shared_database.json"
3215
)

src/memos/configs/internet_retriever.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,18 @@ class XinyuSearchConfig(BaseInternetRetrieverConfig):
5555
)
5656

5757

58+
class BochaSearchConfig(BaseInternetRetrieverConfig):
59+
"""Configuration class for Bocha Search API."""
60+
61+
max_results: int = Field(default=20, description="Maximum number of results to retrieve")
62+
num_per_request: int = Field(default=10, description="Number of results per API request")
63+
reader: MemReaderConfigFactory = Field(
64+
...,
65+
default_factory=MemReaderConfigFactory,
66+
description="Reader configuration",
67+
)
68+
69+
5870
class InternetRetrieverConfigFactory(BaseConfig):
5971
"""Factory class for creating internet retriever configurations."""
6072

@@ -69,6 +81,7 @@ class InternetRetrieverConfigFactory(BaseConfig):
6981
"google": GoogleCustomSearchConfig,
7082
"bing": BingSearchConfig,
7183
"xinyu": XinyuSearchConfig,
84+
"bocha": BochaSearchConfig,
7285
}
7386

7487
@field_validator("backend")

0 commit comments

Comments
 (0)