Skip to content

Commit 864f2f7

Browse files
committed
Create multihead BERT, required by MultiMatch.
1 parent 93457be commit 864f2f7

File tree

4 files changed

+90
-4
lines changed

4 files changed

+90
-4
lines changed

semilearn/datasets/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,10 +137,10 @@ def make_imbalance_data(max_num_labels, num_classes, gamma):
137137

138138

139139
def get_collactor(args, net):
140-
if net == 'bert_base_uncased':
140+
if net in ['bert_base_uncased', 'bert_base_uncased_multihead']:
141141
from semilearn.datasets.collactors import get_bert_base_uncased_collactor
142142
collact_fn = get_bert_base_uncased_collactor(args.max_length)
143-
elif net == 'bert_base_cased':
143+
elif net in ['bert_base_cased', 'bert_base_cased_multihead']:
144144
from semilearn.datasets.collactors import get_bert_base_cased_collactor
145145
collact_fn = get_bert_base_cased_collactor(args.max_length)
146146
elif net == 'wave2vecv2_base':

semilearn/nets/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,6 @@
44
from .resnet import resnet50
55
from .wrn import wrn_28_2, wrn_28_8, wrn_var_37_2
66
from .vit import vit_base_patch16_224, vit_small_patch16_224, vit_small_patch2_32, vit_tiny_patch2_32, vit_base_patch16_96
7-
from .bert import bert_base_cased, bert_base_uncased
7+
from .bert import bert_base_cased, bert_base_uncased, bert_base_cased_multihead, bert_base_uncased_multihead
88
from .wave2vecv2 import wave2vecv2_base
99
from .hubert import hubert_base

semilearn/nets/bert/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,5 @@
22
# Licensed under the MIT License.
33

44
from .bert import bert_base_cased, bert_base_uncased
5-
# from .bert import ClassificationBert
5+
# from .bert import ClassificationBert
6+
from .bert_multihead import bert_base_cased_multihead, bert_base_uncased_multihead
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import torch
5+
import torch.nn as nn
6+
from transformers import BertModel
7+
import os
8+
9+
class ClassificationBertMultihead(nn.Module):
10+
def __init__(self, name, num_classes=2, num_heads=3, adjust_clf_size=False):
11+
super(ClassificationBertMultihead, self).__init__()
12+
13+
self.num_heads = num_heads
14+
15+
# Load pre-trained bert model
16+
self.bert = BertModel.from_pretrained(name)
17+
self.dropout = torch.nn.Dropout(p=0.1, inplace=False)
18+
self.num_features = 768
19+
self.num_features_h = self.num_features // self.num_heads if adjust_clf_size else self.num_features
20+
21+
_classifier_fn = lambda: nn.Sequential(*[
22+
nn.Linear(self.num_features, self.num_features_h),
23+
nn.GELU(),
24+
nn.Linear(self.num_features_h, num_classes)
25+
])
26+
self.classifier = self.multihead_constructor(_classifier_fn)
27+
28+
def multihead_constructor(self, constructor):
29+
return nn.ModuleList([constructor() for _ in range(self.num_heads)])
30+
31+
def forward(self, x, only_fc=False, only_feat=False, return_embed=False, **kwargs):
32+
"""
33+
Args:
34+
x: input tensor, depends on only_fc and only_feat flag
35+
only_fc: only use classifier, input should be features before classifier
36+
only_feat: only return pooled features
37+
return_embed: return word embedding, used for vat
38+
"""
39+
if only_fc:
40+
logits = self.classifier(x)
41+
return logits
42+
43+
out_dict = self.bert(**x, output_hidden_states=True, return_dict=True)
44+
last_hidden = out_dict['last_hidden_state']
45+
drop_hidden = self.dropout(last_hidden)
46+
pooled_output = torch.mean(drop_hidden, 1)
47+
48+
if only_feat:
49+
return pooled_output
50+
51+
# logits = self.classifier(pooled_output)
52+
logits = [head_classifier(pooled_output) for head_classifier in self.classifier]
53+
54+
result_dict = {'logits':logits, 'feat':pooled_output}
55+
56+
if return_embed:
57+
result_dict['embed'] = out_dict['hidden_states'][0]
58+
59+
return result_dict
60+
61+
62+
def extract(self, x):
63+
out_dict = self.bert(**x, output_hidden_states=True, return_dict=True)
64+
last_hidden = out_dict['last_hidden_state']
65+
drop_hidden = self.dropout(last_hidden)
66+
pooled_output = torch.mean(drop_hidden, 1)
67+
return pooled_output
68+
69+
def group_matcher(self, coarse=False, prefix=''):
70+
matcher = dict(stem=r'^{}bert.embeddings'.format(prefix), blocks=r'^{}bert.encoder.layer.(\d+)'.format(prefix))
71+
return matcher
72+
73+
def no_weight_decay(self):
74+
return []
75+
76+
77+
78+
def bert_base_cased_multihead(args, **kwargs):
79+
model = ClassificationBertMultihead('bert-base-cased', args.num_classes, args.num_heads, args.adjust_clf_size, **kwargs)
80+
return model
81+
82+
83+
def bert_base_uncased_multihead(args, **kwargs):
84+
model = ClassificationBertMultihead('bert-base-uncased', args.num_classes, args.num_heads, args.adjust_clf_size, **kwargs)
85+
return model

0 commit comments

Comments
 (0)