@@ -1956,6 +1956,114 @@ def set_decoder(self, spec, module, quant_type=common_spec.Quantization.CT2):
19561956 gc .collect ()
19571957
19581958
1959+ @register_loader ("Qwen2Config" )
1960+ class Qwen2Loader (ModelLoader ):
1961+ @property
1962+ def architecture_name (self ):
1963+ return "Qwen2ForCausalLM"
1964+
1965+ def get_model_spec (self , model ):
1966+ num_layers = model .config .num_hidden_layers
1967+
1968+ num_heads = model .config .num_attention_heads
1969+ num_heads_kv = getattr (model .config , "num_key_value_heads" , num_heads )
1970+ if num_heads_kv == num_heads :
1971+ num_heads_kv = None
1972+
1973+ rope_scaling = getattr (model .config , "rope_scaling" , None )
1974+ if rope_scaling :
1975+ rope_type = rope_scaling .get ("type" ) or rope_scaling ["rope_type" ]
1976+ rotary_scaling_type = _SUPPORTED_ROPE_SCALING .get (rope_type )
1977+ rotary_scaling_factor = rope_scaling ["factor" ]
1978+
1979+ if rotary_scaling_type is None :
1980+ raise NotImplementedError (
1981+ "RoPE scaling type '%s' is not yet implemented. "
1982+ "The following RoPE scaling types are currently supported: %s"
1983+ % (rope_scaling ["type" ], ", " .join (_SUPPORTED_ROPE_SCALING .keys ()))
1984+ )
1985+ else :
1986+ rotary_scaling_type = None
1987+ rotary_scaling_factor = 1
1988+
1989+ spec = transformer_spec .TransformerDecoderModelSpec .from_config (
1990+ num_layers ,
1991+ num_heads ,
1992+ activation = common_spec .Activation .SWISH ,
1993+ pre_norm = True ,
1994+ ffn_glu = True ,
1995+ rms_norm = True ,
1996+ rotary_dim = 0 ,
1997+ rotary_interleave = False ,
1998+ rotary_scaling_type = rotary_scaling_type ,
1999+ rotary_scaling_factor = rotary_scaling_factor ,
2000+ rotary_base = getattr (model .config , "rope_theta" , 10000 ),
2001+ num_heads_kv = num_heads_kv ,
2002+ )
2003+
2004+ self .set_decoder (spec .decoder , model .model )
2005+ self .set_linear (spec .decoder .projection , model .lm_head )
2006+ return spec
2007+
2008+ def get_vocabulary (self , model , tokenizer ):
2009+ tokens = super ().get_vocabulary (model , tokenizer )
2010+
2011+ extra_ids = model .config .vocab_size - len (tokens )
2012+ for i in range (extra_ids ):
2013+ tokens .append ("<extra_id_%d>" % i )
2014+ return tokens
2015+
2016+ def set_vocabulary (self , spec , tokens ):
2017+ spec .register_vocabulary (tokens )
2018+
2019+ def set_config (self , config , model , tokenizer ):
2020+ config .bos_token = (
2021+ tokenizer .bos_token
2022+ if tokenizer .bos_token is not None
2023+ else tokenizer .pad_token
2024+ )
2025+ config .eos_token = tokenizer .eos_token
2026+ config .unk_token = (
2027+ tokenizer .unk_token if tokenizer .unk_token is not None else ""
2028+ )
2029+ config .layer_norm_epsilon = model .config .rms_norm_eps
2030+
2031+ def set_layer_norm (self , spec , layer_norm ):
2032+ spec .gamma = layer_norm .weight
2033+
2034+ def set_decoder (self , spec , module ):
2035+ spec .scale_embeddings = False
2036+ self .set_embeddings (spec .embeddings , module .embed_tokens )
2037+ self .set_layer_norm (spec .layer_norm , module .norm )
2038+
2039+ for layer_spec , layer in zip (spec .layer , module .layers ):
2040+ self .set_layer_norm (
2041+ layer_spec .self_attention .layer_norm , layer .input_layernorm
2042+ )
2043+ self .set_layer_norm (
2044+ layer_spec .ffn .layer_norm , layer .post_attention_layernorm
2045+ )
2046+
2047+ split_layers = [common_spec .LinearSpec () for _ in range (3 )]
2048+ self .set_linear (split_layers [0 ], layer .self_attn .q_proj )
2049+ self .set_linear (split_layers [1 ], layer .self_attn .k_proj )
2050+ self .set_linear (split_layers [2 ], layer .self_attn .v_proj )
2051+
2052+ utils .fuse_linear (layer_spec .self_attention .linear [0 ], split_layers )
2053+ self .set_linear (
2054+ layer_spec .self_attention .linear [1 ],
2055+ layer .self_attn .o_proj ,
2056+ )
2057+
2058+ self .set_linear (layer_spec .ffn .linear_0 , layer .mlp .gate_proj )
2059+ self .set_linear (layer_spec .ffn .linear_0_noact , layer .mlp .up_proj )
2060+ self .set_linear (layer_spec .ffn .linear_1 , layer .mlp .down_proj )
2061+
2062+ delattr (layer , "self_attn" )
2063+ delattr (layer , "mlp" )
2064+ gc .collect ()
2065+
2066+
19592067@register_loader ("MixFormerSequentialConfig" )
19602068class MixFormerSequentialLoader (ModelLoader ):
19612069 @property
0 commit comments