Skip to content

Commit 71b4c53

Browse files
authored
[Paddle-pipelines] Update pipelines examples & update markdown splitters (#6717)
* Update pipelines examples * Update splitter
1 parent bc8df6e commit 71b4c53

File tree

6 files changed

+74
-26
lines changed

6 files changed

+74
-26
lines changed

pipelines/examples/chatbot/chat_markdown_example.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
parser.add_argument('--title_split', default=False, type=bool, help='the markdown file is split by titles')
5050
parser.add_argument("--api_key", default=None, type=str, help="The API Key.")
5151
parser.add_argument("--secret_key", default=None, type=str, help="The secret key.")
52+
parser.add_argument('--indexing', default=False, type=bool, help='Whether indexing is enabled.')
5253
args = parser.parse_args()
5354
# yapf: enable
5455

@@ -97,13 +98,15 @@ def chat_markdown_tutorial():
9798
text_splitter = CharacterTextSplitter(
9899
separator="\n", chunk_size=args.chunk_size, chunk_overlap=0, filters=["\n"]
99100
)
100-
indexing_pipeline = Pipeline()
101-
indexing_pipeline.add_node(component=markdown_converter, name="MarkdownConverter", inputs=["File"])
102-
indexing_pipeline.add_node(component=text_splitter, name="Splitter", inputs=["MarkdownConverter"])
103-
indexing_pipeline.add_node(component=retriever, name="Retriever", inputs=["Splitter"])
104-
indexing_pipeline.add_node(component=document_store, name="DocumentStore", inputs=["Retriever"])
105-
files = glob.glob(args.file_paths + "/**/*.md", recursive=True)
106-
indexing_pipeline.run(file_paths=files)
101+
102+
if args.indexing:
103+
indexing_pipeline = Pipeline()
104+
indexing_pipeline.add_node(component=markdown_converter, name="MarkdownConverter", inputs=["File"])
105+
indexing_pipeline.add_node(component=text_splitter, name="Splitter", inputs=["MarkdownConverter"])
106+
indexing_pipeline.add_node(component=retriever, name="Retriever", inputs=["Splitter"])
107+
indexing_pipeline.add_node(component=document_store, name="DocumentStore", inputs=["Retriever"])
108+
files = glob.glob(args.file_paths + "/**/*.md", recursive=True)
109+
indexing_pipeline.run(file_paths=files)
107110

108111
# Query Markdowns
109112
ernie_bot = ErnieBot(api_key=args.api_key, secret_key=args.secret_key)

pipelines/examples/chatbot/chat_markdown_multi_recall_example.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,7 @@
6464
parser.add_argument("--es_chunk_size", default=500, type=int, help="Number of docs in one chunk sent to es")
6565
parser.add_argument("--es_thread_count", default=32, type=int, help="Size of the threadpool to use for the bulk requests")
6666
parser.add_argument("--es_queue_size", default=32, type=int, help="Size of the task queue between the main thread (producing chunks to send) and the processing threads.")
67+
parser.add_argument('--indexing', default=False, type=bool, help='Whether indexing is enabled.')
6768
args = parser.parse_args()
6869
# yapf: enable
6970

@@ -120,15 +121,16 @@ def chat_markdown_tutorial():
120121
text_splitter = CharacterTextSplitter(
121122
separator="\n", chunk_size=args.data_chunk_size, chunk_overlap=0, filters=["\n"]
122123
)
123-
indexing_pipeline = Pipeline()
124-
indexing_pipeline.add_node(component=markdown_converter, name="MarkdownConverter", inputs=["File"])
125-
indexing_pipeline.add_node(component=text_splitter, name="Splitter", inputs=["MarkdownConverter"])
126-
indexing_pipeline.add_node(component=retriever, name="Retriever", inputs=["Splitter"])
127-
indexing_pipeline.add_node(component=document_store, name="DocumentStore", inputs=["Retriever"])
128-
files = glob.glob(args.file_paths + "/**/*.md", recursive=True)
129-
if len(files) == 0:
130-
raise Exception("file should not be empty")
131-
indexing_pipeline.run(file_paths=files)
124+
if args.indexing:
125+
indexing_pipeline = Pipeline()
126+
indexing_pipeline.add_node(component=markdown_converter, name="MarkdownConverter", inputs=["File"])
127+
indexing_pipeline.add_node(component=text_splitter, name="Splitter", inputs=["MarkdownConverter"])
128+
indexing_pipeline.add_node(component=retriever, name="Retriever", inputs=["Splitter"])
129+
indexing_pipeline.add_node(component=document_store, name="DocumentStore", inputs=["Retriever"])
130+
files = glob.glob(args.file_paths + "/**/*.md", recursive=True)
131+
if len(files) == 0:
132+
raise Exception("file should not be empty")
133+
indexing_pipeline.run(file_paths=files)
132134

133135
# Query Markdowns
134136
if args.chatbot in ["ernie_bot"]:
@@ -150,7 +152,7 @@ def chat_markdown_tutorial():
150152
component=TruncatedConversationHistory(max_length=256), name="TruncateHistory", inputs=["Template"]
151153
)
152154
query_pipeline.add_node(component=ernie_bot, name="ErnieBot", inputs=["TruncateHistory"])
153-
query = "Aistudio最火的项目是哪个?"
155+
query = "理财产品的认购期是多久?"
154156
start_time = time.time()
155157
prediction = query_pipeline.run(query=query, params={"DenseRetriever": {"top_k": 10}, "Ranker": {"top_k": 5}})
156158
end_time = time.time()

