Skip to content

Commit 930a986

Browse files
authored
Add Ernie Search base model into pipelines (#3906)
* Add Ernie Search base version into pipelines * Add more comments
1 parent c08d47c commit 930a986

File tree

9 files changed

+189
-12
lines changed

9 files changed

+189
-12
lines changed

paddlenlp/transformers/ernie/modeling.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,33 @@ class ErniePretrainedModel(PretrainedModel):
715715
"vocab_size": 30522,
716716
"pad_token_id": 0,
717717
},
718+
"ernie-search-base-dual-encoder-marco-en": {
719+
"attention_probs_dropout_prob": 0.1,
720+
"hidden_act": "gelu",
721+
"hidden_dropout_prob": 0.1,
722+
"hidden_size": 768,
723+
"initializer_range": 0.02,
724+
"max_position_embeddings": 512,
725+
"num_attention_heads": 12,
726+
"num_hidden_layers": 12,
727+
"type_vocab_size": 4,
728+
"vocab_size": 30522,
729+
"pad_token_id": 0,
730+
},
731+
"ernie-search-large-cross-encoder-marco-en": {
732+
"attention_probs_dropout_prob": 0.1,
733+
"intermediate_size": 4096,
734+
"hidden_act": "gelu",
735+
"hidden_dropout_prob": 0.1,
736+
"hidden_size": 1024,
737+
"initializer_range": 0.02,
738+
"max_position_embeddings": 512,
739+
"num_attention_heads": 16,
740+
"num_hidden_layers": 24,
741+
"type_vocab_size": 4,
742+
"vocab_size": 30522,
743+
"pad_token_id": 0,
744+
},
718745
}
719746
resource_files_names = {"model_state": "model_state.pdparams"}
720747
pretrained_resource_files_map = {
@@ -800,6 +827,10 @@ class ErniePretrainedModel(PretrainedModel):
800827
"https://paddlenlp.bj.bcebos.com/models/transformers/rocketqa/rocketqav2_en_marco_query_encoder.pdparams",
801828
"rocketqav2-en-marco-para-encoder":
802829
"https://paddlenlp.bj.bcebos.com/models/transformers/rocketqa/rocketqav2_en_marco_para_encoder.pdparams",
830+
"ernie-search-base-dual-encoder-marco-en":
831+
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie_search/ernie_search_base_dual_encoder_marco_en.pdparams",
832+
"ernie-search-large-cross-encoder-marco-en":
833+
"https://paddlenlp.bj.bcebos.com/models/transformers/ernie_search/ernie_search_large_cross_encoder_marco_en.pdparams",
803834
}
804835
}
805836
base_model_prefix = "ernie"

paddlenlp/transformers/ernie/tokenizer.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,8 @@
7070
"rocketqav2-en-marco-cross-encoder": 512,
7171
"rocketqav2-en-marco-query-encoder": 512,
7272
"rocketqav2-en-marco-para-encoder": 512,
73+
"ernie-search-base-dual-encoder-marco-en": 512,
74+
"ernie-search-large-cross-encoder-marco-en": 512,
7375
}
7476

7577

@@ -211,6 +213,10 @@ class ErnieTokenizer(PretrainedTokenizer):
211213
"https://bj.bcebos.com/paddlenlp/models/transformers/ernie_v2_base/vocab.txt",
212214
"rocketqav2-en-marco-para-encoder":
213215
"https://bj.bcebos.com/paddlenlp/models/transformers/ernie_v2_base/vocab.txt",
216+
"ernie-search-base-dual-encoder-marco-en":
217+
"https://bj.bcebos.com/paddlenlp/models/transformers/ernie_v2_base/vocab.txt",
218+
"ernie-search-large-cross-encoder-marco-en":
219+
"https://bj.bcebos.com/paddlenlp/models/transformers/ernie_v2_large/vocab.txt",
214220
}
215221
}
216222
pretrained_init_configuration = {
@@ -343,6 +349,12 @@ class ErnieTokenizer(PretrainedTokenizer):
343349
"rocketqav2-en-marco-para-encoder": {
344350
"do_lower_case": True
345351
},
352+
"ernie-search-base-dual-encoder-marco-en": {
353+
"do_lower_case": True
354+
},
355+
"ernie-search-large-cross-encoder-marco-en": {
356+
"do_lower_case": True
357+
},
346358
}
347359

348360
max_model_input_sizes = PRETRAINED_POSITIONAL_EMBEDDINGS_SIZES

paddlenlp/transformers/semantic_search/modeling.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,25 @@
2323

2424
class ErnieEncoder(ErniePretrainedModel):
2525

