Skip to content

Commit ba489b4

Browse files
committed
wip minicpmv
1 parent 9716c7b commit ba489b4

File tree

1 file changed

+97
-8
lines changed

1 file changed

+97
-8
lines changed

convert_hf_to_gguf.py

Lines changed: 97 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,10 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
204204
f"Missing tensors: {missing}\n"
205205
f"Extra tensors: {extra}")
206206

207-
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight") -> str:
208-
if key not in gguf.MODEL_TENSORS[self.model_arch]:
209-
raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {self.model_arch!r}")
207+
def format_tensor_name(self, key: gguf.MODEL_TENSOR, bid: int | None = None, suffix: str = ".weight", is_vision = False) -> str:
208+
arch = self.vision_arch if is_vision and self.vision_arch is not None else self.model_arch
209+
if key not in gguf.MODEL_TENSORS[arch]:
210+
raise ValueError(f"Missing {key!r} for MODEL_TENSORS of {arch!r}")
210211
name: str = gguf.TENSOR_NAMES[key]
211212
if "{bid}" in name:
212213
assert bid is not None
@@ -2144,6 +2145,7 @@ def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims:
21442145
class MiniCPMModel(Model):
21452146
model_arch = gguf.MODEL_ARCH.MINICPM
21462147
proj_type: gguf.constants.CLIPProjectorType | None
2148+
resampler_n_embd = 0
21472149

21482150
def __init__(self, *args, **kwargs):
21492151
super().__init__(*args, **kwargs)
@@ -2162,6 +2164,12 @@ def __init__(self, *args, **kwargs):
21622164
self.proj_type = gguf.constants.CLIPProjectorType.MINICPMV_2_6
21632165
else:
21642166
raise ValueError(f"Unsupported MiniCPM-V version: {version}")
2167+
# TODO: how to do this without reading the whole safetensor file?
2168+
for tname, tensor in self.get_tensors():
2169+
if tname == "resampler.ln_post.bias":
2170+
self.resampler_n_embd = tensor.shape[0]
2171+
if self.resampler_n_embd < 2:
2172+
raise ValueError("Failed to detect resampler embedding size")
21652173

21662174
if self.vparams is not None and self.vision_arch is not None and self.preprocessor_config is not None:
21672175
self.preprocessor_config["image_mean"] = [0.5, 0.5, 0.5]
@@ -2220,6 +2228,12 @@ def generate_extra_tensors(self) -> Iterable[tuple[str, Tensor]]:
22202228
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_LONG), torch.tensor(long_factors, dtype=torch.float32))
22212229
yield (self.format_tensor_name(gguf.MODEL_TENSOR.ROPE_FACTORS_SHORT), torch.tensor(short_factors, dtype=torch.float32))
22222230

2231+
if self.vision_arch == gguf.MODEL_ARCH.VISION_MINICPMV:
2232+
yield (
2233+
self.format_tensor_name(gguf.MODEL_TENSOR.V_RESMPL_POS_EMBD_K, is_vision=True),
2234+
torch.from_numpy(self._get_2d_sincos_pos_embed(self.resampler_n_embd, (70, 70)))
2235+
)
2236+
22232237
def set_vocab(self):
22242238
if self.vision_arch == gguf.MODEL_ARCH.VISION_MINICPMV:
22252239
# undocumented anywhere, I only found this thanks to https://huggingface.co/openbmb/MiniCPM-V-2_6-gguf
@@ -2233,11 +2247,23 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
22332247
# For vision model
22342248
if name.startswith("llm."):
22352249
name = name.replace("llm.", "")
2236-
# attention, someone mess up and use underscore instead of dot
2237-
if name.endswith("in_proj_weight"):
2238-
name = name.replace("_weight", ".weight")
2239-
if name.endswith("in_proj_bias"):
2240-
name = name.replace("_bias", ".bias")
2250+
2251+
# split the resampler.attn.in_proj_(weight|bias) tensors into q, k, v
2252+
if name.endswith("in_proj_weight") or name.endswith("in_proj_bias"):
2253+
assert data_torch.shape[0] == 3 * self.resampler_n_embd
2254+
split_tensor = data_torch.chunk(3, dim=0)
2255+
name_q = name.replace("in_proj_", "in_proj_q.") # in_proj_q.(weight|bias)
2256+
name_k = name.replace("in_proj_", "in_proj_k.") # in_proj_k.(weight|bias)
2257+
name_v = name.replace("in_proj_", "in_proj_v.") # in_proj_v.(weight|bias)
2258+
return [
2259+
(self.map_tensor_name(name_q), split_tensor[0]),
2260+
(self.map_tensor_name(name_k), split_tensor[1]),
2261+
(self.map_tensor_name(name_v), split_tensor[2]),
2262+
]
2263+
2264+
if name == "resampler.proj" or name == "resampler.query":
2265+
name += ".weight"
2266+
22412267
if "post_layernorm" in name:
22422268
return [] # skip post_layernorm
22432269

