@@ -42,6 +42,7 @@ def __init__(
4242 use_ckpting = [],
4343 parallel_gpu = 1 ,
4444 sliding_window = 0 ,
45+ rotary_interleave = True ,
4546 ):
4647 """
4748 Args:
@@ -60,6 +61,9 @@ def __init__(
6061 max_relative_positions (int):
6162 Max distance between inputs in relative positions
6263 representations
64+ relative_positions_buckets (int):
65+ relative position bias see
66+ https://github.com/google-research/text-to-text-transfer-transformer
6367 aan_useffn (bool): Turn on the FFN layer in the AAN decoder
6468 full_context_alignment (bool):
6569 whether enable an extra full context decoder forward for
@@ -69,9 +73,19 @@ def __init__(
6973 pos_ffn_activation_fn (ActivationFunction):
7074 activation function choice for PositionwiseFeedForward layer
7175 add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
76+ num_kv (int): number of heads for KV when different vs Q (multiquery)
77+ add_ffnbias (bool): whether to add bias to the FF nn.Linear
78+ parallel_residual (bool): Use parallel residual connections in each layer block, as used
79+ by the GPT-J and GPT-NeoX models
80+ shared_layer_norm (bool): When using parallel residual, share the input and post
81+ attention layer norms.
7282 layer_norm (string): type of layer normalization standard/rms
7383 norm_eps (float): layer norm epsilon
74-
84+ use_ckpting (List): layers for which we checkpoint for backward
85+ parallel_gpu (int): Number of gpu for tensor parallelism
86+ sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
87+ rotary_interleave (bool): Interleave the head dimensions when rotary
88+ embeddings are applied
7589 """
7690 super (TransformerDecoderLayerBase , self ).__init__ ()
7791
@@ -83,6 +97,7 @@ def __init__(
8397 dropout = attention_dropout ,
8498 max_relative_positions = max_relative_positions ,
8599 relative_positions_buckets = relative_positions_buckets ,
100+ rotary_interleave = rotary_interleave ,
86101 attn_type = "self" ,
87102 add_qkvbias = add_qkvbias ,
88103 num_kv = num_kv ,
@@ -238,6 +253,7 @@ def __init__(
238253 use_ckpting = [],
239254 parallel_gpu = 1 ,
240255 sliding_window = 0 ,
256+ rotary_interleave = True ,
241257 ):
242258 """
243259 Args:
@@ -266,6 +282,7 @@ def __init__(
266282 use_ckpting = use_ckpting ,
267283 parallel_gpu = parallel_gpu ,
268284 sliding_window = sliding_window ,
285+ rotary_interleave = rotary_interleave ,
269286 )
270287 self .context_attn = MultiHeadedAttention (
271288 heads ,
@@ -424,6 +441,7 @@ def from_opt(cls, opt, embeddings):
424441 if opt .parallel_mode == "tensor_parallel"
425442 else 1 ,
426443 sliding_window = opt .sliding_window ,
444+ rotary_interleave = opt .rotary_interleave ,
427445 )
428446
429447 def init_state (self , src , enc_out , enc_final_hs ):
@@ -486,8 +504,21 @@ class TransformerDecoder(TransformerDecoderBase):
486504 alignment_layer (int): N° Layer to supervise with for alignment guiding
487505 alignment_heads (int):
488506 N. of cross attention heads to use for alignment guiding
507+ pos_ffn_activation_fn (ActivationFunction):
508+ activation function choice for PositionwiseFeedForward layer
489509 add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
510+ num_kv (int): number of heads for KV when different vs Q (multiquery)
511+ add_ffnbias (bool): whether to add bias to the FF nn.Linear
512+ parallel_residual (bool): Use parallel residual connections in each layer block, as used
513+ by the GPT-J and GPT-NeoX models
514+ shared_layer_norm (bool): When using parallel residual, share the input and post
515+ attention layer norms.
490516 layer_norm (string): type of layer normalization standard/rms
517+ norm_eps (float): layer norm epsilon
518+ use_ckpting (List): layers for which we checkpoint for backward
519+ parallel_gpu (int): Number of gpu for tensor parallelism
520+ sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
521+ rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied
491522 """
492523
493524 def __init__ (
@@ -518,6 +549,7 @@ def __init__(
518549 use_ckpting = [],
519550 parallel_gpu = 1 ,
520551 sliding_window = 0 ,
552+ rotary_interleave = True ,
521553 ):
522554 super (TransformerDecoder , self ).__init__ (
523555 d_model , copy_attn , embeddings , alignment_layer , layer_norm , norm_eps
@@ -548,6 +580,7 @@ def __init__(
548580 use_ckpting = use_ckpting ,
549581 parallel_gpu = parallel_gpu ,
550582 sliding_window = sliding_window ,
583+ rotary_interleave = rotary_interleave ,
551584 )
552585 for i in range (num_layers )
553586 ]
@@ -716,22 +749,41 @@ def _forward(
716749class TransformerLMDecoder (TransformerDecoderBase ):
717750 """The Transformer decoder from GPT-2
718751 Args:
719- num_layers (int): number of decoder layers.
720- d_model (int): size of the model
721- heads (int): number of heads
722- d_ff (int): size of the inner FF layer
723- copy_attn (bool): if using a separate copy attention
724- self_attn_type (str): type of self-attention scaled-dot, average
725- dropout (float): dropout in residual, self-attn(dot) and feed-forward
726- attention_dropout (float): dropout in context_attn (and self-attn(avg))
727- embeddings (onmt.modules.Embeddings):
728- embeddings to use, should have positional encodings
729- max_relative_positions (int):
730- Max distance between inputs in relative positions representations
731- relative_positions_buckets (int):
732- Number of buckets when using Relative positions bias
733- aan_useffn (bool): Turn on the FFN layer in the AAN decoder
734- add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
752+ num_layers (int): number of decoder layers.
753+ d_model (int): size of the model
754+ heads (int): number of heads
755+ d_ff (int): size of the inner FF layer
756+ copy_attn (bool): if using a separate copy attention
757+ self_attn_type (str): type of self-attention scaled-dot, average
758+ dropout (float): dropout in residual, self-attn(dot) and feed-forward
759+ attention_dropout (float): dropout in context_attn (and self-attn(avg))
760+ embeddings (onmt.modules.Embeddings):
761+ embeddings to use, should have positional encodings
762+ max_relative_positions (int):
763+ Max distance between inputs in relative positions representations
764+ relative_positions_buckets (int):
765+ Number of buckets when using Relative positions bias
766+ aan_useffn (bool): Turn on the FFN layer in the AAN decoder
767+ full_context_alignment (bool):
768+ whether enable an extra full context decoder forward for alignment
769+ alignment_layer (int): N° Layer to supervise with for alignment guiding
770+ alignment_heads (int):
771+ N. of cross attention heads to use for alignment guiding
772+ pos_ffn_activation_fn (ActivationFunction):
773+ activation function choice for PositionwiseFeedForward layer
774+ add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
775+ num_kv (int): number of heads for KV when different vs Q (multiquery)
776+ add_ffnbias (bool): whether to add bias to the FF nn.Linear
777+ parallel_residual (bool): Use parallel residual connections in each layer block, as used
778+ by the GPT-J and GPT-NeoX models
779+ shared_layer_norm (bool): When using parallel residual, share the input and post
780+ attention layer norms.
781+ layer_norm (string): type of layer normalization standard/rms
782+ norm_eps (float): layer norm epsilon
783+ use_ckpting (List): layers for which we checkpoint for backward
784+ parallel_gpu (int): Number of gpu for tensor parallelism
785+ sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
786+ rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied
735787 """
736788
737789 def __init__ (
@@ -762,6 +814,7 @@ def __init__(
762814 use_ckpting = [],
763815 parallel_gpu = 1 ,
764816 sliding_window = 0 ,
817+ rotary_interleave = True ,
765818 ):
766819 super (TransformerLMDecoder , self ).__init__ (
767820 d_model , copy_attn , embeddings , alignment_layer , layer_norm , norm_eps
@@ -791,6 +844,7 @@ def __init__(
791844 use_ckpting = use_ckpting ,
792845 parallel_gpu = parallel_gpu ,
793846 sliding_window = sliding_window ,
847+ rotary_interleave = rotary_interleave ,
794848 )
795849 for i in range (num_layers )
796850 ]
0 commit comments