diff --git a/src/transformers/integrations/ggml.py b/src/transformers/integrations/ggml.py index 57f0af5667e6..1aa11cf43ca5 100644 --- a/src/transformers/integrations/ggml.py +++ b/src/transformers/integrations/ggml.py @@ -330,7 +330,7 @@ "rope.dimension_count": None, "rope.freq_base": "rope_theta", "attention.head_count": "num_attention_heads", - "attention.head_count_kv": "num_key_value_heads", + "attention.head_count_kv": "num_kv_heads", "attention.layer_norm_rms_epsilon": "rms_norm_eps", "vocab_size": "vocab_size", }, diff --git a/src/transformers/modeling_gguf_pytorch_utils.py b/src/transformers/modeling_gguf_pytorch_utils.py index 00c080fbea81..4cdeec8268f4 100644 --- a/src/transformers/modeling_gguf_pytorch_utils.py +++ b/src/transformers/modeling_gguf_pytorch_utils.py @@ -69,6 +69,42 @@ def process(self, weights, name, **kwargs): return GGUFTensor(weights, name, {}) +class FalconTensorProcessor(TensorProcessor): + def __init__(self, config=None): + super().__init__(config=config) + + def process(self, weights, name, **kwargs): + if "qkv" in name: + shape = weights.shape + weights_copy = weights.copy() + parsed_parameters = kwargs.get("parsed_parameters") + num_attention_heads = parsed_parameters["config"]["num_attention_heads"] + num_key_value_heads = parsed_parameters["config"]["num_kv_heads"] + hidden_size = parsed_parameters["config"]["hidden_size"] + head_dim = hidden_size // num_attention_heads + + # Split the weights array into q, k, v + split_indices = [ + num_attention_heads * head_dim, + num_attention_heads * head_dim + num_key_value_heads * head_dim, + ] + + q, k, v = np.split(weights_copy, split_indices) + + # Reshape q, k, and v as needed + q = q.reshape(num_key_value_heads, num_attention_heads // num_key_value_heads, head_dim, hidden_size) + k = k.reshape(num_key_value_heads, 1, head_dim, hidden_size) + v = v.reshape(num_key_value_heads, 1, head_dim, hidden_size) + + # Concatenate q, k, and v along the second dimension + qkv = np.concatenate((q, k, v), axis=1) + + # Reshape qkv back to the original shape + weights = qkv.reshape(shape) + + return GGUFTensor(weights, name, {}) + + class LlamaTensorProcessor(TensorProcessor): def __init__(self, config=None): super().__init__(config=config) @@ -246,6 +282,7 @@ def process(self, weights, name, **kwargs): "t5encoder": T5TensorProcessor, "gpt2": GPT2TensorProcessor, "mamba": MambaTensorProcessor, + "falcon": FalconTensorProcessor, } @@ -321,6 +358,8 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False): f"From file name, cannot determine the number of parameters for {architecture} architecture" ) model_size = m.group().strip("-") # only keeps `7b` + if model_size == "40b": + parsed_parameters["config"]["new_decoder_architecture"] = True if architecture + model_size not in GGUF_SUPPORTED_ARCHITECTURES: raise ValueError(f"Architecture {architecture + model_size} not supported") diff --git a/tests/quantization/ggml/test_ggml.py b/tests/quantization/ggml/test_ggml.py index 508975865c27..8851747322df 100644 --- a/tests/quantization/ggml/test_ggml.py +++ b/tests/quantization/ggml/test_ggml.py @@ -64,6 +64,7 @@ class GgufIntegrationTests(unittest.TestCase): mamba_model_id = "jpodivin/mamba-2.8b-hf-GGUF" nemotron_original_model_id = "nvidia/Nemotron-Mini-4B-Instruct" nemotron_model_id = "bartowski/Nemotron-Mini-4B-Instruct-GGUF" + falcon40b_model_id = "tensorblock/falcon-40b-GGUF" # standard quants q4_0_gguf_model_id = "tinyllama-1.1b-chat-v1.0.Q4_0.gguf" @@ -111,6 +112,7 @@ class GgufIntegrationTests(unittest.TestCase): fp16_mamba_model_id = "ggml-model-f16.gguf" q6_k_nemotron_model_id = "Nemotron-Mini-4B-Instruct-Q6_K.gguf" fp16_nemotron_model_id = "Nemotron-Mini-4B-Instruct-f16.gguf" + q2_falcon40b_id = "falcon-40b-Q2_K.gguf" example_text = "Hello" @@ -612,7 +614,7 @@ def test_falcon40b_q2_k(self): text = tokenizer(self.example_text, return_tensors="pt").to(torch_device) out = model.generate(**text, max_new_tokens=10) - EXPECTED_TEXT = "Hello All,\nI am new to this forum." + EXPECTED_TEXT = "Hello All,\nOn ne sait plus quoi manger," self.assertEqual(tokenizer.decode(out[0], skip_special_tokens=True), EXPECTED_TEXT) def test_falcon7b_q2_k(self):