Skip to content

Commit 79dbe25

Browse files
committed
1. Fix issues with pre-training bert model
2. Add scripts for DeBERTa v3 fine-tuning
1 parent 6e4d932 commit 79dbe25

File tree

23 files changed

+286
-50
lines changed

23 files changed

+286
-50
lines changed

DeBERTa/apps/models/masked_language_model.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class EnhancedMaskDecoder(torch.nn.Module):
3131
def __init__(self, config, vocab_size):
3232
super().__init__()
3333
self.config = config
34+
self.position_biased_input = getattr(config, 'position_biased_input', True)
3435
self.lm_head = BertLMPredictionHead(config, vocab_size)
3536

3637
def forward(self, ctx_layers, ebd_weight, target_ids, input_ids, input_mask, z_states, attention_mask, encoder, relative_pos=None):
@@ -56,19 +57,21 @@ def emd_context_layer(self, encoder_layers, z_states, attention_mask, encoder, t
5657
attention_mask = attention_mask.unsqueeze(1)
5758
target_mask = target_ids>0
5859
hidden_states = encoder_layers[-2]
59-
layers = [encoder.layer[-1] for _ in range(2)]
60-
61-
z_states += hidden_states
62-
query_mask = attention_mask
63-
query_states = z_states
64-
outputs = []
65-
rel_embeddings = encoder.get_rel_embedding()
66-
67-
for layer in layers:
68-
# TODO: pass relative pos ids
69-
output = layer(hidden_states, query_mask, return_att=False, query_states = query_states, relative_pos=relative_pos, rel_embeddings = rel_embeddings)
70-
query_states = output
71-
outputs.append(query_states)
60+
if not self.position_biased_input:
61+
layers = [encoder.layer[-1] for _ in range(2)]
62+
z_states += hidden_states
63+
query_states = z_states
64+
query_mask = attention_mask
65+
outputs = []
66+
rel_embeddings = encoder.get_rel_embedding()
67+
68+
for layer in layers:
69+
# TODO: pass relative pos ids
70+
output = layer(hidden_states, query_mask, return_att=False, query_states = query_states, relative_pos=relative_pos, rel_embeddings = rel_embeddings)
71+
query_states = output
72+
outputs.append(query_states)
73+
else:
74+
outputs = [encoder_layers[-1]]
7275

7376
_mask_index = (target_ids>0).view(-1).nonzero().view(-1)
7477
def flatten_states(q_states):

DeBERTa/apps/models/ner.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@
2222
class NERModel(NNModule):
2323
def __init__(self, config, num_labels = 2, drop_out=None, **kwargs):
2424
super().__init__(config)
25-
self.bert = DeBERTa(config)
25+
self._register_load_state_dict_pre_hook(self._pre_load_hook)
26+
self.deberta = DeBERTa(config)
2627
self.num_labels = num_labels
2728
self.proj = nn.Linear(config.hidden_size, config.hidden_size)
2829
self.classifier = nn.Linear(config.hidden_size, self.num_labels)
@@ -31,7 +32,7 @@ def __init__(self, config, num_labels = 2, drop_out=None, **kwargs):
3132
self.apply(self.init_weights)
3233

3334
def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, position_ids=None, **kwargs):
34-
outputs = self.bert(input_ids, token_type_ids=type_ids, attention_mask=input_mask, \
35+
outputs = self.deberta(input_ids, token_type_ids=type_ids, attention_mask=input_mask, \
3536
position_ids=position_ids, output_all_encoded_layers=True)
3637
encoder_layers = outputs['hidden_states']
3738
cls = encoder_layers[-1]
@@ -52,3 +53,15 @@ def forward(self, input_ids, type_ids=None, input_mask=None, labels=None, positi
5253
'logits' : logits,
5354
'loss' : loss
5455
}
56+
57+
def _pre_load_hook(self, state_dict, prefix, local_metadata, strict,
58+
missing_keys, unexpected_keys, error_msgs):
59+
new_state = dict()
60+
bert_prefix = prefix + 'bert.'
61+
deberta_prefix = prefix + 'deberta.'
62+
for k in list(state_dict.keys()):
63+
if k.startswith(bert_prefix):
64+
nk = deberta_prefix + k[len(bert_prefix):]
65+
value = state_dict[k]
66+
del state_dict[k]
67+
state_dict[nk] = value

DeBERTa/apps/run.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def create_model(args, num_labels, model_class_fn):
4646
logger.info(f'Total parameters: {sum([p.numel() for p in model.parameters()])}')
4747
return model
4848

49-
def train_model(args, model, device, train_data, eval_data):
49+
def train_model(args, model, device, train_data, eval_data, run_eval_fn):
5050
total_examples = len(train_data)
5151
num_train_steps = int(len(train_data)*args.num_train_epochs / args.train_batch_size)
5252
logger.info(" Training batch size = %d", args.train_batch_size)
@@ -56,7 +56,7 @@ def data_fn(trainer):
5656
return train_data, num_train_steps, None
5757

5858
def eval_fn(trainer, model, device, tag):
59-
results = run_eval(trainer.args, model, device, eval_data, tag, steps=trainer.trainer_state.steps)
59+
results = run_eval_fn(trainer.args, model, device, eval_data, tag, steps=trainer.trainer_state.steps)
6060
eval_metric = np.mean([v[0] for k,v in results.items() if 'train' not in k])
6161
return eval_metric
6262

@@ -285,11 +285,15 @@ def main(args):
285285
if not isinstance(device, torch.device):
286286
return 0
287287
model.to(device)
288+
run_eval_fn = task.run_eval_fn()
289+
if run_eval_fn is None:
290+
run_eval_fn = run_eval
291+
288292
if args.do_eval:
289293
run_eval(args, model, device, eval_data, prefix=args.tag)
290294

291295
if args.do_train:
292-
train_model(args, model, device, train_data, eval_data)
296+
train_model(args, model, device, train_data, eval_data, run_eval_fn)
293297

294298
if args.do_predict:
295299
run_predict(args, model, device, test_data, prefix=args.tag)

DeBERTa/apps/tasks/mlm_task.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(self, tokenizer, mask_lm_prob=0.15, max_seq_len=512, max_preds_per_
4848
if max_preds_per_seq is None:
4949
self.max_preds_per_seq = math.ceil(max_seq_len*mask_lm_prob /10)*10
5050

51-
self.max_gram = max_gram
51+
self.max_gram = max(max_gram, 1)
5252
self.mask_window = int(1/mask_lm_prob) # make ngrams per window sized context
5353
self.vocab_words = list(tokenizer.vocab.keys())
5454

@@ -168,6 +168,20 @@ def metrics_fn(logits, labels):
168168
preds = np.argmax(logits, axis=-1)
169169
acc = (preds==labels).sum()/len(labels)
170170
metrics = OrderedDict(accuracy= acc)
171+
172+
logits = torch.tensor(logits).cuda()
173+
labels = torch.tensor(labels).cuda().long()
174+
chk = 1024
175+
off = 0
176+
loss_fn = torch.nn.CrossEntropyLoss(reduction="none")
177+
losses = []
178+
while off<labels.size(0):
179+
loss = loss_fn(logits[off:off+chk, :], labels[off:off+chk])
180+
losses.append(loss)
181+
off += chk
182+
loss = torch.cat(losses).mean()
183+
ppl = loss.exp().cpu().item()
184+
metrics['PPL'] = ppl
171185
return metrics
172186
return metrics_fn
173187

DeBERTa/apps/tasks/superglue_tasks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -733,7 +733,7 @@ def example_to_feature(self, tokenizer, example, max_seq_len=512, rng=None, mask
733733
# Max Enities spans 87 # 90
734734
max_entities = 110
735735
#max_entity_span = 110
736-
max_entity_span = 90
736+
max_entity_span = 180
737737
entities = example.entity_spans
738738
assert len(entities)<=max_entities, f'Entities number {len(entities)} exceeds the maxium allowed entities {max_entities}'
739739
entity_indice = []

DeBERTa/apps/tasks/task.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,12 @@ def label2id(self, labelstr):
6464
label_dict = {l:i for i,l in enumerate(self.get_labels())}
6565
return label_dict[labelstr] if labelstr in label_dict else -1
6666

67+
def run_eval_fn(self):
68+
return None
69+
70+
def run_pred_fn(self):
71+
return None
72+
6773
def get_metrics_fn(self):
6874
"""Calcuate metrics based on prediction results"""
6975
def metrics_fn(logits, labels):

DeBERTa/deberta/bert.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,9 @@ def forward(self, input_ids, token_type_ids=None, position_ids=None, mask = None
257257
token_type_embeddings = self.token_type_embeddings(token_type_ids)
258258
embeddings += token_type_embeddings
259259

260+
if self.position_biased_input:
261+
embeddings += position_embeddings
262+
260263
if self.embedding_size != self.config.hidden_size:
261264
embeddings = self.embed_proj(embeddings)
262265

DeBERTa/deberta/cache_utils.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,14 @@ def __init__(self, name, vocab, vocab_type, model='pytorch_model.bin', config='c
3737
'base-mnli': PretrainedModel('deberta-base-mnli', 'bpe_encoder.bin', 'gpt2'),
3838
'large-mnli': PretrainedModel('deberta-large-mnli', 'bpe_encoder.bin', 'gpt2'),
3939
'xlarge-mnli': PretrainedModel('deberta-xlarge-mnli', 'bpe_encoder.bin', 'gpt2'),
40-
'xlarge-v2': PretrainedModel('deberta-xlarge-v2', 'spm.model', 'spm'),
41-
'xxlarge-v2': PretrainedModel('deberta-xxlarge-v2', 'spm.model', 'spm'),
42-
'xlarge-v2-mnli': PretrainedModel('deberta-xlarge-v2-mnli', 'spm.model', 'spm'),
43-
'xxlarge-v2-mnli': PretrainedModel('deberta-xxlarge-v2-mnli', 'spm.model', 'spm')
40+
'xlarge-v2': PretrainedModel('deberta-v2-xlarge', 'spm.model', 'spm'),
41+
'xxlarge-v2': PretrainedModel('deberta-v2-xxlarge', 'spm.model', 'spm'),
42+
'xlarge-v2-mnli': PretrainedModel('deberta-v2-xlarge-mnli', 'spm.model', 'spm'),
43+
'xxlarge-v2-mnli': PretrainedModel('deberta-v2-xxlarge-mnli', 'spm.model', 'spm'),
44+
'deberta-v3-small': PretrainedModel('deberta-v3-small', 'spm.model', 'spm'),
45+
'deberta-v3-base': PretrainedModel('deberta-v3-base', 'spm.model', 'spm'),
46+
'deberta-v3-large': PretrainedModel('deberta-v3-large', 'spm.model', 'spm'),
47+
'mdeberta-v3-base': PretrainedModel('mdeberta-v3-base', 'spm.model', 'spm'),
4448
}
4549

4650
def download_asset(url, name, tag=None, no_cache=False, cache_dir=None):

DeBERTa/deberta/disentangled_attention.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ def transpose_for_scores(self, x, attention_heads):
6969
def forward(self, hidden_states, attention_mask, return_att=False, query_states=None, relative_pos=None, rel_embeddings=None):
7070
if query_states is None:
7171
query_states = hidden_states
72-
query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads)
73-
key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads)
72+
query_layer = self.transpose_for_scores(self.query_proj(query_states), self.num_attention_heads).float()
73+
key_layer = self.transpose_for_scores(self.key_proj(hidden_states), self.num_attention_heads).float()
7474
value_layer = self.transpose_for_scores(self.value_proj(hidden_states), self.num_attention_heads)
7575

7676
rel_att = None
@@ -83,14 +83,14 @@ def forward(self, hidden_states, attention_mask, return_att=False, query_states=
8383
if 'p2p' in self.pos_att_type:
8484
scale_factor += 1
8585
scale = math.sqrt(query_layer.size(-1)*scale_factor)
86-
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2))/scale
86+
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)/scale)
8787
if self.relative_attention:
8888
rel_embeddings = self.pos_dropout(rel_embeddings)
8989
rel_att = self.disentangled_attention_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
9090

