23
23
24
24
class ErnieEncoder (ErniePretrainedModel ):
25
25
26
- def __init__ (self , ernie , dropout = None , num_classes = 2 ):
26
+ def __init__ (self ,
27
+ ernie ,
28
+ dropout = None ,
29
+ output_emb_size = None ,
30
+ num_classes = 2 ):
27
31
super (ErnieEncoder , self ).__init__ ()
28
32
self .ernie = ernie # allow ernie to be config
29
33
self .dropout = nn .Dropout (dropout if dropout is not None else 0.1 )
30
34
self .classifier = nn .Linear (self .ernie .config ["hidden_size" ],
31
35
num_classes )
36
+ # Compatible to ERNIE-Search for adding extra linear layer
37
+ self .output_emb_size = output_emb_size
38
+ if output_emb_size is not None and output_emb_size > 0 :
39
+ weight_attr = paddle .ParamAttr (
40
+ initializer = paddle .nn .initializer .TruncatedNormal (std = 0.02 ))
41
+ self .emb_reduce_linear = paddle .nn .Linear (
42
+ self .ernie .config ["hidden_size" ],
43
+ output_emb_size ,
44
+ weight_attr = weight_attr )
32
45
self .apply (self .init_weights )
33
46
34
47
def init_weights (self , layer ):
@@ -79,21 +92,23 @@ def __init__(self,
79
92
query_model_name_or_path = None ,
80
93
title_model_name_or_path = None ,
81
94
share_parameters = False ,
95
+ output_emb_size = None ,
82
96
dropout = None ,
83
97
reinitialize = False ,
84
98
use_cross_batch = False ):
85
99
86
100
super ().__init__ ()
87
101
self .query_ernie , self .title_ernie = None , None
88
102
self .use_cross_batch = use_cross_batch
103
+ self .output_emb_size = output_emb_size
89
104
if query_model_name_or_path is not None :
90
105
self .query_ernie = ErnieEncoder .from_pretrained (
91
- query_model_name_or_path )
106
+ query_model_name_or_path , output_emb_size = output_emb_size )
92
107
if share_parameters :
93
108
self .title_ernie = self .query_ernie
94
109
elif title_model_name_or_path is not None :
95
110
self .title_ernie = ErnieEncoder .from_pretrained (
96
- title_model_name_or_path )
111
+ title_model_name_or_path , output_emb_size = output_emb_size )
97
112
assert (self .query_ernie is not None ) or (self .title_ernie is not None ), \
98
113
"At least one of query_ernie and title_ernie should not be None"
99
114
@@ -125,16 +140,27 @@ def get_pooled_embedding(self,
125
140
position_ids = None ,
126
141
attention_mask = None ,
127
142
is_query = True ):
143
+ """Get the first feature of each sequence for classification"""
128
144
assert (is_query and self .query_ernie is not None ) or (not is_query and self .title_ernie ), \
129
145
"Please check whether your parameter for `is_query` are consistent with DualEncoder initialization."
130
146
if is_query :
131
147
sequence_output , _ = self .query_ernie (input_ids , token_type_ids ,
132
148
position_ids , attention_mask )
149
+ if self .output_emb_size is not None and self .output_emb_size > 0 :
150
+ cls_embedding = self .query_ernie .emb_reduce_linear (
151
+ sequence_output [:, 0 ])
152
+ else :
153
+ cls_embedding = sequence_output [:, 0 ]
133
154
134
155
else :
135
156
sequence_output , _ = self .title_ernie (input_ids , token_type_ids ,
136
157
position_ids , attention_mask )
137
- return sequence_output [:, 0 ]
158
+ if self .output_emb_size is not None and self .output_emb_size > 0 :
159
+ cls_embedding = self .title_ernie .emb_reduce_linear (
160
+ sequence_output [:, 0 ])
161
+ else :
162
+ cls_embedding = sequence_output [:, 0 ]
163
+ return cls_embedding
138
164
139
165
def cosine_sim (self ,
140
166
query_input_ids ,
@@ -272,6 +298,7 @@ def matching(self,
272
298
position_ids = None ,
273
299
attention_mask = None ,
274
300
return_prob_distributation = False ):
301
+ """Use the pooled_output as the feature for pointwise prediction, eg. RocketQAv1"""
275
302
_ , pooled_output = self .ernie (input_ids ,
276
303
token_type_ids = token_type_ids ,
277
304
position_ids = position_ids ,
@@ -288,6 +315,7 @@ def matching_v2(self,
288
315
token_type_ids = None ,
289
316
position_ids = None ,
290
317
attention_mask = None ):
318
+ """Use the cls token embedding as the feature for listwise prediction, eg. RocketQAv2"""
291
319
sequence_output , _ = self .ernie (input_ids ,
292
320
token_type_ids = token_type_ids ,
293
321
position_ids = position_ids ,
@@ -296,6 +324,21 @@ def matching_v2(self,
296
324
probs = self .ernie .classifier (pooled_output )
297
325
return probs
298
326
327
+ def matching_v3 (self ,
328
+ input_ids ,
329
+ token_type_ids = None ,
330
+ position_ids = None ,
331
+ attention_mask = None ):
332
+ """Use the pooled_output as the feature for listwise prediction, eg. ERNIE-Search"""
333
+ sequence_output , pooled_output = self .ernie (
334
+ input_ids ,
335
+ token_type_ids = token_type_ids ,
336
+ position_ids = position_ids ,
337
+ attention_mask = attention_mask )
338
+ pooled_output = self .ernie .dropout (pooled_output )
339
+ probs = self .ernie .classifier (pooled_output )
340
+ return probs
341
+
299
342
def forward (self ,
300
343
input_ids ,
301
344
token_type_ids = None ,
0 commit comments