Skip to content

Commit f89fa2b

Browse files
authored
support minimum gemma 2 (#1772)
* support minimum gemma 2 * fix ci * fix ci
1 parent 6647945 commit f89fa2b

File tree

6 files changed

+169
-2
lines changed

6 files changed

+169
-2
lines changed

include/ctranslate2/layers/transformer.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ namespace ctranslate2 {
119119
const std::unique_ptr<const LayerNorm> _shared_layer_norm;
120120
const std::unique_ptr<const LayerNorm> _input_layer_norm;
121121
const std::unique_ptr<const LayerNorm> _post_attention_layer_norm;
122+
const std::unique_ptr<const LayerNorm> _pre_feedforward_layer_norm;
123+
const std::unique_ptr<const LayerNorm> _post_feedforward_layer_norm;
122124
const std::unique_ptr<const AttentionLayer> _encoder_attention;
123125
const FeedForwardNetwork _ff;
124126
};

python/ctranslate2/converters/transformers.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1421,6 +1421,110 @@ def set_decoder(self, spec, module):
14211421
gc.collect()
14221422

14231423

1424+
@register_loader("Gemma2Config")
1425+
class Gemma2Loader(ModelLoader):
1426+
@property
1427+
def architecture_name(self):
1428+
return "Gemma2ForCausalLM"
1429+
1430+
def get_model_spec(self, model):
1431+
num_layers = model.config.num_hidden_layers
1432+
1433+
num_heads = model.config.num_attention_heads
1434+
num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
1435+
if num_heads_kv == num_heads:
1436+
num_heads_kv = None
1437+
1438+
activation_config = getattr(
1439+
model.config, "hidden_activation", "gelu_pytorch_tanh"
1440+
)
1441+
1442+
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
1443+
num_layers,
1444+
num_heads,
1445+
activation=(
1446+
common_spec.Activation.GELU
1447+
if activation_config == "gelu"
1448+
else common_spec.Activation.GELUTanh
1449+
),
1450+
pre_norm=True,
1451+
ffn_glu=True,
1452+
rms_norm=True,
1453+
rotary_dim=0,
1454+
rotary_interleave=False,
1455+
rotary_base=getattr(model.config, "rope_theta", 10000),
1456+
num_heads_kv=num_heads_kv,
1457+
head_dim=model.config.head_dim,
1458+
pre_post_layer_norm=True,
1459+
)
1460+
1461+
self.set_decoder(spec.decoder, model.model)
1462+
self.set_linear(spec.decoder.projection, model.lm_head)
1463+
spec.decoder.embeddings.multiply_by_sqrt_depth = model.config.hidden_size**0.5
1464+
return spec
1465+
1466+
def get_vocabulary(self, model, tokenizer):
1467+
tokens = super().get_vocabulary(model, tokenizer)
1468+
1469+
extra_ids = model.config.vocab_size - len(tokens)
1470+
for i in range(extra_ids):
1471+
tokens.append("<extra_id_%d>" % i)
1472+
if model.config.vocab_size < len(tokens):
1473+
tokens = tokens[: model.config.vocab_size]
1474+
1475+
return tokens
1476+
1477+
def set_vocabulary(self, spec, tokens):
1478+
spec.register_vocabulary(tokens)
1479+
1480+
def set_config(self, config, model, tokenizer):
1481+
config.bos_token = tokenizer.bos_token
1482+
config.eos_token = tokenizer.eos_token
1483+
config.unk_token = tokenizer.unk_token
1484+
config.layer_norm_epsilon = model.config.rms_norm_eps
1485+
1486+
def set_layer_norm(self, spec, layer_norm):
1487+
spec.gamma = layer_norm.weight
1488+
spec.layer_norm_use_residual = True
1489+
1490+
def set_decoder(self, spec, module):
1491+
spec.scale_embeddings = True
1492+
spec.start_from_zero_embedding = False
1493+
self.set_embeddings(spec.embeddings, module.embed_tokens)
1494+
self.set_layer_norm(spec.layer_norm, module.norm)
1495+
1496+
for layer_spec, layer in zip(spec.layer, module.layers):
1497+
self.set_layer_norm(layer_spec.input_layer_norm, layer.input_layernorm)
1498+
1499+
self.set_layer_norm(
1500+
layer_spec.post_attention_layer_norm, layer.post_attention_layernorm
1501+
)
1502+
1503+
self.set_layer_norm(
1504+
layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm
1505+
)
1506+
1507+
self.set_layer_norm(
1508+
layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm
1509+
)
1510+
1511+
wq = layer.self_attn.q_proj.weight
1512+
wk = layer.self_attn.k_proj.weight
1513+
wv = layer.self_attn.v_proj.weight
1514+
wo = layer.self_attn.o_proj.weight
1515+
1516+
layer_spec.self_attention.linear[0].weight = torch.cat([wq, wk, wv])
1517+
layer_spec.self_attention.linear[1].weight = wo
1518+
1519+
self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
1520+
self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
1521+
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)
1522+
1523+
delattr(layer, "self_attn")
1524+
delattr(layer, "mlp")
1525+
gc.collect()
1526+
1527+
14241528
@register_loader("LlamaConfig")
14251529
class LlamaLoader(ModelLoader):
14261530
@property

