Skip to content

Commit 2870fe3

Browse files
authored
Support qwen2 (#1820)
* support qwen2 * fix flake
1 parent 8460993 commit 2870fe3

File tree

2 files changed

+109
-1
lines changed

2 files changed

+109
-1
lines changed

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ The project implements a custom runtime that applies many performance optimizati
99
The following model types are currently supported:
1010

1111
* Encoder-decoder models: Transformer base/big, M2M-100, NLLB, BART, mBART, Pegasus, T5, Whisper
12-
* Decoder-only models: GPT-2, GPT-J, GPT-NeoX, OPT, BLOOM, MPT, Llama, Mistral, Gemma, CodeGen, GPTBigCode, Falcon
12+
* Decoder-only models: GPT-2, GPT-J, GPT-NeoX, OPT, BLOOM, MPT, Llama, Mistral, Gemma, CodeGen, GPTBigCode, Falcon, Qwen2
1313
* Encoder-only models: BERT, DistilBERT, XLM-RoBERTa
1414

1515
Compatible models should be first converted into an optimized model format. The library includes converters for multiple frameworks:

python/ctranslate2/converters/transformers.py

Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1956,6 +1956,114 @@ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
19561956
gc.collect()
19571957

19581958

1959+
@register_loader("Qwen2Config")
1960+
class Qwen2Loader(ModelLoader):
1961+
@property
1962+
def architecture_name(self):
1963+
return "Qwen2ForCausalLM"
1964+
1965+
def get_model_spec(self, model):
1966+
num_layers = model.config.num_hidden_layers
1967+
1968+
num_heads = model.config.num_attention_heads
1969+
num_heads_kv = getattr(model.config, "num_key_value_heads", num_heads)
1970+
if num_heads_kv == num_heads:
1971+
num_heads_kv = None
1972+
1973+
rope_scaling = getattr(model.config, "rope_scaling", None)
1974+
if rope_scaling:
1975+
rope_type = rope_scaling.get("type") or rope_scaling["rope_type"]
1976+
rotary_scaling_type = _SUPPORTED_ROPE_SCALING.get(rope_type)
1977+
rotary_scaling_factor = rope_scaling["factor"]
1978+
1979+
if rotary_scaling_type is None:
1980+
raise NotImplementedError(
1981+
"RoPE scaling type '%s' is not yet implemented. "
1982+
"The following RoPE scaling types are currently supported: %s"
1983+
% (rope_scaling["type"], ", ".join(_SUPPORTED_ROPE_SCALING.keys()))
1984+
)
1985+
else:
1986+
rotary_scaling_type = None
1987+
rotary_scaling_factor = 1
1988+
1989+
spec = transformer_spec.TransformerDecoderModelSpec.from_config(
1990+
num_layers,
1991+
num_heads,
1992+
activation=common_spec.Activation.SWISH,
1993+
pre_norm=True,
1994+
ffn_glu=True,
1995+
rms_norm=True,
1996+
rotary_dim=0,
1997+
rotary_interleave=False,
1998+
rotary_scaling_type=rotary_scaling_type,
1999+
rotary_scaling_factor=rotary_scaling_factor,
2000+
rotary_base=getattr(model.config, "rope_theta", 10000),
2001+
num_heads_kv=num_heads_kv,
2002+
)
2003+
2004+
self.set_decoder(spec.decoder, model.model)
2005+
self.set_linear(spec.decoder.projection, model.lm_head)
2006+
return spec
2007+
2008+
def get_vocabulary(self, model, tokenizer):
2009+
tokens = super().get_vocabulary(model, tokenizer)
2010+
2011+
extra_ids = model.config.vocab_size - len(tokens)
2012+
for i in range(extra_ids):
2013+
tokens.append("<extra_id_%d>" % i)
2014+
return tokens
2015+
2016+
def set_vocabulary(self, spec, tokens):
2017+
spec.register_vocabulary(tokens)
2018+
2019+
def set_config(self, config, model, tokenizer):
2020+
config.bos_token = (
2021+
tokenizer.bos_token
2022+
if tokenizer.bos_token is not None
2023+
else tokenizer.pad_token
2024+
)
2025+
config.eos_token = tokenizer.eos_token
2026+
config.unk_token = (
2027+
tokenizer.unk_token if tokenizer.unk_token is not None else ""
2028+
)
2029+
config.layer_norm_epsilon = model.config.rms_norm_eps
2030+
2031+
def set_layer_norm(self, spec, layer_norm):
2032+
spec.gamma = layer_norm.weight
2033+
2034+
def set_decoder(self, spec, module):
2035+
spec.scale_embeddings = False
2036+
self.set_embeddings(spec.embeddings, module.embed_tokens)
2037+
self.set_layer_norm(spec.layer_norm, module.norm)
2038+
2039+
for layer_spec, layer in zip(spec.layer, module.layers):
2040+
self.set_layer_norm(
2041+
layer_spec.self_attention.layer_norm, layer.input_layernorm
2042+
)
2043+
self.set_layer_norm(
2044+
layer_spec.ffn.layer_norm, layer.post_attention_layernorm
2045+
)
2046+
2047+
split_layers = [common_spec.LinearSpec() for _ in range(3)]
2048+
self.set_linear(split_layers[0], layer.self_attn.q_proj)
2049+
self.set_linear(split_layers[1], layer.self_attn.k_proj)
2050+
self.set_linear(split_layers[2], layer.self_attn.v_proj)
2051+
2052+
utils.fuse_linear(layer_spec.self_attention.linear[0], split_layers)
2053+
self.set_linear(
2054+
layer_spec.self_attention.linear[1],
2055+
layer.self_attn.o_proj,
2056+
)
2057+
2058+
self.set_linear(layer_spec.ffn.linear_0, layer.mlp.gate_proj)
2059+
self.set_linear(layer_spec.ffn.linear_0_noact, layer.mlp.up_proj)
2060+
self.set_linear(layer_spec.ffn.linear_1, layer.mlp.down_proj)
2061+
2062+
delattr(layer, "self_attn")
2063+
delattr(layer, "mlp")
2064+
gc.collect()
2065+
2066+
19592067
@register_loader("MixFormerSequentialConfig")
19602068
class MixFormerSequentialLoader(ModelLoader):
19612069
@property

0 commit comments

Comments
 (0)