9191
if rel_att is not None:
9292
attention_scores = (attention_scores + rel_att)
93-
attention_scores = attention_scores
93+
attention_scores = (attention_scores - attention_scores.max(dim=-1, keepdim=True).values.detach()).to(hidden_states)
9494
attention_scores = attention_scores.view(-1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1))
9595

9696
# bxhxlxd
@@ -140,10 +140,10 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_
140140
# content->position
141141
if 'c2p' in self.pos_att_type:
142142
scale = math.sqrt(pos_key_layer.size(-1)*scale_factor)
143-
c2p_att = torch.bmm(query_layer, pos_key_layer.transpose(-1, -2))
143+
c2p_att = torch.bmm(query_layer/scale, pos_key_layer.transpose(-1, -2).to(query_layer))
144144
c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span*2-1)
145145
c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_pos.squeeze(0).expand([query_layer.size(0), query_layer.size(1), relative_pos.size(-1)]))
146-
score += c2p_att/scale
146+
score += c2p_att
147147

148148
# position->content
149149
if 'p2c' in self.pos_att_type or 'p2p' in self.pos_att_type:
@@ -159,11 +159,11 @@ def disentangled_attention_bias(self, query_layer, key_layer, relative_pos, rel_
159159
pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
160160

161161
if 'p2c' in self.pos_att_type:
162-
p2c_att = torch.bmm(key_layer, pos_query_layer.transpose(-1, -2))
162+
p2c_att = torch.bmm(key_layer/scale, pos_query_layer.transpose(-1, -2).to(key_layer))
163163
p2c_att = torch.gather(p2c_att, dim=-1, index=p2c_pos.squeeze(0).expand([query_layer.size(0), key_layer.size(-2), key_layer.size(-2)])).transpose(-1,-2)
164164
if query_layer.size(-2) != key_layer.size(-2):
165165
p2c_att = torch.gather(p2c_att, dim=-2, index=pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2))))
166-
score += p2c_att/scale
166+
score += p2c_att
167167