26-
def __init__(self, ernie, dropout=None, num_classes=2):
26+
def __init__(self,
27+
ernie,
28+
dropout=None,
29+
output_emb_size=None,
30+
num_classes=2):
2731
super(ErnieEncoder, self).__init__()
2832
self.ernie = ernie # allow ernie to be config
2933
self.dropout = nn.Dropout(dropout if dropout is not None else 0.1)
3034
self.classifier = nn.Linear(self.ernie.config["hidden_size"],
3135
num_classes)
36+
# Compatible to ERNIE-Search for adding extra linear layer
37+
self.output_emb_size = output_emb_size
38+
if output_emb_size is not None and output_emb_size > 0:
39+
weight_attr = paddle.ParamAttr(
40+
initializer=paddle.nn.initializer.TruncatedNormal(std=0.02))
41+
self.emb_reduce_linear = paddle.nn.Linear(
42+
self.ernie.config["hidden_size"],
43+
output_emb_size,
44+
weight_attr=weight_attr)
3245
self.apply(self.init_weights)
3346

3447
def init_weights(self, layer):
@@ -79,21 +92,23 @@ def __init__(self,
7992
query_model_name_or_path=None,
8093
title_model_name_or_path=None,
8194
share_parameters=False,
95+
output_emb_size=None,
8296
dropout=None,
8397
reinitialize=False,
8498
use_cross_batch=False):
8599

86100
super().__init__()
87101
self.query_ernie, self.title_ernie = None, None
88102
self.use_cross_batch = use_cross_batch
103+
self.output_emb_size = output_emb_size
89104
if query_model_name_or_path is not None:
90105
self.query_ernie = ErnieEncoder.from_pretrained(
91-
query_model_name_or_path)
106+
query_model_name_or_path, output_emb_size=output_emb_size)
92107
if share_parameters:
93108
self.title_ernie = self.query_ernie
94109
elif title_model_name_or_path is not None:
95110
self.title_ernie = ErnieEncoder.from_pretrained(
96-
title_model_name_or_path)
111+
title_model_name_or_path, output_emb_size=output_emb_size)
97112
assert (self.query_ernie is not None) or (self.title_ernie is not None), \
98113
"At least one of query_ernie and title_ernie should not be None"
99114

@@ -125,16 +140,27 @@ def get_pooled_embedding(self,
125140
position_ids=None,
126141
attention_mask=None,
127142
is_query=True):
143+
"""Get the first feature of each sequence for classification"""
128144
assert (is_query and self.query_ernie is not None) or (not is_query and self.title_ernie), \
129145
"Please check whether your parameter for `is_query` are consistent with DualEncoder initialization."
130146
if is_query:
131147
sequence_output, _ = self.query_ernie(input_ids, token_type_ids,
132148
position_ids, attention_mask)
149+
if self.output_emb_size is not None and self.output_emb_size > 0:
150+
cls_embedding = self.query_ernie.emb_reduce_linear(
151+
sequence_output[:, 0])
152+
else:
153+
cls_embedding = sequence_output[:, 0]
133154

134155
else:
135156
sequence_output, _ = self.title_ernie(input_ids, token_type_ids,
136157
position_ids, attention_mask)
137-
return sequence_output[:, 0]
158+
if self.output_emb_size is not None and self.output_emb_size > 0:
159+
cls_embedding = self.title_ernie.emb_reduce_linear(
160+
sequence_output[:, 0])
161+
else:
162+
cls_embedding = sequence_output[:, 0]
163+
return cls_embedding
138164

