Skip to content

Commit f7d51a5

Browse files
committed
Udpate modify_tensors
1 parent 32eeac0 commit f7d51a5

File tree

1 file changed

+16
-7
lines changed

1 file changed

+16
-7
lines changed

convert_hf_to_gguf.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2352,29 +2352,38 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
23522352
# Transform A_log to A: A = -exp(A_log)
23532353
data_torch = -torch.exp(data_torch)
23542354

2355-
# PLaMo2 A_log is shape {d_state} but llama.cpp expects {d_state, d_inner}
2355+
# PLaMo2 A_log is shape {d_state} but llama.cpp expects {d_inner, d_state}
23562356
# Expand the tensor to the correct shape
23572357
if len(data_torch.shape) == 1:
23582358
d_state = data_torch.shape[0] # 64
23592359
d_inner = 8192 # SSM inner size for PLaMo2
23602360

2361-
# Create tensor with correct shape {d_state, d_inner} = {64, 8192}
2362-
# Each row of the matrix should contain the same value from the original 1D tensor
2363-
new_tensor = data_torch.new_zeros((d_state, d_inner))
2361+
# Create tensor with correct shape {d_inner, d_state} = {8192, 64}
2362+
# Each column of the matrix should contain the same value from the original 1D tensor
2363+
new_tensor = data_torch.new_zeros((d_inner, d_state))
23642364
for i in range(d_state):
2365-
new_tensor[i, :] = data_torch[i] # Broadcast the single value across the inner dimension
2365+
new_tensor[:, i] = data_torch[i] # Broadcast the single value across the inner dimension
23662366
data_torch = new_tensor
23672367
logger.debug(f"Expanded A tensor from {d_state} to shape: {data_torch.shape}")
23682368

23692369
return [(new_name, data_torch)]
23702370

2371-
# Handle Mamba D tensor - ensure .weight suffix
2371+
# Handle Mamba D tensor - ensure .weight suffix and expand shape
23722372
if name.endswith("mixer.D") or name.endswith("ssm.D"):
23732373
new_name = self.map_tensor_name(name)
23742374
# Add .weight suffix if not present
23752375
if not new_name.endswith(".weight"):
23762376
new_name += ".weight"
2377-
logger.debug(f"D tensor ==> {new_name}")
2377+
logger.debug(f"D tensor ==> {new_name}, original shape: {data_torch.shape}")
2378+
2379+
# PLaMo2 D is shape {64} but llama.cpp expects {8192}
2380+
# Expand D to broadcast across d_inner dimension
2381+
if len(data_torch.shape) == 1 and data_torch.shape[0] == 64:
2382+
d_inner = 8192 # SSM inner size for PLaMo2
2383+
# Repeat D values across inner dimension
2384+
data_torch = data_torch.repeat(d_inner // data_torch.shape[0])
2385+
logger.debug(f"Expanded D tensor from 64 to shape: {data_torch.shape}")
2386+
23782387
return [(new_name, data_torch)]
23792388

23802389
# Handle Mamba conv1d tensor shape adjustment

0 commit comments

Comments
 (0)