Skip to content
This repository was archived by the owner on Jan 15, 2024. It is now read-only.

Commit 1e51262

Browse files
author
Ubuntu
committed
add adpter
1 parent b1a2bed commit 1e51262

File tree

1 file changed

+27
-47
lines changed

1 file changed

+27
-47
lines changed

src/gluonnlp/models/bert.py

Lines changed: 27 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
# under the License.
1717
"""
1818
Bert Model
19-
2019
@article{devlin2018bert,
2120
title={BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding},
2221
author={Devlin, Jacob and Chang, Ming-Wei and Lee, Kenton and Toutanova, Kristina},
@@ -30,6 +29,7 @@
3029

3130
import os
3231
from typing import Tuple
32+
import json
3333

3434
import mxnet as mx
3535
from mxnet import use_np, np, npx
@@ -69,6 +69,8 @@ def google_en_uncased_bert_base():
6969
cfg.MODEL.dtype = 'float32'
7070
cfg.MODEL.layout = 'NT'
7171
cfg.MODEL.compute_layout = 'auto'
72+
cfg.MODEL.use_adapter = False
73+
cfg.MODEL.adapter_config = None
7274
# Hyper-parameters of the Initializers
7375
cfg.INITIALIZER = CN()
7476
cfg.INITIALIZER.embed = ['truncnorm', 0, 0.02]
@@ -81,6 +83,7 @@ def google_en_uncased_bert_base():
8183
return cfg
8284

8385

86+
8487
@bert_cfg_reg.register()
8588
def google_en_uncased_bert_large():
8689
cfg = google_en_uncased_bert_base()
@@ -161,6 +164,7 @@ def google_multi_cased_bert_large():
161164
'mlm_params': 'google_en_cased_bert_large/model_mlm-59ff3f6a.params',
162165
'lowercase': False,
163166
},
167+
164168
'google_en_uncased_bert_large': {
165169
'cfg': google_en_uncased_bert_large(),
166170
'vocab': 'google_en_uncased_bert_large/vocab-e6d2b21d.json',
@@ -224,7 +228,9 @@ def __init__(self, units: int = 512,
224228
weight_initializer: InitializerType = TruncNorm(stdev=0.02),
225229
bias_initializer: InitializerType = 'zeros',
226230
activation='gelu',
227-
layout='NT'):
231+
layout='NT',
232+
use_adapter='False',
233+
adapter_config={}):
228234
super().__init__()
229235
assert units % num_heads == 0,\
230236
'In BertTransformer, The units should be divided exactly ' \
@@ -236,8 +242,11 @@ def __init__(self, units: int = 512,
236242
self._output_attention = output_attention
237243
self._output_all_encodings = output_all_encodings
238244
self._layout = layout
245+
self._use_adapter = use_adapter
246+
self._adapter_config = adapter_config
239247

240248
self.all_layers = nn.HybridSequential()
249+
241250
for layer_idx in range(num_layers):
242251
self.all_layers.add(
243252
TransformerEncoderLayer(units=units,
@@ -250,7 +259,9 @@ def __init__(self, units: int = 512,
250259
bias_initializer=bias_initializer,
251260
activation=activation,
252261
layout=layout,
253-
dtype=dtype))
262+
dtype=dtype,
263+
use_adapter=use_adapter,
264+
adapter_config=adapter_config))
254265

255266
@property
256267
def layout(self):
@@ -259,28 +270,23 @@ def layout(self):
259270
def forward(self, data, valid_length):
260271
"""
261272
Generate the representation given the inputs.
262-
263273
This is used in training or fine-tuning a bert model.
264-
265274
Parameters
266275
----------
267276
data
268277
- layout = 'NT'
269278
Shape (batch_size, seq_length, C)
270279
- layout = 'TN'
271280
Shape (seq_length, batch_size, C)
272-
273281
valid_length
274282
Shape (batch_size,)
275-
276283
Returns
277284
-------
278285
out
279286
- layout = 'NT'
280287
Shape (batch_size, seq_length, C_out)
281288
- layout = 'TN'
282289
Shape (seq_length, batch_size, C_out)
283-
284290
"""
285291
if self.layout == 'NT':
286292
time_axis, batch_axis = 1, 0
@@ -336,7 +342,9 @@ def __init__(self,
336342
dtype='float32',
337343
use_pooler=True,
338344
layout='NT',
339-
compute_layout='auto'):
345+
compute_layout='auto',
346+
use_adapter=False,
347+
adapter_config={}):
340348
super().__init__()
341349
self._dtype = dtype
342350
self.use_pooler = use_pooler
@@ -351,6 +359,9 @@ def __init__(self,
351359
self.bias_initializer = bias_initializer
352360
self.layer_norm_eps = layer_norm_eps
353361
self._layout = layout
362+
self._use_adapter = use_adapter
363+
if self._use_adapter:
364+
self._adapter_config = json.loads(adapter_config)
354365
if compute_layout is None or compute_layout == 'auto':
355366
self._compute_layout = layout
356367
else:
@@ -370,7 +381,9 @@ def __init__(self,
370381
weight_initializer=weight_initializer,
371382
bias_initializer=bias_initializer,
372383
dtype=dtype,
373-
layout=self._compute_layout
384+
layout=self._compute_layout,
385+
use_adapter=self._use_adapter,
386+
adapter_config=self._adapter_config
374387
)
375388
# Construct word embedding
376389
self.word_embed = nn.Embedding(input_dim=vocab_size,
@@ -404,38 +417,31 @@ def layout(self):
404417
def forward(self, inputs, token_types, valid_length):
405418
# pylint: disable=arguments-differ
406419
"""Generate the representation given the inputs.
407-
408420
This is used in training or fine-tuning a bert model.
409-
410421
Parameters
411422
----------
412423
inputs
413424
- layout = 'NT'
414425
Shape (batch_size, seq_length)
415426
- layout = 'TN'
416427
Shape (seq_length, batch_size)
417-
418428
token_types
419429
If the inputs contain two sequences, we will set different token types for the first
420430
sentence and the second sentence.
421-
422431
- layout = 'NT'
423432
Shape (batch_size, seq_length)
424433
- layout = 'TN'
425434
Shape (batch_size, seq_length)
426-
427435
valid_length :
428436
The valid length of each sequence
429437
Shape (batch_size,)
430-
431438
Returns
432439
-------
433440
contextual_embedding
434441
- layout = 'NT'
435442
Shape (batch_size, seq_length, units).
436443
- layout = 'TN'
437444
Shape (seq_length, batch_size, units).
438-
439445
pooled_output
440446
This is optional. Shape (batch_size, units)
441447
"""
@@ -457,33 +463,27 @@ def forward(self, inputs, token_types, valid_length):
457463

458464
def get_initial_embedding(self, inputs, token_types=None):
459465
"""Get the initial token embeddings that considers the token type and positional embeddings
460-
461466
Parameters
462467
----------
463468
inputs
464469
- layout = 'NT'
465470
Shape (batch_size, seq_length)
466471
- layout = 'TN'
467472
Shape (seq_length, batch_size)
468-
469473
token_types
470474
The type of tokens. If None, it will be initialized as all zero.
471-
472475
- layout = 'NT'
473476
Shape (batch_size, seq_length)
474477
- layout = 'TN'
475478
Shape (seq_length, batch_size)
476-
477479
Returns
478480
-------
479481
embedding
480482
The initial embedding that will be fed into the encoder
481-
482483
- layout = 'NT'
483484
Shape (batch_size, seq_length, C_emb)
484485
- layout = 'TN'
485486
Shape (seq_length, batch_size, C_emb)
486-
487487
"""
488488
if self.layout == 'NT':
489489
time_axis, batch_axis = 1, 0
@@ -505,18 +505,15 @@ def get_initial_embedding(self, inputs, token_types=None):
505505

506506
def apply_pooling(self, sequence):
507507
"""Generate the representation given the inputs.
508-
509508
This is used for pre-training or fine-tuning a bert model.
510509
Get the first token of the whole sequence which is [CLS].
511-
512510
Parameters
513511
----------
514512
sequence
515513
- layout = 'NT'
516514
Shape (batch_size, sequence_length, units)
517515
- layout = 'TN'
518516
Shape (sequence_length, batch_size, units)
519-
520517
Returns
521518
-------
522519
outputs
@@ -538,7 +535,6 @@ def get_cfg(key=None):
538535
@classmethod
539536
def from_cfg(cls, cfg, use_pooler=True, dtype=None) -> 'BertModel':
540537
"""
541-
542538
Parameters
543539
----------
544540
cfg
@@ -547,7 +543,6 @@ def from_cfg(cls, cfg, use_pooler=True, dtype=None) -> 'BertModel':
547543
Whether to output the pooled feature
548544
dtype
549545
data type of the model
550-
551546
Returns
552547
-------
553548
ret
@@ -578,7 +573,9 @@ def from_cfg(cls, cfg, use_pooler=True, dtype=None) -> 'BertModel':
578573
bias_initializer=bias_initializer,
579574
use_pooler=use_pooler,
580575
layout=cfg.MODEL.layout,
581-
compute_layout=cfg.MODEL.compute_layout)
576+
compute_layout=cfg.MODEL.compute_layout,
577+
use_adapter=cfg.MODEL.use_adapter,
578+
adapter_config=cfg.MODEL.adapter_config)
582579

583580

584581
@use_np
@@ -587,7 +584,6 @@ def __init__(self, backbone_cfg,
587584
weight_initializer=None,
588585
bias_initializer=None):
589586
"""
590-
591587
Parameters
592588
----------
593589
backbone_cfg
@@ -626,39 +622,33 @@ def layout(self):
626622
def forward(self, inputs, token_types, valid_length,
627623
masked_positions):
628624
"""Getting the scores of the masked positions.
629-
630625
Parameters
631626
----------
632627
inputs
633628
- layout = 'NT'
634629
Shape (batch_size, seq_length)
635630
- layout = 'TN'
636631
Shape (seq_length, batch_size)
637-
638632
token_types
639633
If the inputs contain two sequences, we will set different token types for the first
640634
sentence and the second sentence.
641-
642635
- layout = 'NT'
643636
Shape (batch_size, seq_length)
644637
- layout = 'TN'
645638
Shape (seq_length, batch_size)
646-
647639
valid_length :
648640
The valid length of each sequence
649641
Shape (batch_size,)
650642
masked_positions :
651643
The masked position of the sequence
652644
Shape (batch_size, num_masked_positions).
653-
654645
Returns
655646
-------
656647
contextual_embedding
657648
- layout = 'NT'
658649
Shape (batch_size, seq_length, units).
659650
- layout = 'TN'
660651
Shape (seq_length, batch_size, units)
661-
662652
pooled_out
663653
Shape (batch_size, units)
664654
mlm_scores :
@@ -680,7 +670,6 @@ def __init__(self, backbone_cfg,
680670
weight_initializer=None,
681671
bias_initializer=None):
682672
"""
683-
684673
Parameters
685674
----------
686675
backbone_cfg
@@ -724,41 +713,34 @@ def layout(self):
724713
def forward(self, inputs, token_types, valid_length,
725714
masked_positions):
726715
"""Generate the representation given the inputs.
727-
728716
This is used in training or fine-tuning a bert model.
729-
730717
Parameters
731718
----------
732719
inputs
733720
- layout = 'NT'
734721
Shape (batch_size, seq_length)
735722
- layout = 'TN'
736723
Shape (seq_length, batch_size)
737-
738724
token_types
739725
If the inputs contain two sequences, we will set different token types for the first
740726
sentence and the second sentence.
741-
742727
- layout = 'NT'
743728
Shape (batch_size, seq_length)
744729
- layout = 'TN'
745730
Shape (seq_length, batch_size)
746-
747731
valid_length
748732
The valid length of each sequence
749733
Shape (batch_size,)
750734
masked_positions
751735
The masked position of the sequence
752736
Shape (batch_size, num_masked_positions).
753-
754737
Returns
755738
-------
756739
contextual_embedding
757740
- layout = 'NT'
758741
Shape (batch_size, seq_length, units).
759742
- layout = 'TN'
760743
Shape (seq_length, batch_size, units).
761-
762744
pooled_out
763745
Shape (batch_size, units)
764746
nsp_score :
@@ -787,7 +769,6 @@ def get_pretrained_bert(model_name: str = 'google_en_cased_bert_base',
787769
load_mlm: str = False)\
788770
-> Tuple[CN, HuggingFaceWordPieceTokenizer, str, str]:
789771
"""Get the pretrained bert weights
790-
791772
Parameters
792773
----------
793774
model_name
@@ -798,7 +779,6 @@ def get_pretrained_bert(model_name: str = 'google_en_cased_bert_base',
798779
Whether to load the weights of the backbone network
799780
load_mlm
800781
Whether to load the weights of MLM
801-
802782
Returns
803783
-------
804784
cfg
@@ -857,4 +837,4 @@ def get_pretrained_bert(model_name: str = 'google_en_cased_bert_base',
857837

858838
BACKBONE_REGISTRY.register('bert', [BertModel,
859839
get_pretrained_bert,
860-
list_pretrained_bert])
840+
list_pretrained_bert])

0 commit comments

Comments
 (0)