Skip to content

Commit a1eaa71

Browse files
authored
support phi-3 128k (#1700)
* support su rotary embedding * fix black * fix test rope embeddings * fix flake * fix tests * small fix * fix phi3 8k
1 parent 580c685 commit a1eaa71

File tree

6 files changed

+150
-16
lines changed

6 files changed

+150
-16
lines changed

include/ctranslate2/layers/attention_layer.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ namespace ctranslate2 {
7272
enum class RotaryScalingType {
7373
None = -1,
7474
Linear,
75+
Su,
7576
};
7677

7778
class RotaryEmbeddings {
@@ -82,6 +83,10 @@ namespace ctranslate2 {
8283
const float scaling_factor = 1,
8384
const float base = 10000,
8485
const dim_t num_initial_positions = 2048,
86+
const StorageView* long_scaling_factor = nullptr,
87+
const StorageView* short_scaling_factor = nullptr,
88+
const dim_t original_max_position_embeddings = 0,
89+
const dim_t max_position_embeddings = 0,
8590
const bool transpose = true);
8691

8792
void apply(StorageView& x, const dim_t offset = 0, bool apply = true);
@@ -110,6 +115,10 @@ namespace ctranslate2 {
110115
const float _scaling_factor;
111116
const float _base;
112117
const dim_t _num_initial_positions;
118+
std::unique_ptr<StorageView> _rotary_scaling_long_factor;
119+
std::unique_ptr<StorageView> _rotary_scaling_short_factor;
120+
const dim_t _original_max_position_embeddings;
121+
const dim_t _max_position_embeddings;
113122
const ops::Rotary _rotary_op;
114123
const bool _transpose;
115124

python/ctranslate2/converters/transformers.py

Lines changed: 42 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040

4141
_SUPPORTED_ROPE_SCALING = {
4242
"linear": attention_spec.RotaryScalingType.Linear,
43+
"su": attention_spec.RotaryScalingType.Su,
4344
}
4445

4546
_MODEL_LOADERS = {}
@@ -346,9 +347,11 @@ def set_common_layers(self, spec, module):
346347
spec.scale_embeddings = module.embed_scale
347348
self.set_position_encodings(spec.position_encodings, module.embed_positions)
348349
self.set_embeddings(
349-
spec.embeddings[0]
350-
if isinstance(spec.embeddings, list)
351-
else spec.embeddings,
350+
(
351+
spec.embeddings[0]
352+
if isinstance(spec.embeddings, list)
353+
else spec.embeddings
354+
),
352355
module.embed_tokens,
353356
)
354357

@@ -1066,9 +1069,11 @@ def set_config(self, config, model, tokenizer):
10661069
def set_stack(self, spec, module, is_decoder=False):
10671070
self.set_layer_norm(spec.layer_norm, module.final_layer_norm)
10681071
self.set_embeddings(
1069-
spec.embeddings[0]
1070-
if isinstance(spec.embeddings, list)
1071-
else spec.embeddings,
1072+
(
1073+
spec.embeddings[0]
1074+
if isinstance(spec.embeddings, list)
1075+
else spec.embeddings
1076+
),
10721077
module.embed_tokens,
10731078
)
10741079

@@ -1298,9 +1303,11 @@ def get_model_spec(self, model):
12981303
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
12991304
num_layers,
13001305
num_heads,
1301-
activation=common_spec.Activation.GELU
1302-
if activation_config == "gelu"
1303-
else common_spec.Activation.GELUTanh,
1306+
activation=(
1307+
common_spec.Activation.GELU
1308+
if activation_config == "gelu"
1309+
else common_spec.Activation.GELUTanh
1310+
),
13041311
pre_norm=True,
13051312
ffn_glu=True,
13061313
rms_norm=True,
@@ -1694,10 +1701,14 @@ def get_model_spec(self, model):
16941701
if num_heads_kv == num_heads:
16951702
num_heads_kv = None
16961703

1704+
original_max_position_embeddings = getattr(
1705+
model.config, "original_max_position_embeddings", 0
1706+
)
1707+
max_position_embeddings = getattr(model.config, "max_position_embeddings", 0)
16971708
rope_scaling = getattr(model.config, "rope_scaling", None)
16981709
if rope_scaling:
16991710
rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_scaling["type"])
1700-
rotary_scaling_factor = rope_scaling["factor"]
1711+
rotary_scaling_factor = rope_scaling.get("factor", 1)
17011712

17021713
if rotary_scaling_type is None:
17031714
raise NotImplementedError(
@@ -1721,6 +1732,8 @@ def get_model_spec(self, model):
17211732
rotary_scaling_type=rotary_scaling_type,
17221733
rotary_scaling_factor=rotary_scaling_factor,
17231734
rotary_base=getattr(model.config, "rope_theta", 10000),
1735+
original_max_position_embeddings=original_max_position_embeddings,
1736+
max_position_embeddings=max_position_embeddings,
17241737
num_heads_kv=num_heads_kv,
17251738
)
17261739

@@ -1748,6 +1761,16 @@ def set_config(self, config, model, tokenizer):
17481761
def set_layer_norm(self, spec, layer_norm):
17491762
spec.gamma = layer_norm.weight
17501763

1764+
def set_rotary_embeddings(
1765+
self, spec, rotary_scaling_long_factor, rotary_scaling_short_factor
1766+
):
1767+
spec.rotary_scaling_long_factor = torch.tensor(
1768+
rotary_scaling_long_factor, dtype=torch.float32
1769+
)
1770+
spec.rotary_scaling_short_factor = torch.tensor(
1771+
rotary_scaling_short_factor, dtype=torch.float32
1772+
)
1773+
17511774
def set_decoder(self, spec, module):
17521775
spec.scale_embeddings = False
17531776
self.set_embeddings(spec.embeddings, module.embed_tokens)
@@ -1765,6 +1788,15 @@ def set_decoder(self, spec, module):
17651788
layer_spec.self_attention.linear[0], layer.self_attn.qkv_proj
17661789
)
17671790
self.set_linear(layer_spec.self_attention.linear[1], layer.self_attn.o_proj)
1791+
if (
1792+
layer.self_attn.rotary_emb.long_factor is not None
1793+
and layer.self_attn.rotary_emb.short_factor is not None
1794+
):
1795+
self.set_rotary_embeddings(
1796+
layer_spec.self_attention,
1797+
layer.self_attn.rotary_emb.long_factor,
1798+
layer.self_attn.rotary_emb.short_factor,
1799+
)
17681800

17691801
gate_proj, up_proj = layer.mlp.gate_up_proj.weight.chunk(2, dim=0)
17701802
layer_spec.ffn.linear_0.weight = gate_proj

python/ctranslate2/specs/attention_spec.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ class RotaryScalingType(enum.IntEnum):
1010
"""RoPE scaling type."""
1111

1212
Linear = 0
13+
Su = 1
1314

1415

1516
class MultiHeadAttentionSpec(model_spec.LayerSpec):
@@ -24,6 +25,8 @@ def __init__(
2425
rotary_scaling_type=None,
2526
rotary_scaling_factor=1,
2627
rotary_base=10000,
28+
original_max_position_embeddings=0,
29+
max_position_embeddings=0,
2730
num_heads_kv=None,
2831
head_dim=None,
2932
sliding_window=None,
@@ -43,16 +46,29 @@ def __init__(
4346
self.relative_attention_bias = None
4447
self.relative_attention_max_distance = None
4548

49+
if original_max_position_embeddings != 0:
50+
self.original_max_position_embeddings = np.dtype("int32").type(
51+
original_max_position_embeddings
52+
)
53+
if max_position_embeddings != 0:
54+
self.max_position_embeddings = np.dtype("int32").type(
55+
max_position_embeddings
56+
)
57+
4658
if rotary_dim is not None:
4759
self.rotary_dim = np.dtype("int32").type(rotary_dim)
4860
self.rotary_interleave = rotary_interleave
4961
self.rotary_base = np.dtype("float32").type(rotary_base)
5062

5163
if rotary_scaling_type is not None:
5264
self.rotary_scaling_type = np.dtype("int8").type(rotary_scaling_type)
65+
if rotary_scaling_type is RotaryScalingType.Linear:
5366
self.rotary_scaling_factor = np.dtype("float32").type(
5467
rotary_scaling_factor
5568
)
69+
elif rotary_scaling_type is RotaryScalingType.Su:
70+
self.rotary_scaling_long_factor = None
71+
self.rotary_scaling_short_factor = None
5672

5773
if num_heads_kv is not None:
5874
self.num_heads_kv = np.dtype("int32").type(num_heads_kv)

python/ctranslate2/specs/model_spec.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,8 @@
3535
"float32",
3636
)
3737

38+
SKIP_CREATING_ALIAS = ("rotary_scaling_long_factor", "rotary_scaling_short_factor")
39+
3840

3941
def _join_scope(scope, name):
4042
if not scope:
@@ -175,9 +177,13 @@ def _alias_variables(self):
175177
break
176178
# Because variables can be transformed on load (e.g. transposed),
177179
# we use an element-wise equality check.
178-
if not value.is_scalar() and value.equal(other_value):
180+
scope, attr_name = _parent_scope(name)
181+
if (
182+
not value.is_scalar()
183+
and value.equal(other_value)
184+
and attr_name not in SKIP_CREATING_ALIAS
185+
):
179186
# Replace variable value by the alias name.
180-
scope, attr_name = _parent_scope(name)
181187
spec = index_spec(self, scope)
182188
setattr(spec, attr_name, other_name)
183189
break

python/ctranslate2/specs/transformer_spec.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ def __init__(
9797
rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None,
9898
rotary_scaling_factor: float = 1,
9999
rotary_base: float = 10000,
100+
original_max_position_embeddings: int = 0,
101+
max_position_embeddings: int = 0,
100102
parallel_residual: bool = False,
101103
shared_layer_norm: bool = False,
102104
multi_query_attention: bool = False,
@@ -135,6 +137,9 @@ def __init__(
135137
rotary_scaling_type: Type of RoPE scaling.
136138
rotary_scaling_factor: Factor used in the RoPE scaling.
137139
rotary_base: The base period of the rotary embeddings.
140+
original_max_position_embeddings: The original max position embeddings
141+
for Su rope embeddings
142+
max_position_embeddings: The max position embeddings for Su rope embeddings
138143
parallel_residual: Use parallel residual connections in each layer block, as used
139144
by the GPT-J and GPT-NeoX models.
140145
shared_layer_norm: When using parallel residual, share the input and post
@@ -199,6 +204,8 @@ def __init__(
199204
rotary_scaling_type=rotary_scaling_type,
200205
rotary_scaling_factor=rotary_scaling_factor,
201206
rotary_base=rotary_base,
207+
original_max_position_embeddings=original_max_position_embeddings,
208+
max_position_embeddings=max_position_embeddings,
202209
parallel_residual=parallel_residual,
203210
shared_layer_norm=shared_layer_norm,
204211
num_heads_kv=num_heads_kv,
@@ -251,6 +258,8 @@ def __init__(
251258
rotary_scaling_type=None,
252259
rotary_scaling_factor=1,
253260
rotary_base=10000,
261+
original_max_position_embeddings=0,
262+
max_position_embeddings=0,
254263
parallel_residual=False,
255264
shared_layer_norm=False,
256265
num_heads_kv=None,
@@ -267,6 +276,8 @@ def __init__(
267276
rotary_scaling_type=rotary_scaling_type,
268277
rotary_scaling_factor=rotary_scaling_factor,
269278
rotary_base=rotary_base,
279+
original_max_position_embeddings=original_max_position_embeddings,
280+
max_position_embeddings=max_position_embeddings,
270281
num_heads_kv=num_heads_kv,
271282
head_dim=head_dim,
272283
sliding_window=sliding_window,
@@ -499,6 +510,8 @@ def from_config(
499510
rotary_scaling_type: Optional[attention_spec.RotaryScalingType] = None,
500511
rotary_scaling_factor: float = 1,
501512
rotary_base: float = 10000,
513+
original_max_position_embeddings: int = 0,
514+
max_position_embeddings: int = 0,
502515
parallel_residual: bool = False,
503516
shared_layer_norm: bool = False,
504517
multi_query_attention: bool = False,
@@ -531,6 +544,9 @@ def from_config(
531544
rotary_scaling_type: Type of RoPE scaling.
532545
rotary_scaling_factor: Factor used in the RoPE scaling.
533546
rotary_base: The base period of the rotary embeddings.
547+
original_max_position_embeddings: The original max position embeddings
548+
for Su rope embeddings
549+
max_position_embeddings: The max position embeddings for Su rope embeddings
534550
parallel_residual: Use parallel residual connections in each layer block, as used
535551
by the GPT-J and GPT-NeoX models.
536552
shared_layer_norm: When using parallel residual, share the input and post
@@ -559,6 +575,8 @@ def from_config(
559575
rotary_scaling_type=rotary_scaling_type,
560576
rotary_scaling_factor=rotary_scaling_factor,
561577
rotary_base=rotary_base,
578+
original_max_position_embeddings=original_max_position_embeddings,
579+
max_position_embeddings=max_position_embeddings,
562580
parallel_residual=parallel_residual,
563581
shared_layer_norm=shared_layer_norm,
564582
multi_query_attention=multi_query_attention,

0 commit comments

Comments
 (0)