1616# under the License.
1717"""
1818Bert Model
19-
2019@article{devlin2018bert,
2120 title={BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding},
2221 author={Devlin, Jacob and Chang, Ming-Wei and Lee, Kenton and Toutanova, Kristina},
3029
3130import os
3231from typing import Tuple
32+ import json
3333
3434import mxnet as mx
3535from mxnet import use_np , np , npx
@@ -69,6 +69,8 @@ def google_en_uncased_bert_base():
6969 cfg .MODEL .dtype = 'float32'
7070 cfg .MODEL .layout = 'NT'
7171 cfg .MODEL .compute_layout = 'auto'
72+ cfg .MODEL .use_adapter = False
73+ cfg .MODEL .adapter_config = None
7274 # Hyper-parameters of the Initializers
7375 cfg .INITIALIZER = CN ()
7476 cfg .INITIALIZER .embed = ['truncnorm' , 0 , 0.02 ]
@@ -81,6 +83,7 @@ def google_en_uncased_bert_base():
8183 return cfg
8284
8385
86+
8487@bert_cfg_reg .register ()
8588def google_en_uncased_bert_large ():
8689 cfg = google_en_uncased_bert_base ()
@@ -161,6 +164,7 @@ def google_multi_cased_bert_large():
161164 'mlm_params' : 'google_en_cased_bert_large/model_mlm-59ff3f6a.params' ,
162165 'lowercase' : False ,
163166 },
167+
164168 'google_en_uncased_bert_large' : {
165169 'cfg' : google_en_uncased_bert_large (),
166170 'vocab' : 'google_en_uncased_bert_large/vocab-e6d2b21d.json' ,
@@ -224,7 +228,9 @@ def __init__(self, units: int = 512,
224228 weight_initializer : InitializerType = TruncNorm (stdev = 0.02 ),
225229 bias_initializer : InitializerType = 'zeros' ,
226230 activation = 'gelu' ,
227- layout = 'NT' ):
231+ layout = 'NT' ,
232+ use_adapter = 'False' ,
233+ adapter_config = {}):
228234 super ().__init__ ()
229235 assert units % num_heads == 0 ,\
230236 'In BertTransformer, The units should be divided exactly ' \
@@ -236,8 +242,11 @@ def __init__(self, units: int = 512,
236242 self ._output_attention = output_attention
237243 self ._output_all_encodings = output_all_encodings
238244 self ._layout = layout
245+ self ._use_adapter = use_adapter
246+ self ._adapter_config = adapter_config
239247
240248 self .all_layers = nn .HybridSequential ()
249+
241250 for layer_idx in range (num_layers ):
242251 self .all_layers .add (
243252 TransformerEncoderLayer (units = units ,
@@ -250,7 +259,9 @@ def __init__(self, units: int = 512,
250259 bias_initializer = bias_initializer ,
251260 activation = activation ,
252261 layout = layout ,
253- dtype = dtype ))
262+ dtype = dtype ,
263+ use_adapter = use_adapter ,
264+ adapter_config = adapter_config ))
254265
255266 @property
256267 def layout (self ):
@@ -259,28 +270,23 @@ def layout(self):
259270 def forward (self , data , valid_length ):
260271 """
261272 Generate the representation given the inputs.
262-
263273 This is used in training or fine-tuning a bert model.
264-
265274 Parameters
266275 ----------
267276 data
268277 - layout = 'NT'
269278 Shape (batch_size, seq_length, C)
270279 - layout = 'TN'
271280 Shape (seq_length, batch_size, C)
272-
273281 valid_length
274282 Shape (batch_size,)
275-
276283 Returns
277284 -------
278285 out
279286 - layout = 'NT'
280287 Shape (batch_size, seq_length, C_out)
281288 - layout = 'TN'
282289 Shape (seq_length, batch_size, C_out)
283-
284290 """
285291 if self .layout == 'NT' :
286292 time_axis , batch_axis = 1 , 0
@@ -336,7 +342,9 @@ def __init__(self,
336342 dtype = 'float32' ,
337343 use_pooler = True ,
338344 layout = 'NT' ,
339- compute_layout = 'auto' ):
345+ compute_layout = 'auto' ,
346+ use_adapter = False ,
347+ adapter_config = {}):
340348 super ().__init__ ()
341349 self ._dtype = dtype
342350 self .use_pooler = use_pooler
@@ -351,6 +359,9 @@ def __init__(self,
351359 self .bias_initializer = bias_initializer
352360 self .layer_norm_eps = layer_norm_eps
353361 self ._layout = layout
362+ self ._use_adapter = use_adapter
363+ if self ._use_adapter :
364+ self ._adapter_config = json .loads (adapter_config )
354365 if compute_layout is None or compute_layout == 'auto' :
355366 self ._compute_layout = layout
356367 else :
@@ -370,7 +381,9 @@ def __init__(self,
370381 weight_initializer = weight_initializer ,
371382 bias_initializer = bias_initializer ,
372383 dtype = dtype ,
373- layout = self ._compute_layout
384+ layout = self ._compute_layout ,
385+ use_adapter = self ._use_adapter ,
386+ adapter_config = self ._adapter_config
374387 )
375388 # Construct word embedding
376389 self .word_embed = nn .Embedding (input_dim = vocab_size ,
@@ -404,38 +417,31 @@ def layout(self):
404417 def forward (self , inputs , token_types , valid_length ):
405418 # pylint: disable=arguments-differ
406419 """Generate the representation given the inputs.
407-
408420 This is used in training or fine-tuning a bert model.
409-
410421 Parameters
411422 ----------
412423 inputs
413424 - layout = 'NT'
414425 Shape (batch_size, seq_length)
415426 - layout = 'TN'
416427 Shape (seq_length, batch_size)
417-
418428 token_types
419429 If the inputs contain two sequences, we will set different token types for the first
420430 sentence and the second sentence.
421-
422431 - layout = 'NT'
423432 Shape (batch_size, seq_length)
424433 - layout = 'TN'
425434 Shape (batch_size, seq_length)
426-
427435 valid_length :
428436 The valid length of each sequence
429437 Shape (batch_size,)
430-
431438 Returns
432439 -------
433440 contextual_embedding
434441 - layout = 'NT'
435442 Shape (batch_size, seq_length, units).
436443 - layout = 'TN'
437444 Shape (seq_length, batch_size, units).
438-
439445 pooled_output
440446 This is optional. Shape (batch_size, units)
441447 """
@@ -457,33 +463,27 @@ def forward(self, inputs, token_types, valid_length):
457463
458464 def get_initial_embedding (self , inputs , token_types = None ):
459465 """Get the initial token embeddings that considers the token type and positional embeddings
460-
461466 Parameters
462467 ----------
463468 inputs
464469 - layout = 'NT'
465470 Shape (batch_size, seq_length)
466471 - layout = 'TN'
467472 Shape (seq_length, batch_size)
468-
469473 token_types
470474 The type of tokens. If None, it will be initialized as all zero.
471-
472475 - layout = 'NT'
473476 Shape (batch_size, seq_length)
474477 - layout = 'TN'
475478 Shape (seq_length, batch_size)
476-
477479 Returns
478480 -------
479481 embedding
480482 The initial embedding that will be fed into the encoder
481-
482483 - layout = 'NT'
483484 Shape (batch_size, seq_length, C_emb)
484485 - layout = 'TN'
485486 Shape (seq_length, batch_size, C_emb)
486-
487487 """
488488 if self .layout == 'NT' :
489489 time_axis , batch_axis = 1 , 0
@@ -505,18 +505,15 @@ def get_initial_embedding(self, inputs, token_types=None):
505505
506506 def apply_pooling (self , sequence ):
507507 """Generate the representation given the inputs.
508-
509508 This is used for pre-training or fine-tuning a bert model.
510509 Get the first token of the whole sequence which is [CLS].
511-
512510 Parameters
513511 ----------
514512 sequence
515513 - layout = 'NT'
516514 Shape (batch_size, sequence_length, units)
517515 - layout = 'TN'
518516 Shape (sequence_length, batch_size, units)
519-
520517 Returns
521518 -------
522519 outputs
@@ -538,7 +535,6 @@ def get_cfg(key=None):
538535 @classmethod
539536 def from_cfg (cls , cfg , use_pooler = True , dtype = None ) -> 'BertModel' :
540537 """
541-
542538 Parameters
543539 ----------
544540 cfg
@@ -547,7 +543,6 @@ def from_cfg(cls, cfg, use_pooler=True, dtype=None) -> 'BertModel':
547543 Whether to output the pooled feature
548544 dtype
549545 data type of the model
550-
551546 Returns
552547 -------
553548 ret
@@ -578,7 +573,9 @@ def from_cfg(cls, cfg, use_pooler=True, dtype=None) -> 'BertModel':
578573 bias_initializer = bias_initializer ,
579574 use_pooler = use_pooler ,
580575 layout = cfg .MODEL .layout ,
581- compute_layout = cfg .MODEL .compute_layout )
576+ compute_layout = cfg .MODEL .compute_layout ,
577+ use_adapter = cfg .MODEL .use_adapter ,
578+ adapter_config = cfg .MODEL .adapter_config )
582579
583580
584581@use_np
@@ -587,7 +584,6 @@ def __init__(self, backbone_cfg,
587584 weight_initializer = None ,
588585 bias_initializer = None ):
589586 """
590-
591587 Parameters
592588 ----------
593589 backbone_cfg
@@ -626,39 +622,33 @@ def layout(self):
626622 def forward (self , inputs , token_types , valid_length ,
627623 masked_positions ):
628624 """Getting the scores of the masked positions.
629-
630625 Parameters
631626 ----------
632627 inputs
633628 - layout = 'NT'
634629 Shape (batch_size, seq_length)
635630 - layout = 'TN'
636631 Shape (seq_length, batch_size)
637-
638632 token_types
639633 If the inputs contain two sequences, we will set different token types for the first
640634 sentence and the second sentence.
641-
642635 - layout = 'NT'
643636 Shape (batch_size, seq_length)
644637 - layout = 'TN'
645638 Shape (seq_length, batch_size)
646-
647639 valid_length :
648640 The valid length of each sequence
649641 Shape (batch_size,)
650642 masked_positions :
651643 The masked position of the sequence
652644 Shape (batch_size, num_masked_positions).
653-
654645 Returns
655646 -------
656647 contextual_embedding
657648 - layout = 'NT'
658649 Shape (batch_size, seq_length, units).
659650 - layout = 'TN'
660651 Shape (seq_length, batch_size, units)
661-
662652 pooled_out
663653 Shape (batch_size, units)
664654 mlm_scores :
@@ -680,7 +670,6 @@ def __init__(self, backbone_cfg,
680670 weight_initializer = None ,
681671 bias_initializer = None ):
682672 """
683-
684673 Parameters
685674 ----------
686675 backbone_cfg
@@ -724,41 +713,34 @@ def layout(self):
724713 def forward (self , inputs , token_types , valid_length ,
725714 masked_positions ):
726715 """Generate the representation given the inputs.
727-
728716 This is used in training or fine-tuning a bert model.
729-
730717 Parameters
731718 ----------
732719 inputs
733720 - layout = 'NT'
734721 Shape (batch_size, seq_length)
735722 - layout = 'TN'
736723 Shape (seq_length, batch_size)
737-
738724 token_types
739725 If the inputs contain two sequences, we will set different token types for the first
740726 sentence and the second sentence.
741-
742727 - layout = 'NT'
743728 Shape (batch_size, seq_length)
744729 - layout = 'TN'
745730 Shape (seq_length, batch_size)
746-
747731 valid_length
748732 The valid length of each sequence
749733 Shape (batch_size,)
750734 masked_positions
751735 The masked position of the sequence
752736 Shape (batch_size, num_masked_positions).
753-
754737 Returns
755738 -------
756739 contextual_embedding
757740 - layout = 'NT'
758741 Shape (batch_size, seq_length, units).
759742 - layout = 'TN'
760743 Shape (seq_length, batch_size, units).
761-
762744 pooled_out
763745 Shape (batch_size, units)
764746 nsp_score :
@@ -787,7 +769,6 @@ def get_pretrained_bert(model_name: str = 'google_en_cased_bert_base',
787769 load_mlm : str = False )\
788770 -> Tuple [CN , HuggingFaceWordPieceTokenizer , str , str ]:
789771 """Get the pretrained bert weights
790-
791772 Parameters
792773 ----------
793774 model_name
@@ -798,7 +779,6 @@ def get_pretrained_bert(model_name: str = 'google_en_cased_bert_base',
798779 Whether to load the weights of the backbone network
799780 load_mlm
800781 Whether to load the weights of MLM
801-
802782 Returns
803783 -------
804784 cfg
@@ -857,4 +837,4 @@ def get_pretrained_bert(model_name: str = 'google_en_cased_bert_base',
857837
858838BACKBONE_REGISTRY .register ('bert' , [BertModel ,
859839 get_pretrained_bert ,
860- list_pretrained_bert ])
840+ list_pretrained_bert ])
0 commit comments