Skip to content

Commit 0334963

Browse files
w5688414ZeyuChentianxin
authored
Add ErnieForMultipleChoice (#1597)
* add recall inference similarity * update examples * updatea readme * update dir name * update neural search readme * update milvus readme * update domain adaptive pretraining readme * fix the mistakes * update readme * add recall Paddle Serving Support * update readme * update readme and format the code * reformat the files * move the files * reformat the code * remove redundant code Co-authored-by: Zeyu Chen <[email protected]> Co-authored-by: tianxin <[email protected]>
1 parent 22cee26 commit 0334963

File tree

1 file changed

+81
-8
lines changed

1 file changed

+81
-8
lines changed

paddlenlp/transformers/ernie/modeling.py

Lines changed: 81 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,10 @@
1919
from .. import PretrainedModel, register_base_model
2020

2121
__all__ = [
22-
'ErnieModel',
23-
'ErniePretrainedModel',
24-
'ErnieForSequenceClassification',
25-
'ErnieForTokenClassification',
26-
'ErnieForQuestionAnswering',
27-
'ErnieForPretraining',
28-
'ErniePretrainingCriterion',
29-
'ErnieForMaskedLM',
22+
'ErnieModel', 'ErniePretrainedModel', 'ErnieForSequenceClassification',
23+
'ErnieForTokenClassification', 'ErnieForQuestionAnswering',
24+
'ErnieForPretraining', 'ErniePretrainingCriterion', 'ErnieForMaskedLM',
25+
'ErnieForMultipleChoice'
3026
]
3127

3228

@@ -859,3 +855,80 @@ def forward(self,
859855
sequence_output = outputs[0]
860856
prediction_scores = self.cls(sequence_output, masked_positions=None)
861857
return prediction_scores
858+
859+
860+
class ErnieForMultipleChoice(ErniePretrainedModel):
861+
"""
862+
Ernie Model with a linear layer on top of the hidden-states output layer,
863+
designed for multiple choice tasks like RocStories/SWAG tasks.
864+
865+
Args:
866+
ernie (:class:`ErnieModel`):
867+
An instance of ErnieModel.
868+
num_choices (int, optional):
869+
The number of choices. Defaults to `2`.
870+
dropout (float, optional):
871+
The dropout probability for output of Ernie.
872+
If None, use the same value as `hidden_dropout_prob` of `ErnieModel`
873+
instance `ernie`. Defaults to None.
874+
"""
875+
876+
def __init__(self, ernie, num_choices=2, dropout=None):
877+
super(ErnieForMultipleChoice, self).__init__()
878+
self.num_choices = num_choices
879+
self.ernie = ernie
880+
self.dropout = nn.Dropout(dropout if dropout is not None else
881+
self.ernie.config["hidden_dropout_prob"])
882+
self.classifier = nn.Linear(self.ernie.config["hidden_size"], 1)
883+
self.apply(self.init_weights)
884+
885+
def forward(self,
886+
input_ids,
887+
token_type_ids=None,
888+
position_ids=None,
889+
attention_mask=None):
890+
r"""
891+
The ErnieForMultipleChoice forward method, overrides the __call__() special method.
892+
893+
Args:
894+
input_ids (Tensor):
895+
See :class:`ErnieModel` and shape as [batch_size, num_choice, sequence_length].
896+
token_type_ids(Tensor, optional):
897+
See :class:`ErnieModel` and shape as [batch_size, num_choice, sequence_length].
898+
position_ids(Tensor, optional):
899+
See :class:`ErnieModel` and shape as [batch_size, num_choice, sequence_length].
900+
attention_mask (list, optional):
901+
See :class:`ErnieModel` and shape as [batch_size, num_choice, sequence_length].
902+
903+
Returns:
904+
Tensor: Returns tensor `reshaped_logits`, a tensor of the multiple choice classification logits.
905+
Shape as `[batch_size, num_choice]` and dtype as `float32`.
906+
907+
"""
908+
# input_ids: [bs, num_choice, seq_l]
909+
input_ids = input_ids.reshape(shape=(
910+
-1, input_ids.shape[-1])) # flat_input_ids: [bs*num_choice,seq_l]
911+
912+
if position_ids is not None:
913+
position_ids = position_ids.reshape(shape=(-1,
914+
position_ids.shape[-1]))
915+
if token_type_ids is not None:
916+
token_type_ids = token_type_ids.reshape(shape=(
917+
-1, token_type_ids.shape[-1]))
918+
919+
if attention_mask is not None:
920+
attention_mask = attention_mask.reshape(
921+
shape=(-1, attention_mask.shape[-1]))
922+
923+
_, pooled_output = self.ernie(
924+
input_ids,
925+
token_type_ids=token_type_ids,
926+
position_ids=position_ids,
927+
attention_mask=attention_mask)
928+
pooled_output = self.dropout(pooled_output)
929+
930+
logits = self.classifier(pooled_output) # logits: (bs*num_choice,1)
931+
reshaped_logits = logits.reshape(
932+
shape=(-1, self.num_choices)) # logits: (bs, num_choice)
933+
934+
return reshaped_logits

0 commit comments

Comments
 (0)