168168
# position->position
169169
if 'p2p' in self.pos_att_type:

README.md

Lines changed: 42 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,10 @@
33
This repository is the official implementation of [ **DeBERTa**: **D**ecoding-**e**nhanced **BERT** with Disentangled **A**ttention ](https://arxiv.org/abs/2006.03654)
44

55
## News
6+
### 11/16/2021
7+
- The models of our new work [DeBERTa V3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing](https://arxiv.org/abs/2111.09543) are publicly available at [huggingface model hub](https://huggingface.co/models?other=deberta-v3) now. The new models are based on DeBERTa-V2 models by replacing MLM with ELECTRA-style objective plus gradient-disentangled embedding sharing which further improves the model efficiency.
8+
- Scripts for DeBERTa V3 model fine-tuning are added
9+
610
### 3/31/2021
711
- Masked language model task is added
812
- SuperGLUE tasks is added
@@ -24,11 +28,6 @@ With DeBERTa 1.5B model, we surpass T5 11B model and human performance on SuperG
2428
### 06/13/2020
2529
We released the pre-trained models, source code, and fine-tuning scripts to reproduce some of the experimental results in the paper. You can follow similar scripts to apply DeBERTa to your own experiments or applications. Pre-training scripts will be released in the next step.
2630

27-
## TODOs
28-
- [x] Add SuperGLUE tasks
29-
- [x] Add SiFT code
30-
- [x] Add Pretraining code
31-
3231

3332
## Introduction to DeBERTa
3433
DeBERTa (Decoding-enhanced BERT with disentangled attention) improves the BERT and RoBERTa models using two novel techniques. The first is the disentangled attention mechanism, where each word is represented using two vectors that encode its content and position, respectively, and the attention weights among words are computed using disentangled matrices on their contents and relative positions. Second, an enhanced mask decoder is used to replace the output softmax layer to predict the masked tokens for model pretraining. We show that these two techniques significantly improve the efficiency of model pre-training and performance of downstream tasks.
@@ -49,9 +48,15 @@ Our pre-trained models are packaged into zipped files. You can download them fro
4948
|[XLarge-MNLI](https://huggingface.co/microsoft/deberta-xlarge-mnli)|750M|1024|48|Fine-turned with MNLI|
5049
|[Large-MNLI](https://huggingface.co/microsoft/deberta-large-mnli)|400M|1024|24|Fine-turned with MNLI|
5150
|[Base-MNLI](https://huggingface.co/microsoft/deberta-base-mnli)|140M|768|12|Fine-turned with MNLI|
51+
|[DeBERTa-V3-Large](https://huggingface.co/microsoft/deberta-v3-large)<sup>2</sup>|418M|1024| 24| 128K new SPM vocab|
52+
|[DeBERTa-V3-Base](https://huggingface.co/microsoft/deberta-v3-base)<sup>2</sup>|183M|768| 12| 128K new SPM vocab|
53+
|[DeBERTa-V3-Small](https://huggingface.co/microsoft/deberta-v3-small)<sup>2</sup>|143M|768| 6| 128K new SPM vocab|
54+
|[mDeBERTa-V3-Base](https://huggingface.co/microsoft/mdeberta)<sup>2</sup>|280M|768| 12| 250K new SPM vocab, multi-lingual model with 102 languages|
5255

5356
## Note
5457
- 1 This is the model(89.9) that surpassed **T5 11B(89.3) and human performance(89.8)** on **SuperGLUE** for the first time. 128K new SPM vocab.
58+
- 2 These V3 DeBERTa models are deberta models pre-trained with ELECTRA-style objective plus gradient-disentangled embedding sharing which significantly improves the model efficiency.
59+
5560

5661
# Try the model
5762

@@ -209,7 +214,20 @@ We present the dev results on SQuAD 1.1/2.0 and several GLUE benchmark tasks.
209214
| [DeBERTa-Large](https://huggingface.co/microsoft/deberta-large)<sup>1</sup> | 95.5/90.1 | 90.7/88.0 | 91.3/91.1| 96.5|95.3| 69.5| 91.0| 92.6/94.6| 92.3/- |92.8/92.5 |
210215
| [DeBERTa-XLarge](https://huggingface.co/microsoft/deberta-xlarge)<sup>1</sup> | -/- | -/- | 91.5/91.2| 97.0 | - | - | 93.1 | 92.1/94.3 | - |92.9/92.7|
211216
| [DeBERTa-V2-XLarge](https://huggingface.co/microsoft/deberta-v2-xlarge)<sup>1</sup>|95.8/90.8| 91.4/88.9|91.7/91.6| **97.5**| 95.8|71.1|**93.9**|92.0/94.2|92.3/89.8|92.9/92.9|
212-
|**[DeBERTa-V2-XXLarge](https://huggingface.co/microsoft/deberta-v2-xxlarge)<sup>1,2</sup>**|**96.1/91.4**|**92.2/89.7**|**91.7/91.9**|97.2|**96.0**|**72.0**| 93.5| **93.1/94.9**|**92.7/90.3** |**93.2/93.1** |
217+
|**[DeBERTa-V2-XXLarge](https://huggingface.co/microsoft/deberta-v2-xxlarge)<sup>1,2</sup>**|**96.1/91.4**|**92.2/89.7**|**91.7/91.9**|97.2|**96.0**|72.0| 93.5| **93.1/94.9**|**92.7/90.3** |**93.2/93.1** |
218+
|**[DeBERTa-V3-Large](https://huggingface.co/microsoft/deberta-v3-large)**|-/-|91.5/89.0|**91.8/91.9**|96.9|**96.0**|**75.3**| 92.7| 92.2/-|**93.0/-** |93.0/- |
219+
|[DeBERTa-V3-Base](https://huggingface.co/microsoft/deberta-v3-base)|-/-|88.4/85.4|90.6/90.7|-|-|-| -| -|- |- |
220+
|[DeBERTa-V3-Small](https://huggingface.co/microsoft/deberta-v3-base)|-/-|82.9/80.4|88.2/87.9|-|-|-| -| -|- |- |
221+
222+
#### Fine-tuning on XNLI
223+
224+
We present the dev results on XNLI with zero-shot crosslingual transfer setting, i.e. training with english data only, test on other languages.
225+
226+
| Model |avg | en | fr| es | de | el | bg | ru |tr |ar |vi | th | zh | hi | sw | ur |
227+
|--------------| ----|----|----|---- |-- |-- |-- | -- |-- |-- |-- | -- | -- | -- | -- | -- |
228+
| XLM-R-base |76.2 |85.8|79.7|80.7 |78.7 |77.5 |79.6 |78.1 |74.2 |73.8 |76.5 |74.6 |76.7| 72.4| 66.5| 68.3|
229+
| [mDeBERTa-V3-Base](https://huggingface.co/microsoft/mdeberta-v3-base)|**79.8**+/-0.2|**88.2**|**82.6**|**84.4** |**82.7** |**82.3** |**82.4** |**80.8** |**79.5** |**78.5** |**78.1** |**76.4** |**79.5**| **75.9**| **73.9**| **72.4**|
230+
213231
--------
214232
#### Notes.
215233
- <sup>1</sup> Following RoBERTa, for RTE, MRPC, STS-B, we fine-tune the tasks based on [DeBERTa-Large-MNLI](https://huggingface.co/microsoft/deberta-large-mnli), [DeBERTa-XLarge-MNLI](https://huggingface.co/microsoft/deberta-xlarge-mnli), [DeBERTa-V2-XLarge-MNLI](https://huggingface.co/microsoft/deberta-v2-xlarge-mnli), [DeBERTa-V2-XXLarge-MNLI](https://huggingface.co/microsoft/deberta-v2-xxlarge-mnli). The results of SST-2/QQP/QNLI/SQuADv2 will also be slightly improved when start from MNLI fine-tuned models, however, we only report the numbers fine-tuned from pretrained base models for those 4 tasks.
@@ -220,14 +238,25 @@ We present the dev results on SQuAD 1.1/2.0 and several GLUE benchmark tasks.
220238
Pengcheng He([email protected]), Xiaodong Liu([email protected]), Jianfeng Gao([email protected]), Weizhu Chen([email protected])
221239

222240
# Citation
241+
``` latex
242+
@misc{he2021debertav3,
243+
title={DeBERTaV3: Improving DeBERTa using ELECTRA-Style Pre-Training with Gradient-Disentangled Embedding Sharing},
244+
author={Pengcheng He and Jianfeng Gao and Weizhu Chen},
245+
year={2021},
246+
eprint={2111.09543},
247+
archivePrefix={arXiv},
248+
primaryClass={cs.CL}
249+
}
223250
```
224-
@misc{he2020deberta,
225-
title={DeBERTa: Decoding-enhanced BERT with Disentangled Attention},
226-
author={Pengcheng He and Xiaodong Liu and Jianfeng Gao and Weizhu Chen},
227-
year={2020},
228-
eprint={2006.03654},
229-
archivePrefix={arXiv},
230-
primaryClass={cs.CL}
251+
252+
``` latex
253+
@inproceedings{
254+
he2021deberta,
255+
title={DEBERTA: DECODING-ENHANCED BERT WITH DISENTANGLED ATTENTION},
256+
author={Pengcheng He and Xiaodong Liu and Jianfeng Gao and Weizhu Chen},
257+
booktitle={International Conference on Learning Representations},
258+
year={2021},
259+
url={https://openreview.net/forum?id=XPZIaotutsD}
231260
}
232261
```
233262

0 commit comments

Comments
 (0)