Skip to content

Commit ad4223f

Browse files
committed
Merge branch 'main' into sandbox-env
2 parents 1e67f56 + 3b01383 commit ad4223f

File tree

24 files changed

+2675
-7
lines changed

24 files changed

+2675
-7
lines changed

src/bohrium/_base_client.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,33 @@ def _request(
185185
logger.info(f"Requesting {method} {url}")
186186
merged_headers = self._build_headers(headers)
187187
merged_params = self._build_params(kwargs.get("params"))
188+
189+
# 处理文件上传
190+
request_kwargs = {
191+
"method": method.upper(),
192+
"url": url,
193+
"params": merged_params,
194+
}
195+
196+
# 处理超时参数
197+
if "timeout" in kwargs:
198+
request_kwargs["timeout"] = kwargs["timeout"]
199+
200+
if json is not None:
201+
request_kwargs["json"] = json
202+
request_kwargs["headers"] = merged_headers
203+
elif "files" in kwargs:
204+
# 当有files参数时,不使用json参数,而是使用files和data
205+
# 不设置headers,让httpx自动处理multipart/form-data
206+
request_kwargs["files"] = kwargs["files"]
207+
if "data" in kwargs:
208+
request_kwargs["data"] = kwargs["data"]
209+
elif "data" in kwargs:
210+
request_kwargs["data"] = kwargs["data"]
211+
request_kwargs["headers"] = merged_headers
212+
else:
213+
request_kwargs["headers"] = merged_headers
214+
188215
try:
189216
return self._client.request(
190217
method.upper(),

src/bohrium/_client.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@
1414

1515
class Bohrium(SyncAPIClient):
1616
job: resources.Job
17+
sigma_search: resources.SigmaSearch
18+
uni_parser: resources.UniParser
19+
knowledge_base: resources.KnowledgeBase
20+
paper: resources.Paper
1721

1822
# client options
1923
access_key: str
20-
project_id: Union[str, None]
24+
project_id: Optional[str]
2125

2226
def __init__(
2327
self,
@@ -42,11 +46,6 @@ def __init__(
4246
if project_id is None:
4347
project_id = os.environ.get("BOHRIUM_PROJECT_ID")
4448

45-
if project_id is None:
46-
raise BohriumError(
47-
"The project_id client option must be set either by passing project_id to the client or by setting the BOHRIUM_PROJECT_ID environment variable"
48-
)
49-
5049
self.project_id = project_id
5150

5251
if base_url is None:
@@ -65,6 +64,10 @@ def __init__(
6564
)
6665

6766
self.job = resources.Job(self)
67+
self.sigma_search = resources.SigmaSearch(self)
68+
self.uni_parser = resources.UniParser(self)
69+
self.knowledge_base = resources.KnowledgeBase(self)
70+
self.paper = resources.Paper(self)
6871

6972
@property
7073
@override

src/bohrium/resources/__init__.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,16 @@
11

22
from .job import Job, AsyncJob
3+
from .sigma_search import SigmaSearch, AsyncSigmaSearch
4+
from .uni_parser import UniParser, AsyncUniParser
5+
from .knowledge_base import KnowledgeBase, AsyncKnowledgeBase
6+
from .paper import Paper, AsyncPaper
37
from .tiefblue import Tiefblue
4-
__all__ = ["Job", "AsyncJob", "Tiefblue"]
8+
9+
__all__ = [
10+
"Job", "AsyncJob", "Tiefblue"
11+
"SigmaSearch", "AsyncSigmaSearch",
12+
"UniParser", "AsyncUniParser",
13+
"KnowledgeBase", "AsyncKnowledgeBase",
14+
"Paper", "AsyncPaper"
15+
]
16+
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .knowledge_base import KnowledgeBase, AsyncKnowledgeBase
2+
3+
__all__ = ["KnowledgeBase", "AsyncKnowledgeBase"]
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import logging
2+
from typing import Optional, List, Dict, Any, Union
3+
from pprint import pprint
4+
5+
from ..._resource import AsyncAPIResource, SyncAPIResource
6+
from ..._response import APIResponse
7+
from ...types.knowledge_base.knowledge_base import (
8+
HybridRecallRequest,
9+
PaperRecallRequest,
10+
PaperInfo,
11+
ChunkSearchRequest
12+
)
13+
14+
log = logging.getLogger(__name__)
15+
16+
17+
class KnowledgeBase(SyncAPIResource):
18+
"""知识库相关接口"""
19+
20+
def hybrid_recall(
21+
self,
22+
knowledge_base_id: int,
23+
text: str,
24+
k: int = 200,
25+
keywords: Optional[Dict[str, float]] = None,
26+
**kwargs
27+
):
28+
"""知识库混合召回"""
29+
log.info(f"hybrid recall from knowledge base: {knowledge_base_id}")
30+
31+
data = {
32+
"knowledge_base_id": knowledge_base_id,
33+
"text": text,
34+
"k": k
35+
}
36+
37+
if keywords:
38+
data["keywords"] = keywords
39+
if kwargs:
40+
data.update(kwargs)
41+
42+
response = self._client.post("/openapi/v1/knowledge/recall/hybrid", json=data)
43+
log.info(response.json())
44+
return APIResponse(response).json.get("data")
45+
46+
def paper_recall(
47+
self,
48+
text: str,
49+
k: int,
50+
papers: List[Dict[str, str]],
51+
**kwargs
52+
):
53+
"""单篇论文召回"""
54+
log.info(f"paper recall: {len(papers)} papers")
55+
56+
data = {
57+
"text": text,
58+
"k": k,
59+
"papers": papers
60+
}
61+
62+
if kwargs:
63+
data.update(kwargs)
64+
65+
response = self._client.post("/openapi/v1/knowledge/recall/papers", json=data)
66+
log.info(response.json())
67+
return APIResponse(response).json.get("data")
68+
69+
def get_file_tree(
70+
self,
71+
folder_id: str,
72+
**kwargs
73+
):
74+
"""获取单篇切片文件树"""
75+
log.info(f"get file tree for folder: {folder_id}")
76+
77+
params = {"folderId": folder_id}
78+
if kwargs:
79+
params.update(kwargs)
80+
81+
response = self._client.get("/openapi/v1/knowledge/folder/file_tree", params=params)
82+
log.info(response.json())
83+
return APIResponse(response).json.get("data")
84+
85+
def search_by_md5_paper_id(
86+
self,
87+
md5: str,
88+
paper_id: str = "",
89+
page_num: int = 0,
90+
page_size: int = 9999,
91+
**kwargs
92+
):
93+
"""根据md5和paper_id搜索chunk信息"""
94+
log.info(f"search chunk by md5: {md5}, paper_id: {paper_id}")
95+
96+
data = {
97+
"md5": md5,
98+
"paper_id": paper_id,
99+
"page_num": page_num,
100+
"page_size": page_size
101+
}
102+
103+
if kwargs:
104+
data.update(kwargs)
105+
106+
response = self._client.post("/openapi/v1/knowledge/box/search_by_md5_paper_id", json=data)
107+
log.info(response.json())
108+
return APIResponse(response).json.get("data")
109+
110+
111+
112+
class AsyncKnowledgeBase(AsyncAPIResource):
113+
"""异步知识库相关接口"""
114+
pass
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .paper import Paper, AsyncPaper
2+
3+
__all__ = ["Paper", "AsyncPaper"]
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import logging
2+
from typing import Optional, List, Dict, Any
3+
from pprint import pprint
4+
5+
from ..._resource import AsyncAPIResource, SyncAPIResource
6+
from ..._response import APIResponse
7+
from ...types.paper.paper import PaperRAGRequest
8+
9+
log = logging.getLogger(__name__)
10+
11+
12+
class Paper(SyncAPIResource):
13+
"""论文相关接口"""
14+
15+
def rag_pass_keyword(
16+
self,
17+
type: int,
18+
rerank: int,
19+
question: str,
20+
page_size: int,
21+
words: Optional[List[str]] = None,
22+
start_time: Optional[str] = None,
23+
end_time: Optional[str] = None,
24+
**kwargs
25+
):
26+
"""论文RAG关键词检索"""
27+
log.info(f"paper rag pass keyword: type={type}, rerank={rerank}")
28+
29+
data = {
30+
"type": type,
31+
"rerank": rerank,
32+
"question": question,
33+
"pageSize": page_size
34+
}
35+
36+
if words:
37+
data["words"] = words
38+
if start_time:
39+
data["startTime"] = start_time
40+
if end_time:
41+
data["endTime"] = end_time
42+
if kwargs:
43+
data.update(kwargs)
44+
45+
response = self._client.post("/openapi/v1/paper/rag/pass/keyword", json=data)
46+
log.info(response.json())
47+
return APIResponse(response).json.get("data")
48+
49+
50+
class AsyncPaper(AsyncAPIResource):
51+
"""异步论文相关接口"""
52+
pass
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .sigma_search import SigmaSearch, AsyncSigmaSearch
2+
3+
__all__ = ["SigmaSearch", "AsyncSigmaSearch"]

0 commit comments

Comments
 (0)