139165
def cosine_sim(self,
140166
query_input_ids,
@@ -272,6 +298,7 @@ def matching(self,
272298
position_ids=None,
273299
attention_mask=None,
274300
return_prob_distributation=False):
301+
"""Use the pooled_output as the feature for pointwise prediction, eg. RocketQAv1"""
275302
_, pooled_output = self.ernie(input_ids,
276303
token_type_ids=token_type_ids,
277304
position_ids=position_ids,
@@ -288,6 +315,7 @@ def matching_v2(self,
288315
token_type_ids=None,
289316
position_ids=None,
290317
attention_mask=None):
318+
"""Use the cls token embedding as the feature for listwise prediction, eg. RocketQAv2"""
291319
sequence_output, _ = self.ernie(input_ids,
292320
token_type_ids=token_type_ids,
293321
position_ids=position_ids,
@@ -296,6 +324,21 @@ def matching_v2(self,
296324
probs = self.ernie.classifier(pooled_output)
297325
return probs
298326

327+
def matching_v3(self,
328+
input_ids,
329+
token_type_ids=None,
330+
position_ids=None,
331+
attention_mask=None):
332+
"""Use the pooled_output as the feature for listwise prediction, eg. ERNIE-Search"""
333+
sequence_output, pooled_output = self.ernie(
334+
input_ids,
335+
token_type_ids=token_type_ids,
336+
position_ids=position_ids,
337+
attention_mask=attention_mask)
338+
pooled_output = self.ernie.dropout(pooled_output)
339+
probs = self.ernie.classifier(pooled_output)
340+
return probs
341+
299342
def forward(self,
300343
input_ids,
301344
token_type_ids=None,

pipelines/API.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
| rocketqa-zh-micro-query-encoder | Chinese | 4-layer, 384-hidden, 12-heads, 23M parameters. Trained on DuReader retrieval text. |
1515
| rocketqa-zh-nano-query-encoder | Chinese | 4-layer, 312-hidden, 12-heads, 18M parameters. Trained on DuReader retrieval text. |
1616
| rocketqav2-en-marco-query-encoder | English | 12-layer, 768-hidden, 12-heads, 118M parameters. Trained on MSMARCO. |
17+
| ernie-search-base-dual-encoder-marco-en | English | 12-layer, 768-hidden, 12-heads, 118M parameters. Trained on MSMARCO. |
1718

1819
## ErnieRanker
1920

@@ -27,6 +28,7 @@
2728
| rocketqa-micro-cross-encoder | Chinese | 4-layer, 384-hidden, 12-heads, 23M parameters. Trained on DuReader retrieval text. |
2829
| rocketqa-nano-cross-encoder | Chinese | 4-layer, 312-hidden, 12-heads, 18M parameters. Trained on DuReader retrieval text. |
2930
| rocketqav2-en-marco-cross-encoder | English | 12-layer, 768-hidden, 12-heads, 118M parameters. Trained on Trained on MSMARCO. |
31+
| ernie-search-large-cross-encoder-marco-en | English | 24-layer, 768-hidden, 12-heads, 118M parameters. Trained on Trained on MSMARCO. |
3032

3133
## ErnieReader
3234

pipelines/pipelines/nodes/retriever/dense.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def __init__(
4949
Path, str] = "rocketqa-zh-dureader-para-encoder",
5050
params_path: Optional[str] = "",
5151
model_version: Optional[str] = None,
52-
output_emb_size=256,
52+
output_emb_size: Optional[int] = None,
53+
reinitialize: bool = False,
54+
share_parameters: bool = False,
5355
max_seq_len_query: int = 64,
5456
max_seq_len_passage: int = 256,
5557
top_k: int = 10,
@@ -98,7 +100,7 @@ def __init__(
98100
:param progress_bar: Whether to show a tqdm progress bar or not.
99101
Can be helpful to disable in production deployments to keep the logs clean.
100102
"""
101-
# save init parameters to enable export of component config as YAML
103+
# Save init parameters to enable export of component config as YAML
102104
self.set_config(
103105
document_store=document_store,
104106
query_embedding_model=query_embedding_model,
@@ -110,6 +112,9 @@ def __init__(
110112
use_gpu=use_gpu,
111113
batch_size=batch_size,
112114
embed_title=embed_title,
115+
reinitialize=reinitialize,
116+
share_parameters=share_parameters,
117+
output_emb_size=output_emb_size,
113118
similarity_function=similarity_function,
114119
progress_bar=progress_bar,
115120
)
@@ -150,8 +155,12 @@ def __init__(
150155
self.passage_tokenizer = AutoTokenizer.from_pretrained(
151156
query_embedding_model)
152157
else:
153-
self.ernie_dual_encoder = ErnieDualEncoder(query_embedding_model,
154-
passage_embedding_model)
158+
self.ernie_dual_encoder = ErnieDualEncoder(
159+
query_embedding_model,
160+
passage_embedding_model,
161+
output_emb_size=output_emb_size,
162+
reinitialize=reinitialize,
163+
share_parameters=share_parameters)
155164
self.query_tokenizer = AutoTokenizer.from_pretrained(
156165
query_embedding_model)
157166
self.passage_tokenizer = AutoTokenizer.from_pretrained(

pipelines/rest_api/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
"PIPELINE_YAML_PATH",
2020
str((Path(__file__).parent / "pipeline" / "pipelines.yaml").absolute()))
2121
QUERY_PIPELINE_NAME = os.getenv("QUERY_PIPELINE_NAME", "query")
22+
QUERY_QA_PAIRS_NAME = os.getenv('QUERY_QA_PAIRS_NAME', 'query_qa_pairs')
2223
INDEXING_PIPELINE_NAME = os.getenv("INDEXING_PIPELINE_NAME", "indexing")
2324
INDEXING_QA_GENERATING_PIPELINE_NAME = os.getenv(
2425
"INDEXING_QA_GENERATING_PIPELINE_NAME", "indexing_qa_generating")

pipelines/rest_api/controller/search.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
import pipelines
2727
from pipelines.pipelines.base import Pipeline
28-
from rest_api.config import PIPELINE_YAML_PATH, QUERY_PIPELINE_NAME
28+
from rest_api.config import PIPELINE_YAML_PATH, QUERY_PIPELINE_NAME, QUERY_QA_PAIRS_NAME
2929
from rest_api.config import LOG_LEVEL, CONCURRENT_REQUEST_PER_WORKER
3030
from rest_api.schema import QueryRequest, QueryResponse, DocumentRequest, DocumentResponse, QueryImageResponse, QueryQAPairResponse, QueryQAPairRequest
3131
from rest_api.controller.utils import RequestLimiter
@@ -42,8 +42,11 @@
4242
PIPELINE = Pipeline.load_from_yaml(Path(PIPELINE_YAML_PATH),
4343
pipeline_name=QUERY_PIPELINE_NAME)
4444

45-
QA_PAIR_PIPELINE = Pipeline.load_from_yaml(Path(PIPELINE_YAML_PATH),
46-
pipeline_name="query_qa_pairs")
45+
try:
46+
QA_PAIR_PIPELINE = Pipeline.load_from_yaml(
47+
Path(PIPELINE_YAML_PATH), pipeline_name=QUERY_QA_PAIRS_NAME)
48+
except Exception as e:
49+
logger.warning(f"Request pipeline ('{QUERY_QA_PAIRS_NAME}: is null'). ")
4750
DOCUMENT_STORE = PIPELINE.get_document_store()
4851
logging.info(f"Loaded pipeline nodes: {PIPELINE.graph.nodes.keys()}")
4952

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
version: '1.1.0'
2+
3+
components: # define all the building-blocks for Pipeline
4+
- name: DocumentStore
5+
type: ElasticsearchDocumentStore # consider using Milvus2DocumentStore or WeaviateDocumentStore for scaling to large number of documents
6+
params:
7+
host: localhost
8+
port: 9200
9+
index: msmarco
10+
embedding_dim: 768
11+
- name: Retriever
12+
type: DensePassageRetriever
13+
params:
14+
document_store: DocumentStore # params can reference other components defined in the YAML
15+
top_k: 10
16+
query_embedding_model: ernie-search-base-dual-encoder-marco-en # an example of using ernie search models
17+
share_parameters: True
18+
output_emb_size: 768
19+
embed_title: False
20+
- name: Ranker # custom-name for the component; helpful for visualization & debugging
21+
type: ErnieRanker # pipelines Class name for the component
22+
params:
23+
model_name_or_path: rocketqav2-en-marco-cross-encoder
24+
top_k: 3
25+
use_en: True,
26+
reinitialize: True
27+
- name: TextFileConverter
28+
type: TextConverter
29+
- name: ImageFileConverter
30+
type: ImageToTextConverter
31+
- name: PDFFileConverter
32+
type: PDFToTextConverter
33+
- name: DocxFileConverter
34+
type: DocxToTextConverter
35+
- name: Preprocessor
36+
type: PreProcessor
37+
params:
38+
split_by: word
39+
split_length: 1000
40+
- name: FileTypeClassifier
41+
type: FileTypeClassifier
42+
43+
pipelines:
44+
- name: query
45+
type: Query
46+
nodes:
47+
- name: Retriever
48+
inputs: [Query]
49+
- name: Ranker
50+
inputs: [Retriever]
51+
- name: indexing
52+
type: Indexing
53+
nodes:
54+
- name: FileTypeClassifier
55+
inputs: [File]
56+
- name: TextFileConverter
57+
inputs: [FileTypeClassifier.output_1]
58+
- name: PDFFileConverter
59+
inputs: [FileTypeClassifier.output_2]
60+
- name: DocxFileConverter
61+
inputs: [FileTypeClassifier.output_4]
62+
- name: ImageFileConverter
63+
inputs: [FileTypeClassifier.output_6]
64+
- name: Preprocessor
65+
inputs: [PDFFileConverter, TextFileConverter, DocxFileConverter, ImageFileConverter]
66+
- name: Retriever
67+
inputs: [Preprocessor]
68+
- name: DocumentStore
69+
inputs: [Retriever]

pipelines/utils/offline_ann.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,13 @@
8282
parser.add_argument(
8383
'--delete_index',
8484
action='store_true',
85-
help='whether to delete existing index while updating index')
85+
help='Whether to delete existing index while updating index')
86+
87+
parser.add_argument(
88+
'--share_parameters',
89+
action='store_true',
90+
help='Use to control the query and title models sharing the same parameters'
91+
)
8692

8793
args = parser.parse_args()
8894

@@ -126,6 +132,7 @@ def offline_ann(index_name, doc_dir):
126132
passage_embedding_model=args.passage_embedding_model,
127133
params_path=args.params_path,
128134
output_emb_size=args.embedding_dim,
135+
share_parameters=args.share_parameters,
129136
max_seq_len_query=64,
130137
max_seq_len_passage=256,
131138
batch_size=16,

0 commit comments

Comments
 (0)