|
19 | 19 | from .. import PretrainedModel, register_base_model
|
20 | 20 |
|
21 | 21 | __all__ = [
|
22 |
| - 'ErnieModel', |
23 |
| - 'ErniePretrainedModel', |
24 |
| - 'ErnieForSequenceClassification', |
25 |
| - 'ErnieForTokenClassification', |
26 |
| - 'ErnieForQuestionAnswering', |
27 |
| - 'ErnieForPretraining', |
28 |
| - 'ErniePretrainingCriterion', |
29 |
| - 'ErnieForMaskedLM', |
| 22 | + 'ErnieModel', 'ErniePretrainedModel', 'ErnieForSequenceClassification', |
| 23 | + 'ErnieForTokenClassification', 'ErnieForQuestionAnswering', |
| 24 | + 'ErnieForPretraining', 'ErniePretrainingCriterion', 'ErnieForMaskedLM', |
| 25 | + 'ErnieForMultipleChoice' |
30 | 26 | ]
|
31 | 27 |
|
32 | 28 |
|
@@ -859,3 +855,80 @@ def forward(self,
|
859 | 855 | sequence_output = outputs[0]
|
860 | 856 | prediction_scores = self.cls(sequence_output, masked_positions=None)
|
861 | 857 | return prediction_scores
|
| 858 | + |
| 859 | + |
| 860 | +class ErnieForMultipleChoice(ErniePretrainedModel): |
| 861 | + """ |
| 862 | + Ernie Model with a linear layer on top of the hidden-states output layer, |
| 863 | + designed for multiple choice tasks like RocStories/SWAG tasks. |
| 864 | + |
| 865 | + Args: |
| 866 | + ernie (:class:`ErnieModel`): |
| 867 | + An instance of ErnieModel. |
| 868 | + num_choices (int, optional): |
| 869 | + The number of choices. Defaults to `2`. |
| 870 | + dropout (float, optional): |
| 871 | + The dropout probability for output of Ernie. |
| 872 | + If None, use the same value as `hidden_dropout_prob` of `ErnieModel` |
| 873 | + instance `ernie`. Defaults to None. |
| 874 | + """ |
| 875 | + |
| 876 | + def __init__(self, ernie, num_choices=2, dropout=None): |
| 877 | + super(ErnieForMultipleChoice, self).__init__() |
| 878 | + self.num_choices = num_choices |
| 879 | + self.ernie = ernie |
| 880 | + self.dropout = nn.Dropout(dropout if dropout is not None else |
| 881 | + self.ernie.config["hidden_dropout_prob"]) |
| 882 | + self.classifier = nn.Linear(self.ernie.config["hidden_size"], 1) |
| 883 | + self.apply(self.init_weights) |
| 884 | + |
| 885 | + def forward(self, |
| 886 | + input_ids, |
| 887 | + token_type_ids=None, |
| 888 | + position_ids=None, |
| 889 | + attention_mask=None): |
| 890 | + r""" |
| 891 | + The ErnieForMultipleChoice forward method, overrides the __call__() special method. |
| 892 | +
|
| 893 | + Args: |
| 894 | + input_ids (Tensor): |
| 895 | + See :class:`ErnieModel` and shape as [batch_size, num_choice, sequence_length]. |
| 896 | + token_type_ids(Tensor, optional): |
| 897 | + See :class:`ErnieModel` and shape as [batch_size, num_choice, sequence_length]. |
| 898 | + position_ids(Tensor, optional): |
| 899 | + See :class:`ErnieModel` and shape as [batch_size, num_choice, sequence_length]. |
| 900 | + attention_mask (list, optional): |
| 901 | + See :class:`ErnieModel` and shape as [batch_size, num_choice, sequence_length]. |
| 902 | +
|
| 903 | + Returns: |
| 904 | + Tensor: Returns tensor `reshaped_logits`, a tensor of the multiple choice classification logits. |
| 905 | + Shape as `[batch_size, num_choice]` and dtype as `float32`. |
| 906 | +
|
| 907 | + """ |
| 908 | + # input_ids: [bs, num_choice, seq_l] |
| 909 | + input_ids = input_ids.reshape(shape=( |
| 910 | + -1, input_ids.shape[-1])) # flat_input_ids: [bs*num_choice,seq_l] |
| 911 | + |
| 912 | + if position_ids is not None: |
| 913 | + position_ids = position_ids.reshape(shape=(-1, |
| 914 | + position_ids.shape[-1])) |
| 915 | + if token_type_ids is not None: |
| 916 | + token_type_ids = token_type_ids.reshape(shape=( |
| 917 | + -1, token_type_ids.shape[-1])) |
| 918 | + |
| 919 | + if attention_mask is not None: |
| 920 | + attention_mask = attention_mask.reshape( |
| 921 | + shape=(-1, attention_mask.shape[-1])) |
| 922 | + |
| 923 | + _, pooled_output = self.ernie( |
| 924 | + input_ids, |
| 925 | + token_type_ids=token_type_ids, |
| 926 | + position_ids=position_ids, |
| 927 | + attention_mask=attention_mask) |
| 928 | + pooled_output = self.dropout(pooled_output) |
| 929 | + |
| 930 | + logits = self.classifier(pooled_output) # logits: (bs*num_choice,1) |
| 931 | + reshaped_logits = logits.reshape( |
| 932 | + shape=(-1, self.num_choices)) # logits: (bs, num_choice) |
| 933 | + |
| 934 | + return reshaped_logits |
0 commit comments