Skip to content

Commit 212141b

Browse files
authored
Add rotary_interleave=false (Hugging face models) (#2507)
1 parent 3743fa3 commit 212141b

File tree

4 files changed

+117
-21
lines changed

4 files changed

+117
-21
lines changed

onmt/decoders/transformer.py

Lines changed: 71 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def __init__(
4242
use_ckpting=[],
4343
parallel_gpu=1,
4444
sliding_window=0,
45+
rotary_interleave=True,
4546
):
4647
"""
4748
Args:
@@ -60,6 +61,9 @@ def __init__(
6061
max_relative_positions (int):
6162
Max distance between inputs in relative positions
6263
representations
64+
relative_positions_buckets (int):
65+
relative position bias see
66+
https://github.com/google-research/text-to-text-transfer-transformer
6367
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
6468
full_context_alignment (bool):
6569
whether enable an extra full context decoder forward for
@@ -69,9 +73,19 @@ def __init__(
6973
pos_ffn_activation_fn (ActivationFunction):
7074
activation function choice for PositionwiseFeedForward layer
7175
add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
76+
num_kv (int): number of heads for KV when different vs Q (multiquery)
77+
add_ffnbias (bool): whether to add bias to the FF nn.Linear
78+
parallel_residual (bool): Use parallel residual connections in each layer block, as used
79+
by the GPT-J and GPT-NeoX models
80+
shared_layer_norm (bool): When using parallel residual, share the input and post
81+
attention layer norms.
7282
layer_norm (string): type of layer normalization standard/rms
7383
norm_eps (float): layer norm epsilon
74-
84+
use_ckpting (List): layers for which we checkpoint for backward
85+
parallel_gpu (int): Number of gpu for tensor parallelism
86+
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
87+
rotary_interleave (bool): Interleave the head dimensions when rotary
88+
embeddings are applied
7589
"""
7690
super(TransformerDecoderLayerBase, self).__init__()
7791

@@ -83,6 +97,7 @@ def __init__(
8397
dropout=attention_dropout,
8498
max_relative_positions=max_relative_positions,
8599
relative_positions_buckets=relative_positions_buckets,
100+
rotary_interleave=rotary_interleave,
86101
attn_type="self",
87102
add_qkvbias=add_qkvbias,
88103
num_kv=num_kv,
@@ -238,6 +253,7 @@ def __init__(
238253
use_ckpting=[],
239254
parallel_gpu=1,
240255
sliding_window=0,
256+
rotary_interleave=True,
241257
):
242258
"""
243259
Args:
@@ -266,6 +282,7 @@ def __init__(
266282
use_ckpting=use_ckpting,
267283
parallel_gpu=parallel_gpu,
268284
sliding_window=sliding_window,
285+
rotary_interleave=rotary_interleave,
269286
)
270287
self.context_attn = MultiHeadedAttention(
271288
heads,
@@ -424,6 +441,7 @@ def from_opt(cls, opt, embeddings):
424441
if opt.parallel_mode == "tensor_parallel"
425442
else 1,
426443
sliding_window=opt.sliding_window,
444+
rotary_interleave=opt.rotary_interleave,
427445
)
428446

429447
def init_state(self, src, enc_out, enc_final_hs):
@@ -486,8 +504,21 @@ class TransformerDecoder(TransformerDecoderBase):
486504
alignment_layer (int): N° Layer to supervise with for alignment guiding
487505
alignment_heads (int):
488506
N. of cross attention heads to use for alignment guiding
507+
pos_ffn_activation_fn (ActivationFunction):
508+
activation function choice for PositionwiseFeedForward layer
489509
add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
510+
num_kv (int): number of heads for KV when different vs Q (multiquery)
511+
add_ffnbias (bool): whether to add bias to the FF nn.Linear
512+
parallel_residual (bool): Use parallel residual connections in each layer block, as used
513+
by the GPT-J and GPT-NeoX models
514+
shared_layer_norm (bool): When using parallel residual, share the input and post
515+
attention layer norms.
490516
layer_norm (string): type of layer normalization standard/rms
517+
norm_eps (float): layer norm epsilon
518+
use_ckpting (List): layers for which we checkpoint for backward
519+
parallel_gpu (int): Number of gpu for tensor parallelism
520+
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
521+
rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied
491522
"""
492523

493524
def __init__(
@@ -518,6 +549,7 @@ def __init__(
518549
use_ckpting=[],
519550
parallel_gpu=1,
520551
sliding_window=0,
552+
rotary_interleave=True,
521553
):
522554
super(TransformerDecoder, self).__init__(
523555
d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
@@ -548,6 +580,7 @@ def __init__(
548580
use_ckpting=use_ckpting,
549581
parallel_gpu=parallel_gpu,
550582
sliding_window=sliding_window,
583+
rotary_interleave=rotary_interleave,
551584
)
552585
for i in range(num_layers)
553586
]
@@ -716,22 +749,41 @@ def _forward(
716749
class TransformerLMDecoder(TransformerDecoderBase):
717750
"""The Transformer decoder from GPT-2
718751
Args:
719-
num_layers (int): number of decoder layers.
720-
d_model (int): size of the model
721-
heads (int): number of heads
722-
d_ff (int): size of the inner FF layer
723-
copy_attn (bool): if using a separate copy attention
724-
self_attn_type (str): type of self-attention scaled-dot, average
725-
dropout (float): dropout in residual, self-attn(dot) and feed-forward
726-
attention_dropout (float): dropout in context_attn (and self-attn(avg))
727-
embeddings (onmt.modules.Embeddings):
728-
embeddings to use, should have positional encodings
729-
max_relative_positions (int):
730-
Max distance between inputs in relative positions representations
731-
relative_positions_buckets (int):
732-
Number of buckets when using Relative positions bias
733-
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
734-
add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
752+
num_layers (int): number of decoder layers.
753+
d_model (int): size of the model
754+
heads (int): number of heads
755+
d_ff (int): size of the inner FF layer
756+
copy_attn (bool): if using a separate copy attention
757+
self_attn_type (str): type of self-attention scaled-dot, average
758+
dropout (float): dropout in residual, self-attn(dot) and feed-forward
759+
attention_dropout (float): dropout in context_attn (and self-attn(avg))
760+
embeddings (onmt.modules.Embeddings):
761+
embeddings to use, should have positional encodings
762+
max_relative_positions (int):
763+
Max distance between inputs in relative positions representations
764+
relative_positions_buckets (int):
765+
Number of buckets when using Relative positions bias
766+
aan_useffn (bool): Turn on the FFN layer in the AAN decoder
767+
full_context_alignment (bool):
768+
whether enable an extra full context decoder forward for alignment
769+
alignment_layer (int): N° Layer to supervise with for alignment guiding
770+
alignment_heads (int):
771+
N. of cross attention heads to use for alignment guiding
772+
pos_ffn_activation_fn (ActivationFunction):
773+
activation function choice for PositionwiseFeedForward layer
774+
add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
775+
num_kv (int): number of heads for KV when different vs Q (multiquery)
776+
add_ffnbias (bool): whether to add bias to the FF nn.Linear
777+
parallel_residual (bool): Use parallel residual connections in each layer block, as used
778+
by the GPT-J and GPT-NeoX models
779+
shared_layer_norm (bool): When using parallel residual, share the input and post
780+
attention layer norms.
781+
layer_norm (string): type of layer normalization standard/rms
782+
norm_eps (float): layer norm epsilon
783+
use_ckpting (List): layers for which we checkpoint for backward
784+
parallel_gpu (int): Number of gpu for tensor parallelism
785+
sliding_window (int): Width of the band mask and KV cache (cf Mistral Model)
786+
rotary_interleave (bool): Interleave the head dimensions when rotary embeddings are applied
735787
"""
736788

737789
def __init__(
@@ -762,6 +814,7 @@ def __init__(
762814
use_ckpting=[],
763815
parallel_gpu=1,
764816
sliding_window=0,
817+
rotary_interleave=True,
765818
):
766819
super(TransformerLMDecoder, self).__init__(
767820
d_model, copy_attn, embeddings, alignment_layer, layer_norm, norm_eps
@@ -791,6 +844,7 @@ def __init__(
791844
use_ckpting=use_ckpting,
792845
parallel_gpu=parallel_gpu,
793846
sliding_window=sliding_window,
847+
rotary_interleave=rotary_interleave,
794848
)
795849
for i in range(num_layers)
796850
]

onmt/encoders/transformer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,17 @@ class TransformerEncoderLayer(nn.Module):
2929
dropout (float): dropout probability(0-1.0).
3030
pos_ffn_activation_fn (ActivationFunction):
3131
activation function choice for PositionwiseFeedForward layer
32+
add_qkvbias (bool): whether to add bias to the Key/Value nn.Linear
33+
num_kv (int): number of heads for KV when different vs Q (multiquery)
34+
add_ffnbias (bool): whether to add bias to the FF nn.Linear
35+
parallel_residual (bool): Use parallel residual connections in each layer block, as used
36+
by the GPT-J and GPT-NeoX models
37+
layer_norm (string): type of layer normalization standard/rms
38+
norm_eps (float): layer norm epsilon
39+
use_ckpting (List): layers for which we checkpoint for backward
40+
parallel_gpu (int): Number of gpu for tensor parallelism
41+
rotary_interleave (bool): Interleave the head dimensions when rotary
42+
embeddings are applied
3243
"""
3344

3445
def __init__(
@@ -49,6 +60,7 @@ def __init__(
4960
norm_eps=1e-6,
5061
use_ckpting=[],
5162
parallel_gpu=1,
63+
rotary_interleave=True,
5264
):
5365
super(TransformerEncoderLayer, self).__init__()
5466

@@ -59,6 +71,7 @@ def __init__(
5971
is_decoder=False,
6072
max_relative_positions=max_relative_positions,
6173
relative_positions_buckets=relative_positions_buckets,
74+
rotary_interleave=rotary_interleave,
6275
attn_type="self",
6376
add_qkvbias=add_qkvbias,
6477
num_kv=num_kv,
@@ -163,6 +176,7 @@ def __init__(
163176
norm_eps=1e-6,
164177
use_ckpting=[],
165178
parallel_gpu=1,
179+
rotary_interleave=True,
166180
):
167181
super(TransformerEncoder, self).__init__()
168182

@@ -186,6 +200,7 @@ def __init__(
186200
norm_eps=norm_eps,
187201
use_ckpting=use_ckpting,
188202
parallel_gpu=parallel_gpu,
203+
rotary_interleave=rotary_interleave,
189204
)
190205
for i in range(num_layers)
191206
]
@@ -223,6 +238,7 @@ def from_opt(cls, opt, embeddings):
223238
parallel_gpu=opt.world_size
224239
if opt.parallel_mode == "tensor_parallel"
225240
else 1,
241+
rotary_interleave=opt.rotary_interleave,
226242
)
227243

228244
def forward(self, src, src_len=None):

onmt/modules/multi_headed_attn.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,17 @@ def rotaryembeddings(dim: int, maxseqlen=8192, base=10000):
2828
return rope
2929

3030

31-
def apply_rotary_emb(query, key, rope):
31+
def apply_rotary_emb(query, key, rope, interleave=True):
3232
query = query.transpose(1, 2)
3333
key = key.transpose(1, 2)
34+
if not interleave:
35+
query = torch.cat(
36+
(-query[..., query.shape[-1] // 2 :], query[..., : query.shape[-1] // 2]),
37+
dim=-1,
38+
)
39+
key = torch.cat(
40+
(-key[..., key.shape[-1] // 2 :], key[..., : key.shape[-1] // 2]), dim=-1
41+
)
3442
query_ = query.float().reshape(*query.shape[:-1], -1, 2)
3543
query_ = torch.view_as_complex(query_)
3644
key_ = key.float().reshape(*key.shape[:-1], -1, 2)
@@ -243,6 +251,7 @@ def __init__(
243251
is_decoder: bool = True,
244252
max_relative_positions: int = 0,
245253
relative_positions_buckets: int = 0,
254+
rotary_interleave: bool = True,
246255
attn_type: str = None,
247256
add_qkvbias=False,
248257
num_kv=0,
@@ -336,6 +345,7 @@ def __init__(
336345

337346
if max_relative_positions == -1: # rotary embeddings
338347
self.rope = rotaryembeddings(self.dim_per_head)
348+
self.rotary_interleave = rotary_interleave
339349

340350
if max_relative_positions == -2: # alibi positional bias
341351
self.alibi = AlibiPositionalBias(head_count)
@@ -403,7 +413,9 @@ def forward(
403413
start_pos = step
404414
seqlen = query.size(2)
405415
rope = self.rope[start_pos : start_pos + seqlen]
406-
query, key = apply_rotary_emb(query, key, rope=rope)
416+
query, key = apply_rotary_emb(
417+
query, key, rope, interleave=self.rotary_interleave
418+
)
407419

408420
if self.layer_cache[1]["keys"].numel() != 0:
409421
key = torch.cat((self.layer_cache[1]["keys"], key), dim=2)
@@ -440,7 +452,9 @@ def forward(
440452
start_pos = 0
441453
seqlen = query.size(2)
442454
rope = self.rope[start_pos : start_pos + seqlen].to(query.device)
443-
query, key = apply_rotary_emb(query, key, rope=rope)
455+
query, key = apply_rotary_emb(
456+
query, key, rope, interleave=self.rotary_interleave
457+
)
444458

445459
b, h, l, d = key.size()
446460
if self.num_kv > 0:

onmt/opts.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -862,6 +862,17 @@ def model_opts(parser):
862862
help="This setting enable relative position bias"
863863
"more info: https://github.com/google-research/text-to-text-transfer-transformer",
864864
)
865+
group.add(
866+
"--rotary_interleave",
867+
"-rotary_interleave",
868+
type=bool,
869+
default=True,
870+
help="Interleave the head dimensions when rotary"
871+
" embeddings are applied."
872+
" Otherwise the head dimensions are sliced in half."
873+
"True = default Llama from Meta (original)"
874+
"False = used by all Hugging face models",
875+
)
865876
group.add(
866877
"--heads",
867878
"-heads",
@@ -927,7 +938,8 @@ def model_opts(parser):
927938
"-shared_layer_norm",
928939
action="store_true",
929940
help="Use a shared layer_norm in parallel residual attention"
930-
"Note: must be true for Falcon 7B / false for Falcon 40B",
941+
"Note: must be true for Falcon 7B / false for Falcon 40B"
942+
"same for GPT-J and GPT-NeoX models",
931943
)
932944
# Alignement options
933945
group = parser.add_argument_group("Model - Alignement")

0 commit comments

Comments
 (0)