Skip to content

Commit e1014ee

Browse files
update layoutxlm/modeling.py (#2285)
Co-authored-by: yingyibiao <[email protected]>
1 parent 1a8c099 commit e1014ee

File tree

1 file changed

+101
-8
lines changed

1 file changed

+101
-8
lines changed

paddlenlp/transformers/layoutxlm/modeling.py

Lines changed: 101 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929

3030
__all__ = [
3131
'LayoutXLMModel', "LayoutXLMPretrainedModel",
32-
"LayoutXLMForTokenClassification", "LayoutXLMForPretraining",
33-
"LayoutXLMForRelationExtraction"
32+
"LayoutXLMForTokenClassification", "LayoutXLMForSequenceClassification",
33+
"LayoutXLMForPretraining", "LayoutXLMForRelationExtraction"
3434
]
3535

3636

@@ -63,6 +63,34 @@ def relative_position_bucket(relative_position,
6363
return ret
6464

6565

66+
def token_featue_to_sequence_feature(input_ids, seq_length, sequence_output):
67+
"""
68+
used to transform token feature into sequence feature by
69+
averaging all the token features of certain sequence
70+
"""
71+
batches = input_ids.shape[0]
72+
for batch_id in range(batches):
73+
start_idx = -1
74+
for i in range(0, seq_length):
75+
if input_ids[batch_id, i] == 6:
76+
if start_idx > -1:
77+
feature_block = sequence_output[batch_id, start_idx + 1:i]
78+
sequence_output[batch_id, start_idx] = paddle.mean(
79+
feature_block, axis=0)
80+
start_idx = i
81+
82+
if input_ids[batch_id, i] == 1:
83+
feature_block = sequence_output[batch_id, start_idx + 1:i]
84+
sequence_output[batch_id, start_idx] = paddle.mean(
85+
feature_block, axis=0)
86+
break
87+
88+
if i == seq_length - 1:
89+
sequence_output[batch_id, start_idx] = paddle.mean(
90+
feature_block, axis=0)
91+
return
92+
93+
6694
class LayoutXLMPooler(Layer):
6795
def __init__(self, hidden_size, with_pool):
6896
super(LayoutXLMPooler, self).__init__()
@@ -911,6 +939,73 @@ def forward(
911939
return outputs
912940

913941

942+
class LayoutXLMForSequenceClassification(LayoutXLMPretrainedModel):
943+
def __init__(self, layoutxlm, num_classes=2, dropout=None):
944+
super(LayoutXLMForSequenceClassification, self).__init__()
945+
self.num_classes = num_classes
946+
if isinstance(layoutxlm, dict):
947+
self.layoutxlm = LayoutXLMModel(**layoutxlm)
948+
else:
949+
self.layoutxlm = layoutxlm
950+
self.dropout = nn.Dropout(dropout if dropout is not None else
951+
self.layoutxlm.config["hidden_dropout_prob"])
952+
self.classifier = nn.Linear(self.layoutxlm.config["hidden_size"],
953+
num_classes)
954+
self.classifier.apply(self.init_weights)
955+
956+
def get_input_embeddings(self):
957+
return self.layoutxlm.embeddings.word_embeddings
958+
959+
def forward(
960+
self,
961+
input_ids=None,
962+
bbox=None,
963+
image=None,
964+
attention_mask=None,
965+
token_type_ids=None,
966+
position_ids=None,
967+
head_mask=None,
968+
labels=None, ):
969+
outputs = self.layoutxlm(
970+
input_ids=input_ids,
971+
bbox=bbox,
972+
image=image,
973+
attention_mask=attention_mask,
974+
token_type_ids=token_type_ids,
975+
position_ids=position_ids,
976+
head_mask=head_mask, )
977+
seq_length = input_ids.shape[1]
978+
# sequence out and image out
979+
sequence_output, image_output = outputs[0][:, :seq_length], outputs[
980+
0][:, seq_length:]
981+
982+
# token feature to sequence feature
983+
token_featue_to_sequence_feature(input_ids, seq_length, sequence_output)
984+
985+
sequence_output = self.dropout(sequence_output)
986+
logits = self.classifier(sequence_output)
987+
988+
outputs = logits,
989+
990+
if labels is not None:
991+
loss_fct = nn.CrossEntropyLoss()
992+
993+
if attention_mask is not None:
994+
active_loss = attention_mask.reshape([-1, ]) == 1
995+
active_logits = logits.reshape(
996+
[-1, self.num_classes])[active_loss]
997+
active_labels = labels.reshape([-1, ])[active_loss]
998+
loss = loss_fct(active_logits, active_labels)
999+
else:
1000+
loss = loss_fct(
1001+
logits.reshape([-1, self.num_classes]),
1002+
labels.reshape([-1, ]))
1003+
1004+
outputs = (loss, ) + outputs
1005+
1006+
return outputs
1007+
1008+
9141009
class LayoutXLMPredictionHead(Layer):
9151010
"""
9161011
Bert Model with a `language modeling` head on top for CLM fine-tuning.
@@ -1036,12 +1131,10 @@ def build_relation(self, relations, entities):
10361131
for b in range(batch_size):
10371132
if len(entities[b]["start"]) <= 2:
10381133
entities[b] = {"end": [1, 1], "label": [0, 0], "start": [0, 0]}
1039-
all_possible_relations = set([
1040-
(i, j)
1041-
for i in range(len(entities[b]["label"]))
1042-
for j in range(len(entities[b]["label"]))
1043-
if entities[b]["label"][i] == 1 and entities[b]["label"][j] == 2
1044-
])
1134+
all_possible_relations = set(
1135+
[(i, j) for i in range(len(entities[b]["label"]))
1136+
for j in range(len(entities[b]["label"])) if
1137+
entities[b]["label"][i] == 1 and entities[b]["label"][j] == 2])
10451138
if len(all_possible_relations) == 0:
10461139
all_possible_relations = {(0, 1)}
10471140
positive_relations = set(

0 commit comments

Comments
 (0)