diff --git a/docs/guides/transformers.md b/docs/guides/transformers.md index 39f3172a0..ffeed2fbc 100644 --- a/docs/guides/transformers.md +++ b/docs/guides/transformers.md @@ -8,6 +8,8 @@ CTranslate2 supports selected models from Hugging Face's [Transformers](https:// * CodeGen * DistilBERT * Falcon +* Gemma 2 +* Gemma 3 (text only) * Llama * M2M100 * MarianMT @@ -80,7 +82,7 @@ print(tokenizer.decode(tokenizer.convert_tokens_to_ids(target), skip_special_tok ## BERT -[BERT](https://huggingface.co/docs/transformers/model_doc/bert) is pretrained model on English language using a masked language modeling objective. +[BERT](https://huggingface.co/docs/transformers/model_doc/bert) is a pretrained model on English language using a masked language modeling objective. CTranslate2 only implements the `BertModel` class from Transformers which includes the Transformer encoder and the pooling layer. Task-specific layers should be run with PyTorch as shown in the example below. @@ -183,6 +185,43 @@ output = tokenizer.decode(results[0].sequences_ids[0]) print(output) ``` +## Gemma 3 (text only) + + +[Gemma 3](https://ai.google.dev/gemma/docs/core) is Google's latest family of lightweight, open-weight AI models, built on the same technology as Gemini. + +Gemma models come in two flavors: instruction tuned (it) models and base models. + +Instruction tuned models expect a specific [prompt template format](https://ai.google.dev/gemma/docs/core/prompt-structure) which you should use. + +When converting an instruction-tuned model, CTranslate sets `` as the default end-of-sequence token. + + +To convert a model: + +```bash +ct2-transformers-converter --model google/gemma-3-1b-it --output_dir gemma-3-1b-it +``` + +Gemma 3 usage sample: + + +```python + +from transformers import AutoTokenizer +import ctranslate2 + +tok = AutoTokenizer.from_pretrained("google/gemma-3-1b-it") +gen = ctranslate2.Generator("gemma-3-1b-it") + +prompt = "user\nGenerate a 200 word text talking about George Orwell.\nmodel\n" +tokens = tok.convert_ids_to_tokens(tok.encode(prompt)) + +res = gen.generate_batch([tokens], max_length=2048, sampling_temperature=0.1, include_prompt_in_result=False) +print(tok.convert_tokens_to_string(res[0].sequences[0])) +``` + + ## Llama 2 [Llama 2](https://ai.meta.com/llama/) is a collection of pretrained and fine-tuned generative text models ranging in scale from 7 billion to 70 billion parameters. diff --git a/include/ctranslate2/layers/attention.h b/include/ctranslate2/layers/attention.h index 5778a028c..570de73f1 100644 --- a/include/ctranslate2/layers/attention.h +++ b/include/ctranslate2/layers/attention.h @@ -2,6 +2,7 @@ #include "ctranslate2/layers/attention_layer.h" #include "ctranslate2/padder.h" +#include "ctranslate2/layers/transformer.h" namespace ctranslate2 { namespace layers { @@ -65,6 +66,8 @@ namespace ctranslate2 { dim_t _relative_right_max_position; const bool _merge_time_and_head_dims; const dim_t _cache_time_dim; + std::unique_ptr _q_norm; // Query normalization + std::unique_ptr _k_norm; // Key normalization }; } } diff --git a/python/ctranslate2/converters/transformers.py b/python/ctranslate2/converters/transformers.py index 2684dd2c7..b1cc54ba4 100644 --- a/python/ctranslate2/converters/transformers.py +++ b/python/ctranslate2/converters/transformers.py @@ -1819,6 +1819,192 @@ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2): gc.collect() +@register_loader("Gemma3TextConfig") +@register_loader("Gemma3Config") +class Gemma3Loader(ModelLoader): + @property + def architecture_name(self): + return "Gemma3ForCausalLM" + + def get_model_spec(self, model): + num_layers = model.config.num_hidden_layers + num_heads = model.config.num_attention_heads + num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads) + if num_heads_kv == num_heads: + num_heads_kv = None + + head_dim = model.config.head_dim + + activation_config = getattr( + model.config, "hidden_activation", "gelu_pytorch_tanh" + ) + + # Get RoPE parameters + rope_theta = getattr(model.config, "rope_theta", 1_000_000) # Global: 1M + rope_local_base_freq = getattr( + model.config, "rope_local_base_freq", 10_000 + ) # Local: 10k + + # Get sliding window configuration + sliding_window = getattr(model.config, "sliding_window", 1024) + layer_types = getattr(model.config, "layer_types", None) + + quantization_config = getattr(model.config, "quantization_config", None) + if quantization_config: + if quantization_config.quant_method == "awq": + quant_type = _SUPPORTED_QUANTIZATION.get(quantization_config.version) + if quant_type is None: + raise NotImplementedError( + "Quantization type '%s' is not yet implemented." + % quantization_config.quant_method + ) + else: + quant_type = common_spec.Quantization.CT2 + + # Create base spec using from_config + spec = transformer_spec.TransformerDecoderModelSpec.from_config( + num_layers, + num_heads, + activation=( + common_spec.Activation.GELU + if activation_config == "gelu" + else common_spec.Activation.GELUTanh + ), + pre_norm=True, + ffn_glu=True, + rms_norm=True, + rotary_dim=head_dim, + rotary_interleave=False, + rotary_base=rope_local_base_freq, # Default to local base freq + num_heads_kv=num_heads_kv, + head_dim=head_dim, + sliding_window=sliding_window, # Default to local sliding window + pre_post_layer_norm=True, + qk_norm=True, + ) + + # Store layer_types for use in set_decoder + self._layer_types = layer_types + + # Override per-layer settings for global vs local attention + for i, layer_type in enumerate(layer_types): + layer = spec.decoder.layer[i] + if layer_type == "full_attention": + layer.self_attention.rotary_base = np.dtype("float32").type(rope_theta) + layer.self_attention.sliding_window = np.dtype("int32").type(0) + elif layer_type == "sliding_attention": + layer.self_attention.rotary_base = np.dtype("float32").type( + rope_local_base_freq + ) + layer.self_attention.sliding_window = np.dtype("int32").type( + sliding_window + ) + + self.set_decoder(spec.decoder, model.model, quant_type) + self.set_linear(spec.decoder.projection, model.lm_head) + return spec + + def get_vocabulary(self, model, tokenizer): + tokens = super().get_vocabulary(model, tokenizer) + + extra_ids = model.config.vocab_size - len(tokens) + for i in range(extra_ids): + tokens.append("" % i) + if model.config.vocab_size < len(tokens): + tokens = tokens[: model.config.vocab_size] + + return tokens + + def set_vocabulary(self, spec, tokens): + spec.register_vocabulary(tokens) + + def set_config(self, config, model, tokenizer): + config.bos_token = tokenizer.bos_token + config.unk_token = tokenizer.unk_token + + if ( + hasattr(tokenizer, "chat_template") + and isinstance(tokenizer.chat_template, str) + and tokenizer.chat_template.strip() + ): + config.eos_token = "" + else: + config.eos_token = tokenizer.eos_token + + def set_layer_norm(self, spec, layer_norm): + spec.gamma = layer_norm.weight + 1.0 + + def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2): + spec.scale_embeddings = True + spec.start_from_zero_embedding = False + self.set_embeddings(spec.embeddings, module.embed_tokens) # Input + self.set_layer_norm(spec.layer_norm, module.norm) # Output + + for layer_spec, layer in zip(spec.layer, module.layers): + self.set_layer_norm(layer_spec.input_layer_norm, layer.input_layernorm) + + self.set_layer_norm( + layer_spec.post_attention_layer_norm, layer.post_attention_layernorm + ) + + self.set_layer_norm( + layer_spec.pre_feedforward_layer_norm, layer.pre_feedforward_layernorm + ) + + self.set_layer_norm( + layer_spec.post_feedforward_layer_norm, layer.post_feedforward_layernorm + ) + + # Set QK-norm weights (Gemma 3 uses this instead of soft-capping) + self.set_layer_norm( + layer_spec.self_attention.q_norm, layer.self_attn.q_norm + ) + self.set_layer_norm( + layer_spec.self_attention.k_norm, layer.self_attn.k_norm + ) + + # Set attention projections + split_layers = [common_spec.LinearSpec() for _ in range(3)] + self.set_linear( + split_layers[0], layer.self_attn.q_proj, quant_type=quant_type + ) + self.set_linear( + split_layers[1], layer.self_attn.k_proj, quant_type=quant_type + ) + self.set_linear( + split_layers[2], layer.self_attn.v_proj, quant_type=quant_type + ) + + if quant_type == common_spec.Quantization.CT2: + utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers) + else: + cc_dim = 1 if quant_type == common_spec.Quantization.AWQ_GEMM else 0 + utils.fuse_linear_prequant( + layer_spec.self_attention.linear[0], split_layers, cc_dim + ) + + self.set_linear( + layer_spec.self_attention.linear[1], + layer.self_attn.o_proj, + quant_type=quant_type, + ) + + # Set FFN weights + self.set_linear( + layer_spec.ffn.linear_0, layer.mlp.gate_proj, quant_type=quant_type + ) + self.set_linear( + layer_spec.ffn.linear_0_noact, layer.mlp.up_proj, quant_type=quant_type + ) + self.set_linear( + layer_spec.ffn.linear_1, layer.mlp.down_proj, quant_type=quant_type + ) + + delattr(layer, "self_attn") + delattr(layer, "mlp") + gc.collect() + + @register_loader("MistralConfig") class MistralLoader(ModelLoader): @property diff --git a/python/ctranslate2/specs/attention_spec.py b/python/ctranslate2/specs/attention_spec.py index f49d41121..97a33b1c2 100644 --- a/python/ctranslate2/specs/attention_spec.py +++ b/python/ctranslate2/specs/attention_spec.py @@ -32,6 +32,8 @@ def __init__( num_heads_kv=None, head_dim=None, sliding_window=None, + qk_norm=False, + qk_norm_rms=True, ): self.queries_scale = model_spec.OPTIONAL @@ -40,6 +42,10 @@ def __init__( common_spec.LinearSpec() for _ in range(2 if self_attention else 3) ] + if qk_norm: + self.q_norm = common_spec.LayerNormSpec(rms_norm=qk_norm_rms) + self.k_norm = common_spec.LayerNormSpec(rms_norm=qk_norm_rms) + if relative_position: self.relative_position_keys = None self.relative_position_values = None diff --git a/python/ctranslate2/specs/transformer_spec.py b/python/ctranslate2/specs/transformer_spec.py index 230e62cfd..334ecc041 100644 --- a/python/ctranslate2/specs/transformer_spec.py +++ b/python/ctranslate2/specs/transformer_spec.py @@ -109,6 +109,7 @@ def __init__( quant_type: Optional[common_spec.Quantization] = None, quant_group_size: Optional[int] = None, quant_bits: Optional[int] = None, + qk_norm: Optional[bool] = False, ): """Initializes a Transformer decoder specification. @@ -222,6 +223,7 @@ def __init__( num_heads_kv=num_heads_kv, head_dim=head_dim, sliding_window=sliding_window, + qk_norm=qk_norm, ) for _ in range(num_layers) ] @@ -286,6 +288,7 @@ def __init__( num_heads_kv=None, head_dim=None, sliding_window=None, + qk_norm=False, ): self.self_attention = attention_spec.MultiHeadAttentionSpec( self_attention=True, @@ -302,6 +305,7 @@ def __init__( num_heads_kv=num_heads_kv, head_dim=head_dim, sliding_window=sliding_window, + qk_norm=qk_norm, ) if with_encoder_attention: @@ -309,6 +313,7 @@ def __init__( rms_norm=rms_norm, num_heads_kv=num_heads_kv, sliding_window=sliding_window, + qk_norm=qk_norm, ) self.ffn = FeedForwardSpec(glu=ffn_glu, rms_norm=rms_norm) @@ -557,6 +562,7 @@ def from_config( quant_type: Optional[common_spec.Quantization] = None, quant_group_size: Optional[int] = None, quant_bits: Optional[int] = None, + qk_norm: Optional[bool] = False, ): """Creates a Transformer decoder model specification. @@ -631,6 +637,7 @@ def from_config( quant_type=quant_type, quant_group_size=quant_group_size, quant_bits=quant_bits, + qk_norm=qk_norm, ) return cls(decoder) diff --git a/src/layers/attention.cc b/src/layers/attention.cc index 6ad344410..9afd773be 100644 --- a/src/layers/attention.cc +++ b/src/layers/attention.cc @@ -310,6 +310,8 @@ namespace ctranslate2 { && !_relative_position_keys && !_relative_position_values) ,_cache_time_dim(_merge_time_and_head_dims ? 1 : 2) + , _q_norm(build_optional_layer(model, scope + "/q_norm")) + , _k_norm(build_optional_layer(model, scope + "/k_norm")) { if (_relative_position_keys) _maximum_relative_position = (_relative_position_keys->dim(0) - 1) / 2; @@ -379,6 +381,12 @@ namespace ctranslate2 { } else { split_heads(fused_proj, 2 * _num_heads, values_padder); ops::Split(1)(fused_proj, keys_proj, values_proj); + + if (_k_norm) { + StorageView keys_normed(keys_proj.dtype(), keys_proj.device()); + (*_k_norm)(keys_proj, keys_normed); + keys_proj = std::move(keys_normed); + } } if (cached_keys != nullptr) { @@ -387,6 +395,13 @@ namespace ctranslate2 { } } + if (_q_norm) { + StorageView queries_normed(queries_proj.dtype(), queries_proj.device()); + (*_q_norm)(queries_proj, queries_normed); + queries_proj = std::move(queries_normed); + } + + if (queries_proj.dim(1) == 1 && cached_keys) beam_size = queries_proj.dim(0) / cached_keys->dim(0); @@ -409,11 +424,36 @@ namespace ctranslate2 { if (_merge_time_and_head_dims) { queries_proj.reshape({queries_proj.dim(0), -1, _d_head}); + + if (_q_norm) { + StorageView queries_normed(queries_proj.dtype(), queries_proj.device()); + (*_q_norm)(queries_proj, queries_normed); + queries_proj = std::move(queries_normed); + } + + if (_k_norm) { + StorageView keys_normed(keys_proj.dtype(), keys_proj.device()); + (*_k_norm)(keys_proj, keys_normed); + keys_proj = std::move(keys_normed); + } + } else { split_heads(queries_proj, _num_heads); split_heads(keys_proj, _num_heads_kv); split_heads(values_proj, _num_heads_kv); + if (_q_norm) { + StorageView queries_normed(queries_proj.dtype(), queries_proj.device()); + (*_q_norm)(queries_proj, queries_normed); + queries_proj = std::move(queries_normed); + } + + if (_k_norm) { + StorageView keys_normed(keys_proj.dtype(), keys_proj.device()); + (*_k_norm)(keys_proj, keys_normed); + keys_proj = std::move(keys_normed); + } + replicate_heads(keys_proj, _num_heads / _num_heads_kv); replicate_heads(values_proj, _num_heads / _num_heads_kv); } @@ -421,6 +461,18 @@ namespace ctranslate2 { } else { split_heads(fused_proj, 3 * _num_heads, queries_padder); ops::Split(1)(fused_proj, queries_proj, keys_proj, values_proj); + + if (_q_norm) { + StorageView queries_normed(queries_proj.dtype(), queries_proj.device()); + (*_q_norm)(queries_proj, queries_normed); + queries_proj = std::move(queries_normed); + } + + if (_k_norm) { + StorageView keys_normed(keys_proj.dtype(), keys_proj.device()); + (*_k_norm)(keys_proj, keys_normed); + keys_proj = std::move(keys_normed); + } } if (_rotary_embeddings) {