|
29 | 29 |
|
30 | 30 | __all__ = [
|
31 | 31 | 'LayoutXLMModel', "LayoutXLMPretrainedModel",
|
32 |
| - "LayoutXLMForTokenClassification", "LayoutXLMForPretraining", |
33 |
| - "LayoutXLMForRelationExtraction" |
| 32 | + "LayoutXLMForTokenClassification", "LayoutXLMForSequenceClassification", |
| 33 | + "LayoutXLMForPretraining", "LayoutXLMForRelationExtraction" |
34 | 34 | ]
|
35 | 35 |
|
36 | 36 |
|
@@ -63,6 +63,34 @@ def relative_position_bucket(relative_position,
|
63 | 63 | return ret
|
64 | 64 |
|
65 | 65 |
|
| 66 | +def token_featue_to_sequence_feature(input_ids, seq_length, sequence_output): |
| 67 | + """ |
| 68 | + used to transform token feature into sequence feature by |
| 69 | + averaging all the token features of certain sequence |
| 70 | + """ |
| 71 | + batches = input_ids.shape[0] |
| 72 | + for batch_id in range(batches): |
| 73 | + start_idx = -1 |
| 74 | + for i in range(0, seq_length): |
| 75 | + if input_ids[batch_id, i] == 6: |
| 76 | + if start_idx > -1: |
| 77 | + feature_block = sequence_output[batch_id, start_idx + 1:i] |
| 78 | + sequence_output[batch_id, start_idx] = paddle.mean( |
| 79 | + feature_block, axis=0) |
| 80 | + start_idx = i |
| 81 | + |
| 82 | + if input_ids[batch_id, i] == 1: |
| 83 | + feature_block = sequence_output[batch_id, start_idx + 1:i] |
| 84 | + sequence_output[batch_id, start_idx] = paddle.mean( |
| 85 | + feature_block, axis=0) |
| 86 | + break |
| 87 | + |
| 88 | + if i == seq_length - 1: |
| 89 | + sequence_output[batch_id, start_idx] = paddle.mean( |
| 90 | + feature_block, axis=0) |
| 91 | + return |
| 92 | + |
| 93 | + |
66 | 94 | class LayoutXLMPooler(Layer):
|
67 | 95 | def __init__(self, hidden_size, with_pool):
|
68 | 96 | super(LayoutXLMPooler, self).__init__()
|
@@ -911,6 +939,73 @@ def forward(
|
911 | 939 | return outputs
|
912 | 940 |
|
913 | 941 |
|
| 942 | +class LayoutXLMForSequenceClassification(LayoutXLMPretrainedModel): |
| 943 | + def __init__(self, layoutxlm, num_classes=2, dropout=None): |
| 944 | + super(LayoutXLMForSequenceClassification, self).__init__() |
| 945 | + self.num_classes = num_classes |
| 946 | + if isinstance(layoutxlm, dict): |
| 947 | + self.layoutxlm = LayoutXLMModel(**layoutxlm) |
| 948 | + else: |
| 949 | + self.layoutxlm = layoutxlm |
| 950 | + self.dropout = nn.Dropout(dropout if dropout is not None else |
| 951 | + self.layoutxlm.config["hidden_dropout_prob"]) |
| 952 | + self.classifier = nn.Linear(self.layoutxlm.config["hidden_size"], |
| 953 | + num_classes) |
| 954 | + self.classifier.apply(self.init_weights) |
| 955 | + |
| 956 | + def get_input_embeddings(self): |
| 957 | + return self.layoutxlm.embeddings.word_embeddings |
| 958 | + |
| 959 | + def forward( |
| 960 | + self, |
| 961 | + input_ids=None, |
| 962 | + bbox=None, |
| 963 | + image=None, |
| 964 | + attention_mask=None, |
| 965 | + token_type_ids=None, |
| 966 | + position_ids=None, |
| 967 | + head_mask=None, |
| 968 | + labels=None, ): |
| 969 | + outputs = self.layoutxlm( |
| 970 | + input_ids=input_ids, |
| 971 | + bbox=bbox, |
| 972 | + image=image, |
| 973 | + attention_mask=attention_mask, |
| 974 | + token_type_ids=token_type_ids, |
| 975 | + position_ids=position_ids, |
| 976 | + head_mask=head_mask, ) |
| 977 | + seq_length = input_ids.shape[1] |
| 978 | + # sequence out and image out |
| 979 | + sequence_output, image_output = outputs[0][:, :seq_length], outputs[ |
| 980 | + 0][:, seq_length:] |
| 981 | + |
| 982 | + # token feature to sequence feature |
| 983 | + token_featue_to_sequence_feature(input_ids, seq_length, sequence_output) |
| 984 | + |
| 985 | + sequence_output = self.dropout(sequence_output) |
| 986 | + logits = self.classifier(sequence_output) |
| 987 | + |
| 988 | + outputs = logits, |
| 989 | + |
| 990 | + if labels is not None: |
| 991 | + loss_fct = nn.CrossEntropyLoss() |
| 992 | + |
| 993 | + if attention_mask is not None: |
| 994 | + active_loss = attention_mask.reshape([-1, ]) == 1 |
| 995 | + active_logits = logits.reshape( |
| 996 | + [-1, self.num_classes])[active_loss] |
| 997 | + active_labels = labels.reshape([-1, ])[active_loss] |
| 998 | + loss = loss_fct(active_logits, active_labels) |
| 999 | + else: |
| 1000 | + loss = loss_fct( |
| 1001 | + logits.reshape([-1, self.num_classes]), |
| 1002 | + labels.reshape([-1, ])) |
| 1003 | + |
| 1004 | + outputs = (loss, ) + outputs |
| 1005 | + |
| 1006 | + return outputs |
| 1007 | + |
| 1008 | + |
914 | 1009 | class LayoutXLMPredictionHead(Layer):
|
915 | 1010 | """
|
916 | 1011 | Bert Model with a `language modeling` head on top for CLM fine-tuning.
|
@@ -1036,12 +1131,10 @@ def build_relation(self, relations, entities):
|
1036 | 1131 | for b in range(batch_size):
|
1037 | 1132 | if len(entities[b]["start"]) <= 2:
|
1038 | 1133 | entities[b] = {"end": [1, 1], "label": [0, 0], "start": [0, 0]}
|
1039 |
| - all_possible_relations = set([ |
1040 |
| - (i, j) |
1041 |
| - for i in range(len(entities[b]["label"])) |
1042 |
| - for j in range(len(entities[b]["label"])) |
1043 |
| - if entities[b]["label"][i] == 1 and entities[b]["label"][j] == 2 |
1044 |
| - ]) |
| 1134 | + all_possible_relations = set( |
| 1135 | + [(i, j) for i in range(len(entities[b]["label"])) |
| 1136 | + for j in range(len(entities[b]["label"])) if |
| 1137 | + entities[b]["label"][i] == 1 and entities[b]["label"][j] == 2]) |
1045 | 1138 | if len(all_possible_relations) == 0:
|
1046 | 1139 | all_possible_relations = {(0, 1)}
|
1047 | 1140 | positive_relations = set(
|
|
0 commit comments