python/ctranslate2/specs/transformer_spec.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ def __init__(
101101
max_position_embeddings: int = 0,
102102
parallel_residual: bool = False,
103103
shared_layer_norm: bool = False,
104+
pre_post_layer_norm: bool = False,
104105
multi_query_attention: bool = False,
105106
num_heads_kv: Optional[int] = None,
106107
head_dim: Optional[int] = None,
@@ -147,6 +148,7 @@ def __init__(
147148
by the GPT-J and GPT-NeoX models.
148149
shared_layer_norm: When using parallel residual, share the input and post
149150
attention layer norms.
151+
pre_post_layer_norm: Add post layer norm for each pre norm layer
150152
multi_query_attention: Use multi-query attention (alias for num_heads_kv=1).
151153
num_heads_kv: Number of attention heads for the key and value.
152154
sliding_window: Max sequence length to retain in KV Cache.
@@ -216,6 +218,7 @@ def __init__(
216218
max_position_embeddings=max_position_embeddings,
217219
parallel_residual=parallel_residual,
218220
shared_layer_norm=shared_layer_norm,
221+
pre_post_layer_norm=pre_post_layer_norm,
219222
num_heads_kv=num_heads_kv,
220223
head_dim=head_dim,
221224
sliding_window=sliding_window,
@@ -279,6 +282,7 @@ def __init__(
279282
max_position_embeddings=0,
280283
parallel_residual=False,
281284
shared_layer_norm=False,
285+
pre_post_layer_norm=False,
282286
num_heads_kv=None,
283287
head_dim=None,
284288
sliding_window=None,
@@ -319,6 +323,21 @@ def __init__(
319323
delattr(self.self_attention, "layer_norm")
320324
delattr(self.ffn, "layer_norm")
321325

326+
if pre_post_layer_norm:
327+
self.input_layer_norm = common_spec.LayerNormSpec(rms_norm=rms_norm)
328+
self.post_attention_layer_norm = common_spec.LayerNormSpec(
329+
rms_norm=rms_norm
330+
)
331+
self.pre_feedforward_layer_norm = common_spec.LayerNormSpec(
332+
rms_norm=rms_norm
333+
)
334+
self.post_feedforward_layer_norm = common_spec.LayerNormSpec(
335+
rms_norm=rms_norm
336+
)
337+
338+
delattr(self.self_attention, "layer_norm")
339+
delattr(self.ffn, "layer_norm")
340+
322341

323342
class FeedForwardSpec(model_spec.LayerSpec):
324343
def __init__(self, glu=False, rms_norm=False):
@@ -530,6 +549,7 @@ def from_config(
530549
max_position_embeddings: int = 0,
531550
parallel_residual: bool = False,
532551
shared_layer_norm: bool = False,
552+
pre_post_layer_norm: bool = False,
533553
multi_query_attention: bool = False,
534554
num_heads_kv: Optional[int] = None,
535555
head_dim: Optional[int] = None,
@@ -570,6 +590,7 @@ def from_config(
570590
by the GPT-J and GPT-NeoX models.
571591
shared_layer_norm: When using parallel residual, share the input and post
572592
attention layer norms.
593+
pre_post_layer_norm: add post layer norm for each pre norm layer
573594
multi_query_attention: Use multi-query attention (alias for num_heads_kv=1).
574595
num_heads_kv: Number of attention heads for the key and value.
575596
head_dim: Number of head
@@ -602,6 +623,7 @@ def from_config(
602623
max_position_embeddings=max_position_embeddings,
603624
parallel_residual=parallel_residual,
604625
shared_layer_norm=shared_layer_norm,
626+
pre_post_layer_norm=pre_post_layer_norm,
605627
multi_query_attention=multi_query_attention,
606628
num_heads_kv=num_heads_kv,
607629
head_dim=head_dim,

src/layers/attention.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -363,7 +363,7 @@ namespace ctranslate2 {
363363
if (queries_padder)
364364
queries_padder->add_padding(fused_proj);
365365

366-
const ops::Split split_op(2, {_d_model, _num_heads_kv * _d_head, _num_heads_kv * _d_head});
366+
const ops::Split split_op(2, {_num_heads * _d_head, _num_heads_kv * _d_head, _num_heads_kv * _d_head});
367367
split_op(fused_proj, queries_proj, keys_proj, values_proj);
368368

369369
if (_merge_time_and_head_dims) {

src/layers/transformer.cc

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,10 @@ namespace ctranslate2 {
120120
, _input_layer_norm(build_optional_layer<LayerNorm>(model, scope + "/input_layer_norm"))
121121
, _post_attention_layer_norm(build_optional_layer<LayerNorm>(
122122
model, scope + "/post_attention_layer_norm"))
123+
, _pre_feedforward_layer_norm(build_optional_layer<LayerNorm>(
124+
model, scope + "/pre_feedforward_layer_norm"))
125+
, _post_feedforward_layer_norm(build_optional_layer<LayerNorm>(
126+
model, scope + "/post_feedforward_layer_norm"))
123127
, _encoder_attention(build_optional_layer<MultiHeadAttention>(model,
124128
scope + "/attention",
125129
num_heads,
@@ -149,6 +153,41 @@ namespace ctranslate2 {
149153
const DataType dtype = input.dtype();
150154
const Device device = input.device();
151155

156+
const bool pre_post_layer_norm = _post_feedforward_layer_norm && _pre_feedforward_layer_norm;
157+
if (pre_post_layer_norm) {
158+
StorageView hidden(dtype, device);
159+
StorageView context(dtype, device);
160+
(*_input_layer_norm)(input, hidden);
161+
162+
if (_self_attention)
163+
(*_self_attention)(hidden,
164+
hidden,
165+
input_length,
166+
context,
167+
cached_self_attn_keys,
168+
cached_self_attn_values,
169+
nullptr,
170+
input_padder,
171+
input_padder,
172+
true,
173+
position_bias,
174+
offset);
175+
176+
(*_post_attention_layer_norm)(context, output);
177+
ops::Add()(output, input, output);
178+
179+
context = std::move(output);
180+
(*_pre_feedforward_layer_norm)(context, output);
181+
hidden = std::move(output);
182+
183+
_ff(hidden, output);
184+
185+
hidden = std::move(output);
186+
(*_post_feedforward_layer_norm)(hidden, output);
187+
ops::Add()(output, context, output);
188+
return;
189+
}
190+
152191
const bool use_parallel_residual = _shared_layer_norm || _input_layer_norm;
153192

154193
if (use_parallel_residual) {

third_party/googletest

Submodule googletest updated 245 files

0 commit comments

Comments
 (0)