Skip to content

Commit 6b54dc8

Browse files
committed
Add mamba tensors
1 parent 8c526a1 commit 6b54dc8

File tree

1 file changed

+47
-6
lines changed

1 file changed

+47
-6
lines changed

convert_hf_to_gguf.py

Lines changed: 47 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2190,6 +2190,26 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None):
21902190
@ModelBase.register("Plamo2ForCausalLM")
21912191
class Plamo2Model(LlamaModel):
21922192
model_arch = gguf.MODEL_ARCH.PLAMO2
2193+
2194+
def __init__(self, *args, **kwargs):
2195+
super().__init__(*args, **kwargs)
2196+
2197+
# Add custom mappings for Plamo2's unique structure
2198+
# Plamo2 uses "mixer" for Mamba layers instead of standard attention
2199+
tensor_map = gguf.get_tensor_name_map(self.model_arch, self.block_count)
2200+
2201+
# Add Mamba-specific mappings
2202+
for i in range(self.block_count):
2203+
# SSM/Mamba tensors
2204+
tensor_map[f"model.layers.{i}.mixer.in_proj"] = f"blk.{i}.ssm_in"
2205+
tensor_map[f"model.layers.{i}.mixer.conv1d"] = f"blk.{i}.ssm_conv1d"
2206+
tensor_map[f"model.layers.{i}.mixer.x_proj"] = f"blk.{i}.ssm_x"
2207+
tensor_map[f"model.layers.{i}.mixer.dt_proj"] = f"blk.{i}.ssm_dt"
2208+
tensor_map[f"model.layers.{i}.mixer.A_log"] = f"blk.{i}.ssm_a"
2209+
tensor_map[f"model.layers.{i}.mixer.D"] = f"blk.{i}.ssm_d"
2210+
tensor_map[f"model.layers.{i}.mixer.out_proj"] = f"blk.{i}.ssm_out"
2211+
2212+
self.tensor_map = tensor_map
21932213

21942214
def set_vocab(self):
21952215
# Plamo2 uses sentencepiece tokenizer similar to Llama
@@ -2220,12 +2240,33 @@ def set_gguf_parameters(self):
22202240
def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iterable[tuple[str, Tensor]]:
22212241
# Handle Plamo2 specific tensor naming
22222242
# The model has both attention and Mamba layers
2223-
2224-
# Handle Mamba-specific tensors if present
2225-
if "mamba" in name:
2226-
# Mamba layers might need special handling
2227-
# For now, pass through with standard naming
2228-
pass
2243+
2244+
# Handle the nested layer structure: layers.layers.X
2245+
if name.startswith("model.layers.layers."):
2246+
# Extract the layer number and rest of the name
2247+
parts = name.split(".")
2248+
layer_num = parts[3] # The layer number
2249+
rest = ".".join(parts[4:]) # Everything after the layer number
2250+
2251+
# Reconstruct the name without the duplicate "layers"
2252+
name = f"model.layers.{layer_num}.{rest}"
2253+
2254+
# Handle Mamba-specific A_log tensor transformation
2255+
if name.endswith(".A_log"):
2256+
# Map the tensor name first
2257+
new_name = self.map_tensor_name(name)
2258+
logger.debug(f"A_log --> A ==> {new_name}")
2259+
# Transform A_log to A: A = -exp(A_log)
2260+
data_torch = -torch.exp(data_torch)
2261+
return [(new_name, data_torch)]
2262+
2263+
# Handle Mamba conv1d tensor shape adjustment
2264+
if "mixer.conv1d" in name:
2265+
new_name = self.map_tensor_name(name)
2266+
# Squeeze the conv1d tensor if needed
2267+
if len(data_torch.shape) == 4:
2268+
data_torch = data_torch.squeeze()
2269+
return [(new_name, data_torch)]
22292270

22302271
return super().modify_tensors(data_torch, name, bid)
22312272

0 commit comments

Comments
 (0)