pipelines/pipelines/nodes/preprocessor/text_splitter.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -391,7 +391,6 @@ def split_text(
391391
# header_stack: List[Dict[str, Union[int, str]]] = []
392392
header_stack: List[HeaderType] = []
393393
initial_metadata: Dict[str, str] = {}
394-
395394
for line in lines:
396395
stripped_line = line.strip()
397396
# Check each line against each of the header types (e.g., #, ##)
@@ -495,9 +494,9 @@ def _merge_splits(
495494
# We now want to combine these smaller pieces into medium size
496495
# chunks to send to the LLM.
497496
if chunk_size is None:
498-
chunk_size = self.chunk_size
497+
chunk_size = self._chunk_size
499498
if chunk_overlap is None:
500-
chunk_overlap = self.chunk_overlap
499+
chunk_overlap = self._chunk_overlap
501500
if separator is None:
502501
separator = self._separator
503502
separator_len = self._length_function(separator)

pipelines/pipelines/utils/preprocessing.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ def convert_files_to_dicts_splitter(
177177
separator=separator,
178178
chunk_size=chunk_size,
179179
headers_to_split_on=headers_to_split_on,
180-
return_each_line=False,
180+
return_each_line=True,
181181
filters=filters,
182182
)
183183
if language == "chinese":

pipelines/requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,5 @@ boilerpy3
2929
events
3030
sseclient-py==1.7.2
3131
typing_extensions==4.5
32-
spacy
32+
spacy
33+
tritonclient[all]

pipelines/utils/offline_ann.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,11 @@
1414

1515
import argparse
1616

17-
from pipelines.document_stores import ElasticsearchDocumentStore, MilvusDocumentStore
17+
from pipelines.document_stores import (
18+
BaiduElasticsearchDocumentStore,
19+
ElasticsearchDocumentStore,
20+
MilvusDocumentStore,
21+
)
1822
from pipelines.nodes import DensePassageRetriever
1923
from pipelines.utils import convert_files_to_dicts, fetch_archive_from_http, launch_es
2024
from pipelines.utils.preprocessing import convert_files_to_dicts_splitter
@@ -30,7 +34,9 @@
3034
parser = argparse.ArgumentParser()
3135
parser.add_argument("--index_name", default="baike_cities", type=str, help="The index name of the ANN search engine")
3236
parser.add_argument("--doc_dir", default="data/baike/", type=str, help="The doc path of the corpus")
33-
parser.add_argument("--search_engine", choices=["elastic", "milvus"], default="elastic", help="The type of ANN search engine.")
37+
parser.add_argument('--username', type=str, default="", help='Username of ANN search engine')
38+
parser.add_argument('--password', type=str, default="", help='Password of ANN search engine')
39+
parser.add_argument("--search_engine", choices=["elastic", "milvus", 'bes'], default="elastic", help="The type of ANN search engine.")
3440
parser.add_argument("--host", type=str, default="127.0.0.1", help="host ip of ANN search engine")
3541
parser.add_argument("--port", type=str, default="9200", help="port of ANN search engine")
3642
parser.add_argument("--embedding_dim", default=768, type=int, help="The embedding_dim of index")
@@ -51,6 +57,9 @@
5157
parser.add_argument('--filters', type=list, default=['\n'], help="Filter special symbols")
5258
parser.add_argument('--language', type=str, default='chinese', help="the language of files")
5359
parser.add_argument('--pooling_mode', choices=['max_tokens', 'mean_tokens', 'mean_sqrt_len_tokens', 'cls_token'], default='cls_token', help='the type of sentence embedding')
60+
parser.add_argument("--es_chunk_size", default=500, type=int, help="Number of docs in one chunk sent to es")
61+
parser.add_argument("--es_thread_count", default=32, type=int, help="Size of the threadpool to use for the bulk requests")
62+
parser.add_argument("--es_queue_size", default=32, type=int, help="Size of the task queue between the main thread (producing chunks to send) and the processing threads.")
5463
args = parser.parse_args()
5564
# yapf: enable
5665

@@ -66,13 +75,30 @@ def offline_ann(index_name, doc_dir):
6675
index_param={"M": 16, "efConstruction": 50},
6776
index_type="HNSW",
6877
)
78+
elif args.search_engine == "bes":
79+
80+
document_store = BaiduElasticsearchDocumentStore(
81+
host=args.host,
82+
port=args.port,
83+
username=args.username,
84+
password=args.password,
85+
embedding_dim=args.embedding_dim,
86+
similarity="dot_prod",
87+
vector_type="bpack_vector",
88+
search_fields=["content", "meta"],
89+
index=args.index_name,
90+
chunk_size=args.es_chunk_size,
91+
thread_count=args.es_thread_count,
92+
queue_size=args.es_queue_size,
93+
)
94+
6995
else:
7096
launch_es()
7197
document_store = ElasticsearchDocumentStore(
7298
host=args.host,
7399
port=args.port,
74-
username="",
75-
password="",
100+
username=args.username,
101+
password=args.password,
76102
embedding_dim=args.embedding_dim,
77103
index=index_name,
78104
search_fields=args.search_fields, # 当使用了多路召回并且搜索字段设置了除content的其他字段,构建索引时其他字段也需要设置,例如:['content', 'name']。
@@ -128,6 +154,23 @@ def delete_data(index_name):
128154
index_param={"M": 16, "efConstruction": 50},
129155
index_type="HNSW",
130156
)
157+
elif args.search_engine == "bes":
158+
159+
document_store = BaiduElasticsearchDocumentStore(
160+
host=args.host,
161+
port=args.port,
162+
username=args.username,
163+
password=args.password,
164+
embedding_dim=args.embedding_dim,
165+
similarity="dot_prod",
166+
vector_type="bpack_vector",
167+
search_fields=["content", "meta"],
168+
index=args.index_name,
169+
chunk_size=args.es_chunk_size,
170+
thread_count=args.es_thread_count,
171+
queue_size=args.es_queue_size,
172+
)
173+
131174
else:
132175
document_store = ElasticsearchDocumentStore(
133176
host=args.host,

0 commit comments

Comments
 (0)