@@ -2251,6 +2277,69 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
22512277
data_torch = LlamaModel.permute(data_torch, n_head, n_kv_head)
22522278

22532279
return [(self.map_tensor_name(name), data_torch)]
2280+
2281+
def tensor_force_quant(self, name: str, new_name: str, bid: int | None, n_dims: int) -> gguf.GGMLQuantizationType | bool:
2282+
del name, bid # unused
2283+
if "v.resmpl.query" in new_name or "v.resmpl.pos_embd_k" in new_name:
2284+
return gguf.GGMLQuantizationType.F32
2285+
if "v.resmpl." in new_name:
2286+
return gguf.GGMLQuantizationType.F32 if n_dims == 1 else gguf.GGMLQuantizationType.F16
2287+
return False
2288+
2289+
# utils to work with MiniCPM-V resampler
2290+
2291+
# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20
2292+
def _get_2d_sincos_pos_embed(self, embed_dim: int, grid_size: tuple[int, int] | int, cls_token=False) -> np.ndarray:
2293+
"""
2294+
grid_size: int of the grid height and width
2295+
return:
2296+
pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
2297+
"""
2298+
if isinstance(grid_size, int):
2299+
grid_h_size, grid_w_size = grid_size, grid_size
2300+
else:
2301+
grid_h_size, grid_w_size = grid_size[0], grid_size[1]
2302+
2303+
grid_h = np.arange(grid_h_size, dtype=np.float32)
2304+
grid_w = np.arange(grid_w_size, dtype=np.float32)
2305+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
2306+
grid = np.stack(grid, axis=0)
2307+
2308+
grid = grid.reshape([2, 1, grid_h_size, grid_w_size])
2309+
pos_embed = self._get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
2310+
if cls_token:
2311+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
2312+
return pos_embed
2313+
2314+
def _get_2d_sincos_pos_embed_from_grid(self, embed_dim: int, grid: np.ndarray) -> np.ndarray:
2315+
assert embed_dim % 2 == 0
2316+
2317+
# use half of dimensions to encode grid_h
2318+
emb_h = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
2319+
emb_w = self._get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
2320+
2321+
emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
2322+
return emb
2323+
2324+
def _get_1d_sincos_pos_embed_from_grid(self, embed_dim: int, pos: np.ndarray) -> np.ndarray:
2325+
"""
2326+
embed_dim: output dimension for each position
2327+
pos: a list of positions to be encoded: size (M,)
2328+
out: (M, D)
2329+
"""
2330+
assert embed_dim % 2 == 0
2331+
omega = np.arange(embed_dim // 2, dtype=np.float32)
2332+
omega /= embed_dim / 2.
2333+
omega = 1. / 10000 ** omega # (D/2,)
2334+
2335+
pos = pos.reshape(-1) # (M,)
2336+
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
2337+
2338+
emb_sin = np.sin(out) # (M, D/2)
2339+
emb_cos = np.cos(out) # (M, D/2)
2340+
2341+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
2342+
return emb
22542343

22552344

22562345
@Model.register("MiniCPM3ForCausalLM")

0 commit comments

Comments
 (0)