diff --git a/.gitignore b/.gitignore index 0c9ef52c..7d2f64a9 100644 --- a/.gitignore +++ b/.gitignore @@ -23,7 +23,10 @@ cache/ #GGUF *.gguf -# txt -*.txt +# # txt +# *.txt *.http + +#log +log diff --git a/TODO.md b/TODO.md new file mode 100644 index 00000000..a6258c3f --- /dev/null +++ b/TODO.md @@ -0,0 +1,18 @@ +1. 目标: + + (miniCPM + Fastcache) x (DCU & 摩尔) + (llava + Fastcache) x (DCU & 摩尔) + + +2. 具体工作拆分: + + a. DCU平台端到端跑通:llava encoder部分正确性调试【目前可上手的工作,缺的算子暂时先占位】 + 把encoder/Fastcache/llm拼到一起。 + + b. 我今天搞:摩尔平台计算资源申请 + + c. 我最近两天:两个平台缺的算子搞定 + +3. ddl: 10天后,本月25号 + +4. weight: +108:/home/weight/MiniCPM-V-2_6;/home/weight/llava-1.5-7b-hf \ No newline at end of file diff --git a/compress_ckpt/llava_mlp.bin b/compress_ckpt/llava_mlp.bin new file mode 100644 index 00000000..0a9cbe16 Binary files /dev/null and b/compress_ckpt/llava_mlp.bin differ diff --git a/compress_ckpt/llava_mlp_layerwise.bin b/compress_ckpt/llava_mlp_layerwise.bin new file mode 100644 index 00000000..fd8d529e Binary files /dev/null and b/compress_ckpt/llava_mlp_layerwise.bin differ diff --git a/compress_ckpt/minicpm_mlp.pth b/compress_ckpt/minicpm_mlp.pth new file mode 100644 index 00000000..74e52b84 Binary files /dev/null and b/compress_ckpt/minicpm_mlp.pth differ diff --git a/compress_ckpt/minicpmv_mlp.bin b/compress_ckpt/minicpmv_mlp.bin new file mode 100644 index 00000000..d4bb392c Binary files /dev/null and b/compress_ckpt/minicpmv_mlp.bin differ diff --git a/debug_data/qkv_debug.txt b/debug_data/qkv_debug.txt new file mode 100644 index 00000000..e69de29b diff --git a/docs/KVCacheCompressionMapping.md b/docs/KVCacheCompressionMapping.md new file mode 100644 index 00000000..063e2f96 --- /dev/null +++ b/docs/KVCacheCompressionMapping.md @@ -0,0 +1,34 @@ +# KV Cache Compression Weight Mapping (llava_mlp.bin) + +## 前缀与含义 +权重来自 Fastcache 的 KVCacheLinearDecoupleCompressor,`.pth` 结构位于 `compressor` 子树,键模式为 `...weight`。当前导出的 bin 包含以下前缀(已按排序写入): + +- `compress_tk`: 文本 K 压缩/投影相关权重 +- `compress_tv`: 文本 V 压缩/投影相关权重 +- `compress_iv`: 图像 V 压缩/投影相关权重(命名可能沿用 image/value 缩写) +- `compress_ik`: 图像 K 压缩/投影相关权重 +- `attention`: 压缩器内部的注意力/门控线性层(小头数,通常 slot=0..7 等) + +> 注:原始 PyTorch 权重中未见 bias;转换脚本若发现 bias 长度不匹配会跳过。 + +## 排序与写入顺序 +排序键:`prefix` 优先级(`compress_tk` → `compress_tv` → `compress_iv` → `compress_ik` → `attention`),然后 `layer` 升序,再 `slot` 升序。同一 `(prefix,layer,slot)` 下先 weight 后 bias(若存在)。 + +## 形状推断与 hidden_size +- 头部的 `hidden_size` 来自首个权重的列数(当前为 640)。 +- 每个 weight 块记录 `rows`、`cols`。可视为线性层 `out = W * in`,`W` 形状为 `[rows, cols]`,输出维度 = `rows`。 +- bias(若存在且长度==rows)紧随其后。 + +## 可能的计算图猜测(供 C++ 实现对齐) +- `compress_tk`: 对文本 K 做降维/解耦,slot 多个表示分阶段或多头混合投影。 +- `compress_tv`: 对文本 V 做降维/解耦。 +- `compress_iv`: 对图像 V 做降维/解耦。 +- `compress_ik`: 对图像 K 做降维/解耦。 +- `attention`: 压缩器内部的小型注意力/门控 MLP,用于生成压缩映射或融合文本/图像特征。 + +实际计算顺序需结合 Fastcache 的 Python 源码(`KVCacheLinearDecoupleCompressor.forward`)逐层映射,将上述权重映射到具体的线性/激活/重排操作。 + +## 与 bin 对齐的校验 +- 使用 `scripts/verify_llava_mlp_bin.py` 可对比 `.pth` 与 `.bin`:会打印头部、逐块形状及 max diff。 +- 当前验证结果:`num_layers=32`,`weight_count_per_layer=12`,384 个 weight 块,max diff=0(fp16)。 + diff --git a/docs/KVCacheCompressionOpsChecklist.md b/docs/KVCacheCompressionOpsChecklist.md new file mode 100644 index 00000000..c9b3f02a --- /dev/null +++ b/docs/KVCacheCompressionOpsChecklist.md @@ -0,0 +1,34 @@ +# KV Cache Compression 算法拆解与算子需求(llava_mlp.bin 基线) + +## 模块拆解(基于权重前缀推断) +- `compress_tk`: 文本 K 路径压缩/解耦。若 slot 多个,可能对应多阶段或多头混合投影。 +- `compress_tv`: 文本 V 路径压缩/解耦。 +- `compress_iv`: 图像 V 路径压缩/解耦。 +- `compress_ik`: 图像 K 路径压缩/解耦。 +- `attention`: 压缩器内部的小型注意力/门控线性层(可能用于融合/生成映射)。 + +## 可能的计算流程(参考 Fastcache 思路,需结合源码逐条对齐) +1) 对 KV 按类别(文本/图像)分支,按头/slot 做线性变换降维或投影。 +2) 可选的 gating/注意力:使用 `attention.*` 权重对压缩特征做融合或生成索引/权重。 +3) 生成压缩后的 K/V(seq 维缩短或维度降维),并记录映射(indices/scale)。 +4) 解压路径:根据保存的映射/scale,将压缩 K/V 恢复到注意力可消费的形式(或直接在注意力中使用压缩格式)。 + +## 需要的核心算子(优先复用 InfiniCore) +- 矩阵乘 + bias:`linear`(已有)。需支持 fp16/bf16。 +- 激活:SiLU/GELU(确认是否已有;缺失则补充逐元素 kernel)。 +- 张量重排:view/reshape/permute/slice(`Tensor` 已支持)。 +- 可能的归一化/缩放:元素级乘加(已有基础算子可组合)。 +- 可选:索引/聚合(若压缩逻辑需要采样或根据权重重排 seq)。 + +## 需要明确的映射关系(待确认) +- 每个 prefix 对应的输入/输出形状:`[B, heads, seq, dim]` → `[...]`,slot 如何映射。 +- 压缩因子作用位置:seq 维还是隐藏维(或两者结合)。 +- `attention` 权重的输入特征来源与输出用途(生成哪类权重/索引)。 +- 是否需要存储 indices/scale 以便解压或稀疏注意力。 + +## 实现阶段建议 +1) **占位压缩**:先实现保留最近 N/截断的简版,打通链路。 +2) **权重映射对齐**:阅读 Python 版 `KVCacheLinearDecoupleCompressor.forward`,写出每个 prefix 的线性/激活顺序和张量维度。 +3) **算子补齐**:如缺 SiLU/GELU,新增简洁 kernel;其余用现有 linear/elemwise 组合。 +4) **解压策略**:选择解压到密集 KV(改动小)或直接改注意力支持压缩格式(二选一)。 +5) **验证**:构造 C++/ctypes 小测试,随机 KV → 压缩 → 解压 → 对比误差,量化开销与收益。 diff --git a/docs/KVCacheCompressionWeightFormat.md b/docs/KVCacheCompressionWeightFormat.md new file mode 100644 index 00000000..575c5ba6 --- /dev/null +++ b/docs/KVCacheCompressionWeightFormat.md @@ -0,0 +1,57 @@ +# KV Cache Compression Weight Format (Binary, No PyTorch Dependency) + +## File Layout +All values are little-endian unless stated otherwise. Strings are ASCII, null-terminated. + +- Header (fixed size) + - `uint32` magic = 0x4B56434D ("KV C M") + - `uint32` version = 1 + - `uint16` dtype code: 0 = fp16, 1 = bf16, 2 = fp32 + - `uint16` reserved = 0 + - `uint32` num_layers + - `uint32` num_heads + - `uint32` head_dim + - `uint32` hidden_size + - `uint32` compression_factor (e.g., 4, 5) + - `uint32` min_seq_len + - `uint32` weight_count_per_layer (for sanity check) + - `uint32` metadata_size_bytes (future expansion; set 0 for now) +- Layer blocks (repeat `num_layers` times) + - For each weight tensor (order defined below): + - `uint32` rows + - `uint32` cols + - `uint32` has_bias (0/1) + - data blob for weight: `rows * cols * sizeof(dtype)` + - optional bias blob: `cols * sizeof(dtype)` when `has_bias==1` +- Footer + - `uint32` checksum (optional; set 0 if not used) + +## Weight Order per Layer (example for linear-decouple MLP) +Adjust if实际模型结构不同,但顺序需在导出和加载一致。 +1. `proj_k` weight (+bias) +2. `proj_v` weight (+bias) +3. `compress_k` weight (+bias) +4. `compress_v` weight (+bias) +5. `decompress_k` weight (+bias) +6. `decompress_v` weight (+bias) +7. `gate`/`mlp` weights (+bias) if算法需要 + +`weight_count_per_layer` = 实际包含的权重项数,便于解析时校验。 + +## Export Steps (one-time, in external Python env) +1) 使用 PyTorch 读取原 `.pth`:`state = torch.load(...)`. +2) 提取压缩器权重到固定顺序的列表;统一 dtype(fp16/bf16): + ```python + weights = [ + (state['proj_k.weight'], state.get('proj_k.bias')), + ... + ] + ``` +3) 写入头部;逐层写元信息 + 数据;按 `dtype` 转为字节(fp16 用 `np.float16.tobytes()`)。 +4) 填充 footer(可置 0)。 + +## Loader Expectations (C++/InfiniCore) +- 读取并验证 magic/version/dtype/层数/weight_count_per_layer。 +- 为每个权重创建 `Tensor::weight`,dtype 与头部一致。 +- 如果缺少某些权重(has_bias=0),按约定跳过 bias。 +- 解析出的权重按同样顺序存入压缩器对象,以确保前向逻辑正确。 diff --git a/include/infinicore_infer.h b/include/infinicore_infer.h index 0bed7bc7..4c73a85d 100644 --- a/include/infinicore_infer.h +++ b/include/infinicore_infer.h @@ -2,9 +2,11 @@ #define INFINICORE_INFER_H #include "infinicore_infer/cache.h" +#include "infinicore_infer/kv_compression.h" #include "infinicore_infer/weights_loader.h" #include "infinicore_infer/models/deepseek.h" #include "infinicore_infer/models/jiuge.h" +#include "infinicore_infer/models/minicpmv.h" #endif /* INFINICORE_INFER_H */ diff --git a/include/infinicore_infer/kv_compression.h b/include/infinicore_infer/kv_compression.h new file mode 100644 index 00000000..3cfc54ab --- /dev/null +++ b/include/infinicore_infer/kv_compression.h @@ -0,0 +1,25 @@ +#ifndef KV_COMPRESSION_H +#define KV_COMPRESSION_H + +#include + +#include + +struct KVCache; + +typedef struct { + uint32_t enable; + uint32_t compression_factor; + uint32_t min_seq_len; + uint32_t image_kv_len; + const char *weight_path; // path to .bin weights (see docs/KVCacheCompressionWeightFormat.md) +} KVCompressionConfig; + +// Compress KVCache in-place: +// - Reads KV from [0, seq_len) and writes compressed KV back into the same cache prefix [0, new_len). +// - Returns new_len on success; returns seq_len on no-op/failure. +__C __export uint32_t +compressKVCacheInplace(struct KVCache *kv_cache, uint32_t seq_len, const KVCompressionConfig *cfg); + +#endif // KV_COMPRESSION_H + diff --git a/include/infinicore_infer/models/jiuge.h b/include/infinicore_infer/models/jiuge.h index 1cae1223..d1b3accb 100644 --- a/include/infinicore_infer/models/jiuge.h +++ b/include/infinicore_infer/models/jiuge.h @@ -30,11 +30,15 @@ typedef struct // [dvoc, d] const void *output_embd; // nlayer * [d] - const void *const *attn_norm; + const void *const *attn_norm; // 指针数组,每层一个RMSNorm权重 // nlayer * [ndev, (nh + 2 * nkvh) / ndev * dh, d] const void *const *attn_qkv; // nlayer * [ndev, (nh + 2 * nkvh) / ndev * dh] const void *const *attn_qkv_b; + // nlayer * [dh] + const void *const *attn_q_norm; + // nlayer * [dh] + const void *const *attn_k_norm; // nlayer * [ndev, d, nkvh / ndev * dh] const void *const *attn_o; // nlayer * [d] @@ -80,6 +84,43 @@ inferBatchJiuge(struct JiugeModel *, const float *temperature, const uint32_t *topk, const float *topp, uint32_t *output); +/// @brief 批次推理一轮,并采样出新的 token(RoPE 位置与 KV 写入位置可解耦,用于 KV 压缩) +/// @param req_pos 位置 id 基址(用于 RoPE/pos_ids 计算) +/// @param kv_pos KVCache 写入/读取基址(用于 past_len/total_len 计算) +__C __export void +inferBatchJiugeEx(struct JiugeModel *, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, + const uint32_t *req_pos, + const uint32_t *kv_pos, + struct KVCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output); + +/// @brief 批次推理一轮,并采样出新的 token,同时输出 logits +/// @param logits 输出 logits 数组 +__C __export void +inferBatchJiugeWithLogits(struct JiugeModel *, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output, void *logits); + +/// @brief 批次推理一轮(RoPE 位置与 KV 写入位置可解耦),同时输出 logits +/// @param req_pos 位置 id 基址(用于 RoPE/pos_ids 计算) +/// @param kv_pos KVCache 写入/读取基址(用于 past_len/total_len 计算) +/// @param logits 输出 logits 数组 +__C __export void +inferBatchJiugeExWithLogits(struct JiugeModel *, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, + const uint32_t *req_pos, + const uint32_t *kv_pos, + struct KVCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output, void *logits); + /// @brief 批次推理一轮,输出 output embedding 后的 logits /// @param tokens 输入 token 地址 /// @param ntok 输入 token 数量 @@ -95,4 +136,94 @@ forwardBatchJiuge(struct JiugeModel *, struct KVCache **kv_caches, void *logits); +/// @brief 批次推理一轮,输出 logits(RoPE 位置与 KV 写入位置可解耦,用于 KV 压缩) +__C __export void +forwardBatchJiugeEx(struct JiugeModel *, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, + const uint32_t *req_pos, + const uint32_t *kv_pos, + struct KVCache **kv_caches, + void *logits); + +/// @brief 批次推理一轮,支持对指定 token 位置的输入 embedding 做覆盖(用于多模态 image embedding 注入) +/// @note override_pos 需要按升序排列,且每个位置最多出现一次 +/// @param n_override 覆盖位置数量 +/// @param override_pos 覆盖位置(基于拼接后的 tokens 序列下标,范围 [0, ntok)) +/// @param override_embeds 覆盖 embedding,shape [n_override, d],dtype = meta.dt_logits +__C __export void +inferBatchJiugeWithOverrides(struct JiugeModel *, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + uint32_t n_override, + const uint32_t *override_pos, + const void *override_embeds, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output); + +/// @brief 批次推理一轮(RoPE 位置与 KV 写入位置可解耦),支持 embedding 覆盖 +__C __export void +inferBatchJiugeWithOverridesEx(struct JiugeModel *, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, + const uint32_t *req_pos, + const uint32_t *kv_pos, + struct KVCache **kv_caches, + uint32_t n_override, + const uint32_t *override_pos, + const void *override_embeds, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output); + +/// @brief 批次推理一轮,支持 embedding 覆盖,同时输出 logits +__C __export void +inferBatchJiugeWithOverridesWithLogits(struct JiugeModel *, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + uint32_t n_override, + const uint32_t *override_pos, + const void *override_embeds, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output, void *logits); + +// /// @brief 批次推理一轮(RoPE 位置与 KV 写入位置可解耦),支持 embedding 覆盖,同时输出 logits +// __C __export void +// inferBatchJiugeWithOverridesExWithLogits(struct JiugeModel *, +// const uint32_t *tokens, uint32_t ntok, +// const uint32_t *req_lens, uint32_t nreq, +// const uint32_t *req_pos, +// const uint32_t *kv_pos, +// struct KVCache **kv_caches, +// uint32_t n_override, +// const uint32_t *override_pos, +// const void *override_embeds, +// const float *temperature, const uint32_t *topk, const float *topp, +// uint32_t *output, void *logits); + +/// @brief 批次推理一轮,输出 logits,支持对指定 token 位置的输入 embedding 做覆盖 +__C __export void +forwardBatchJiugeWithOverrides(struct JiugeModel *, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + uint32_t n_override, + const uint32_t *override_pos, + const void *override_embeds, + void *logits); + +/// @brief 批次推理一轮,输出 logits(RoPE 位置与 KV 写入位置可解耦),支持 embedding 覆盖 +__C __export void +forwardBatchJiugeWithOverridesEx(struct JiugeModel *, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, + const uint32_t *req_pos, + const uint32_t *kv_pos, + struct KVCache **kv_caches, + uint32_t n_override, + const uint32_t *override_pos, + const void *override_embeds, + void *logits); + #endif diff --git a/include/infinicore_infer/models/llava.h b/include/infinicore_infer/models/llava.h new file mode 100644 index 00000000..aa8354e9 --- /dev/null +++ b/include/infinicore_infer/models/llava.h @@ -0,0 +1,208 @@ +#ifndef MODEL_LLAVA_H +#define MODEL_LLAVA_H + +#include +#include +#include + +#include + +struct LlavaModel; + +// Vision Encoder Meta +typedef struct { + size_t image_size; + size_t patch_size; + size_t num_patches; + size_t vision_embed_dim; + size_t vision_num_layers; + size_t vision_num_heads; + size_t vision_intermediate_size; // mlp_dim + float vision_epsilon; +} LlavaVisionMeta; + +// Language Model Meta (reuses Jiuge structure) +typedef struct { + infiniDtype_t dt_logits; + size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc; + float epsilon, theta; + uint32_t end_token; +} LlavaLanguageMeta; + +// MultiModal Projector Meta +typedef struct { + size_t vision_embed_dim; + size_t text_embed_dim; + size_t projector_hidden_size; +} LlavaProjectorMeta; + +typedef struct { + LlavaVisionMeta vision_meta; + LlavaLanguageMeta language_meta; + LlavaProjectorMeta projector_meta; +} LlavaMeta; + +typedef struct { + // Vision Encoder Weights + size_t vision_nlayer; + const void *vision_patch_embed_weight; // [num_patches, vision_embed_dim] + const void *vision_class_token; // [vision_embed_dim] + const void *vision_position_embedding; // [num_patches + 1, vision_embed_dim] + const void *const *vision_encoder_weights; // vision_layers * [various vision weights] // 应该没用到 + + const void *vision_pre_layernorm_weight; // [vision_embed_dim] + const void *vision_pre_layernorm_bias; // [vision_embed_dim] + const void *vision_post_layernorm_weight; // [vision_embed_dim] + const void *vision_post_layernorm_bias; // [vision_embed_dim] + + const void *const *vision_q_weights; // vision_layers * [vision_q_weight] + const void *const *vision_q_biases; // vision_layers * [vision_q_bias] + const void *const *vision_k_weights; // vision_layers * [vision_k_weight] + const void *const *vision_k_biases; // vision_layers * [vision_k_bias] + const void *const *vision_v_weights; // vision_layers * [vision_v_weight] + const void *const *vision_v_biases; // vision_layers * [vision_v_bias] + + + const void *const *vision_in_layer_pre_norm_weights; // vision_layers * [vision_embed_dim] + const void *const *vision_in_layer_pre_norm_biases; // vision_layers * [vision_embed_dim] + + // out_proj / proj (注意:是 attention 的输出投影) + const void *const *vision_proj_weight; // vision_layers * [embed_dim, embed_dim] + const void *const *vision_proj_bias; // vision_layers * [embed_dim] + + // post attention layernorm(等价 torch: self.layer_norm2 或类似) + const void *const *vision_in_layer_post_norm_weight; // vision_layers * [embed_dim] + const void *const *vision_post_norm_bias; // vision_layers * [embed_dim] + + // MLP 层:fc1 + const void *const *vision_mlp_fc1_weight; // vision_layers * [mlp_dim, embed_dim] + const void *const *vision_mlp_fc1_bias; // vision_layers * [mlp_dim] // 4096, vision_intermediate_size + + // MLP 层:fc2 + const void *const *vision_mlp_fc2_weight; // vision_layers * [embed_dim, mlp_dim] + const void *const *vision_mlp_fc2_bias; // vision_layers * [embed_dim] + + + // MultiModal Projector Weights + const void *projector_weight_1; // linear_1: [projector_hidden_size, vision_embed_dim] + const void *projector_bias_1; // linear_1: [projector_hidden_size] + const void *projector_weight_2; // linear_2: [text_embed_dim, projector_hidden_size] + const void *projector_bias_2; // linear_2: [text_embed_dim] + + // Language Model Weights (reuses Jiuge structure) + size_t nlayer; + infiniDtype_t dt_norm, dt_mat; + int transpose_linear_weights; + + // Language model weights + const void *input_embd; + const void *output_norm; + const void *output_embd; + const void *const *attn_norm; + const void *const *attn_qkv; + const void *const *attn_qkv_b; + const void *const *attn_q_norm; + const void *const *attn_k_norm; + const void *const *attn_o; + const void *const *ffn_norm; + const void *const *ffn_gate_up; + const void *const *ffn_down; +} LlavaWeights; + +struct LlavaKVCache; + +// Vision debug stages for alignment with HF. +// Output dtype is always meta.language_meta.dt_logits. +// - PRE_LN: [1, 577, 1024] +// - SELECT_ALL: [1, 577, 1024] (vision_feature_layer = -2, includes class token) +// - SELECT_PATCH: [1, 576, 1024] (vision_feature_layer = -2, patch-only) +// - PROJECTOR: [1, 576, 4096] (projector on patch-only tokens) +// - PROJECTOR_ALL:[1, 577, 4096] (projector on all tokens, for debugging) +#define LLAVA_VISION_STAGE_PRE_LN 0u +#define LLAVA_VISION_STAGE_SELECT_ALL 1u +#define LLAVA_VISION_STAGE_SELECT_PATCH 2u +#define LLAVA_VISION_STAGE_PROJECTOR 3u +#define LLAVA_VISION_STAGE_PROJECTOR_ALL 4u + +//////////////////// APIs /////////////////////// +/// @brief 创建LLaVA模型 +/// @param vision_meta 视觉编码器元信息 +/// @param language_meta 语言模型元信息 +/// @param projector_meta 多模态投影器元信息 +/// @param weights 模型权重 +/// @param device 协处理器种类 +/// @param ndev 协处理器数量 +/// @param dev_ids 协处理器编号,长度为 ndev +__C __export struct LlavaModel * +createLlavaModel(const LlavaMeta *meta, + const LlavaWeights *weights, + infiniDevice_t device, + int ndev, + const int *dev_ids); + +/// @brief 销毁LLaVA模型 +/// @param model 模型实例 +__C __export void +destroyLlavaModel(struct LlavaModel *model); + +/// @brief 视觉编码前向推理 +/// @param model 模型实例 +/// @param image_tensor 输入图像张量 +/// @param output 输出视觉特征 +__C __export void +encodeVision(struct LlavaModel *model, + const void *image_tensor, + void *output); + +/// @brief 批量视觉编码推理(用于Python接口) +/// @param model 模型实例 +/// @param image_data 图像数据指针 +/// @param output 输出缓冲区 +__C __export void +inferBatchLlavaVison(struct LlavaModel *model, + const void *image_data, + void *output); + +/// @brief Batch vision forward for intermediate alignment (HF nodes). +/// @param stage One of LLAVA_VISION_STAGE_*. +__C __export void +inferBatchLlavaVisionStage(struct LlavaModel *model, + const void *image_data, + uint32_t stage, + void *output); + +/// @brief 多模态投影前向推理 +/// @param model 模型实例 +/// @param vision_features 视觉特征 +/// @param output 投影后的文本嵌入 +__C __export void +projectMultiModal(struct LlavaModel *model, + const void *vision_features, + void *output); + +/// @brief 语言模型批处理推理 (复用Jiuge逻辑) +/// @param model 模型实例 +/// @param tokens 输入tokens +/// @param ntok tokens长度 +/// @param req_lens 每个请求长度 +/// @param nreq 请求数量 +/// @param req_pos 每个请求当前的位置 +/// @param kv_caches KV缓存 +/// @param temperature 温度参数 +/// @param topk Top-K采样参数 +/// @param topp Top-P采样参数 +/// @param output 输出token IDs +__C __export void +inferBatchLlavaLanguage(struct LlavaModel *model, + const uint32_t *tokens, + size_t ntok, + const size_t *req_lens, + size_t nreq, + const size_t *req_pos, + struct LlavaKVCache *kv_caches, + const float *temperature, + const float *topk, + const float *topp, + uint32_t *output); + +#endif diff --git a/include/infinicore_infer/models/minicpmv.h b/include/infinicore_infer/models/minicpmv.h new file mode 100644 index 00000000..38b204a7 --- /dev/null +++ b/include/infinicore_infer/models/minicpmv.h @@ -0,0 +1,198 @@ +#ifndef MODEL_MINICPMV_H +#define MODEL_MINICPMV_H + +#include +#include +#include + +#include + +struct MiniCPMVModel; + +typedef struct { + // Vision encoder (SigLIP-NaViT) + size_t patch_size; // 14 + size_t vision_embed_dim; // 1152 + size_t vision_num_layers; // 27 + size_t vision_num_heads; // 16 + size_t vision_intermediate_size;// 4304 + float vision_layer_norm_eps; // 1e-6 + size_t vision_image_size; // 980 (pos-embed grid: 70x70) + size_t vision_num_positions; // 4900 +} MiniCPMVVisionMeta; + +typedef struct { + // Resampler (Perceiver-style, one cross-attn) + size_t num_queries; // 64 + size_t embed_dim; // 3584 + size_t num_heads; // 28 + size_t kv_dim; // 1152 + float layer_norm_eps; // 1e-6 + size_t max_patches_h; // 70 + size_t max_patches_w; // 70 +} MiniCPMVResamplerMeta; + +typedef struct { + // Language model meta (same layout as JiugeMeta) + infiniDtype_t dt_logits; + size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc; + float epsilon, theta; + uint32_t end_token; +} MiniCPMVLanguageMeta; + +typedef struct { + MiniCPMVVisionMeta vision_meta; + MiniCPMVResamplerMeta resampler_meta; + MiniCPMVLanguageMeta language_meta; +} MiniCPMVMeta; + +typedef struct { + // LayerNorm + const void *layer_norm1_weight; + const void *layer_norm1_bias; + const void *layer_norm2_weight; + const void *layer_norm2_bias; + + // Self-attention + const void *q_weight; + const void *q_bias; + const void *k_weight; + const void *k_bias; + const void *v_weight; + const void *v_bias; + const void *out_weight; + const void *out_bias; + + // MLP + const void *fc1_weight; + const void *fc1_bias; + const void *fc2_weight; + const void *fc2_bias; +} MiniCPMVSiglipLayerWeights; + +typedef struct { + // SigLIP patch embedding conv2d + const void *vpm_patch_embedding_weight; // [vision_embed_dim, 3, patch, patch] + const void *vpm_patch_embedding_bias; // [vision_embed_dim] + // SigLIP position embedding + const void *vpm_position_embedding; // [vision_num_positions, vision_embed_dim] + // SigLIP encoder layers + const MiniCPMVSiglipLayerWeights *vpm_layers; // [vision_num_layers] + // SigLIP final LN + const void *vpm_post_layernorm_weight; // [vision_embed_dim] + const void *vpm_post_layernorm_bias; // [vision_embed_dim] + + // Resampler + const void *resampler_query; // [num_queries, embed_dim] + // NOTE: For the current CPU reference implementation, these weights are expected + // to be pre-transposed to "in x out" layout for GEMM: [in_dim, out_dim]. + const void *resampler_kv_proj_weight; // [kv_dim, embed_dim] + const void *resampler_attn_in_proj_weight; // [embed_dim, 3*embed_dim] + const void *resampler_attn_in_proj_bias; // [3*embed_dim] + const void *resampler_attn_out_proj_weight; // [embed_dim, embed_dim] + const void *resampler_attn_out_proj_bias; // [embed_dim] + const void *resampler_ln_q_weight; // [embed_dim] + const void *resampler_ln_q_bias; // [embed_dim] + const void *resampler_ln_kv_weight; // [embed_dim] + const void *resampler_ln_kv_bias; // [embed_dim] + const void *resampler_ln_post_weight; // [embed_dim] + const void *resampler_ln_post_bias; // [embed_dim] + const void *resampler_proj; // [embed_dim, embed_dim] + + // Language model weights (reuse Jiuge layout) + size_t nlayer; + infiniDtype_t dt_norm, dt_mat; + int transpose_linear_weights; + + const void *input_embd; + const void *output_norm; + const void *output_embd; + const void *const *attn_norm; + const void *const *attn_qkv; + const void *const *attn_qkv_b; + const void *const *attn_q_norm; + const void *const *attn_k_norm; + const void *const *attn_o; + const void *const *ffn_norm; + const void *const *ffn_gate_up; + const void *const *ffn_down; +} MiniCPMVWeights; + +//////////////////// APIs /////////////////////// +__C __export struct MiniCPMVModel * +createMiniCPMVModel(const MiniCPMVMeta *meta, + const MiniCPMVWeights *weights, + infiniDevice_t device, + int ndev, + const int *dev_ids); + +__C __export void +destroyMiniCPMVModel(struct MiniCPMVModel *model); + +/// @brief Resampler forward (CPU reference path). +/// @note This API is for step-by-step validation; it currently supports CPU only. +/// @param x Vision features, shape [seq_len, kv_dim], dtype = meta.language_meta.dt_logits +/// @param seq_len Must equal tgt_h * tgt_w (no padding supported in this API) +/// @param tgt_h Patch grid height +/// @param tgt_w Patch grid width +/// @param output Output, shape [num_queries, embed_dim], dtype = meta.language_meta.dt_logits +__C __export void +inferMiniCPMVResampler(struct MiniCPMVModel *model, + const void *x, size_t seq_len, + uint32_t tgt_h, uint32_t tgt_w, + void *output); + +/// @brief SigLIP patch embedding + position embedding (CPU reference path). +/// @note This API is for step-by-step validation; it currently supports CPU only. +/// @param pixel_values Input packed as [1, 3, patch_size, seq_len * patch_size], where seq_len == tgt_h * tgt_w. +/// @param output Output embeddings, shape [seq_len, vision_embed_dim], dtype = meta.language_meta.dt_logits +__C __export void +inferMiniCPMVSiglipEmbeddings(struct MiniCPMVModel *model, + const void *pixel_values, + size_t seq_len, + uint32_t tgt_h, + uint32_t tgt_w, + void *output); + +/// @brief SigLIP encoder layer0 forward (CPU reference path). +/// @note This API is for step-by-step validation; it currently supports CPU only and ignores attention masks. +/// @param hidden_states Input, shape [seq_len, vision_embed_dim] +/// @param output Output, shape [seq_len, vision_embed_dim] +__C __export void +inferMiniCPMVSiglipLayer0(struct MiniCPMVModel *model, + const void *hidden_states, + size_t seq_len, + void *output); + +/// @brief SigLIP encoder layer forward (CPU reference path). +/// @note This API is for step-by-step validation; it currently supports CPU only and ignores attention masks. +/// @param layer_idx Which encoder layer to run. +__C __export void +inferMiniCPMVSiglipLayer(struct MiniCPMVModel *model, + uint32_t layer_idx, + const void *hidden_states, + size_t seq_len, + void *output); + +/// @brief SigLIP encoder forward for the first `num_layers` layers, followed by post-layernorm (CPU reference path). +/// @note This API is for step-by-step validation; it currently supports CPU only and ignores attention masks. +__C __export void +inferMiniCPMVSiglipEncoder(struct MiniCPMVModel *model, + uint32_t num_layers, + const void *hidden_states, + size_t seq_len, + void *output); + +/// @brief Vision forward: SigLIP embeddings -> SigLIP encoder -> resampler (CPU reference path). +/// @note This API is for step-by-step validation; it currently supports CPU only. +/// @param pixel_values Input packed as [1, 3, patch_size, seq_len * patch_size], where seq_len == tgt_h * tgt_w. +/// @param output Output, shape [num_queries, embed_dim], dtype = meta.language_meta.dt_logits +__C __export void +inferMiniCPMVVisionResampler(struct MiniCPMVModel *model, + const void *pixel_values, + size_t seq_len, + uint32_t tgt_h, + uint32_t tgt_w, + void *output); + +#endif diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 00000000..37ff22df --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +transformers==4.57.1 \ No newline at end of file diff --git a/scripts/deepseek.py b/scripts/deepseek.py index bba5a373..29d81ce1 100644 --- a/scripts/deepseek.py +++ b/scripts/deepseek.py @@ -662,11 +662,7 @@ def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1. output_tokens = self.batch_infer_one_round([infer_task]) end_time = time.time() steps += 1 - output_str = ( - self.tokenizer._tokenizer.id_to_token(output_tokens[0]) - .replace("▁", " ") - .replace("<0x0A>", "\n") - ) + output_str = self.tokenizer.decode(output_tokens[0]) output_content += output_str print(output_str, end="", flush=True) if output_tokens[0] in self.eos_token_id: diff --git a/scripts/image_processing_minicpmv.py b/scripts/image_processing_minicpmv.py new file mode 100644 index 00000000..67ff552f --- /dev/null +++ b/scripts/image_processing_minicpmv.py @@ -0,0 +1,418 @@ +from typing import Optional, Union, Dict, Any, List + +import torch +import math +import PIL.Image +import PIL.ImageSequence +import numpy as np +import PIL +from PIL import Image + +from transformers.utils import TensorType, requires_backends, is_torch_dtype, is_torch_device +from transformers.image_processing_utils import BaseImageProcessor, BatchFeature +from transformers import AutoImageProcessor +from transformers.image_transforms import to_channel_dimension_format +from transformers.image_utils import ( + ImageInput, + make_list_of_images, + valid_images, + is_torch_tensor, + is_batched, + to_numpy_array, + infer_channel_dimension_format, + ChannelDimension +) + + +def recursive_converter(converter, value): + if isinstance(value, list): + new_value = [] + for v in value: + new_value += [recursive_converter(converter, v)] + return new_value + else: + return converter(value) + + +class MiniCPMVBatchFeature(BatchFeature): + r""" + Extend from BatchFeature for supporting various image size + """ + def __init__(self, data: Optional[Dict[str, Any]] = None, tensor_type: Union[None, str, TensorType] = None): + super().__init__(data) + self.convert_to_tensors(tensor_type=tensor_type) + + def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None): + if tensor_type is None: + return self + + is_tensor, as_tensor = self._get_is_as_tensor_fns(tensor_type) + + def converter(value): + try: + if not is_tensor(value): + tensor = as_tensor(value) + return tensor + except: # noqa E722 + if key == "overflowing_values": + raise ValueError("Unable to create tensor returning overflowing values of different lengths. ") + raise ValueError( + "Unable to create tensor, you should probably activate padding " + "with 'padding=True' to have batched tensors with the same length." + ) + + + for key, value in self.items(): + self[key] = recursive_converter(converter, value) + return self + + def to(self, *args, **kwargs) -> "MiniCPMVBatchFeature": + requires_backends(self, ["torch"]) + import torch + + def cast_tensor(v): + # check if v is a floating point + if torch.is_floating_point(v): + # cast and send to device + return v.to(*args, **kwargs) + elif device is not None: + return v.to(device=device) + else: + return v + + new_data = {} + device = kwargs.get("device") + # Check if the args are a device or a dtype + if device is None and len(args) > 0: + # device should be always the first argument + arg = args[0] + if is_torch_dtype(arg): + # The first argument is a dtype + pass + elif isinstance(arg, str) or is_torch_device(arg) or isinstance(arg, int): + device = arg + else: + # it's something else + raise ValueError(f"Attempting to cast a BatchFeature to type {str(arg)}. This is not supported.") + # We cast only floating point tensors to avoid issues with tokenizers casting `LongTensor` to `FloatTensor` + for k, v in self.items(): + new_data[k] = recursive_converter(cast_tensor, v) + self.data = new_data + return self + + +class MiniCPMVImageProcessor(BaseImageProcessor): + model_input_names = ["pixel_values"] + + def __init__( + self, + max_slice_nums=9, + scale_resolution=448, + patch_size=14, + **kwargs): + super().__init__(**kwargs) + self.max_slice_nums = max_slice_nums + self.scale_resolution = scale_resolution + self.patch_size = patch_size + self.use_image_id = kwargs.pop("use_image_id", False) + self.image_feature_size = kwargs.pop("image_feature_size", 64) + self.im_start_token = kwargs.pop("im_start", "") + self.im_end_token = kwargs.pop("im_end", "") + self.slice_start_token = kwargs.pop("slice_start", "") + self.slice_end_token = kwargs.pop("slice_end", "") + self.unk_token = kwargs.pop("unk", "") + self.im_id_start = kwargs.pop("im_id_start", "") + self.im_id_end = kwargs.pop("im_id_end", "") + self.slice_mode = kwargs.pop("slice_mode", True) + self.mean = np.array(kwargs.pop("norm_mean", [0.5, 0.5, 0.5])) + self.std = np.array(kwargs.pop("norm_std", [0.5, 0.5, 0.5])) + self.version = kwargs.pop("version", 2.0) + + def ensure_divide(self, length, patch_size): + return max(round(length / patch_size) * patch_size, patch_size) + + def find_best_resize(self, + original_size, + scale_resolution, + patch_size, + allow_upscale=False): + width, height = original_size + if (width * height > + scale_resolution * scale_resolution) or allow_upscale: + r = width / height + height = int(scale_resolution / math.sqrt(r)) + width = int(height * r) + best_width = self.ensure_divide(width, patch_size) + best_height = self.ensure_divide(height, patch_size) + return (best_width, best_height) + + def get_refine_size(self, + original_size, + grid, + scale_resolution, + patch_size, + allow_upscale=False): + width, height = original_size + grid_x, grid_y = grid + + refine_width = self.ensure_divide(width, grid_x) + refine_height = self.ensure_divide(height, grid_y) + + grid_width = refine_width / grid_x + grid_height = refine_height / grid_y + + best_grid_size = self.find_best_resize((grid_width, grid_height), + scale_resolution, + patch_size, + allow_upscale=allow_upscale) + refine_size = (best_grid_size[0] * grid_x, best_grid_size[1] * grid_y) + return refine_size + + def split_to_patches(self, image, grid): + patches = [] + width, height = image.size + grid_x = int(width / grid[0]) + grid_y = int(height / grid[1]) + for i in range(0, height, grid_y): + images = [] + for j in range(0, width, grid_x): + box = (j, i, j + grid_x, i + grid_y) + patch = image.crop(box) + images.append(patch) + patches.append(images) + return patches + + def slice_image( + self, image, max_slice_nums=9, scale_resolution=448, patch_size=14, never_split=False + ): + original_size = image.size + source_image = None + best_grid = self.get_sliced_grid(original_size, max_slice_nums, never_split) + patches = [] + + if best_grid is None: + # dont need to slice, upsample + best_size = self.find_best_resize( + original_size, scale_resolution, patch_size, allow_upscale=True + ) + source_image = image.resize(best_size, resample=Image.Resampling.BICUBIC) + else: + # source image, down-sampling and ensure divided by patch_size + best_resize = self.find_best_resize(original_size, scale_resolution, patch_size) + source_image = image.copy().resize(best_resize, resample=Image.Resampling.BICUBIC) + refine_size = self.get_refine_size( + original_size, best_grid, scale_resolution, patch_size, allow_upscale=True + ) + refine_image = image.resize(refine_size, resample=Image.Resampling.BICUBIC) + patches = self.split_to_patches(refine_image, best_grid) + + return source_image, patches, best_grid + + def get_grid_placeholder(self, grid): + if grid is None: + return "" + slice_image_placeholder = ( + self.slice_start_token + + self.unk_token * self.image_feature_size + + self.slice_end_token + ) + + cols = grid[0] + rows = grid[1] + slices = [] + for i in range(rows): + lines = [] + for j in range(cols): + lines.append(slice_image_placeholder) + slices.append("".join(lines)) + + slice_placeholder = "\n".join(slices) + return slice_placeholder + + def get_image_id_placeholder(self, idx=0): + return f"{self.im_id_start}{idx}{self.im_id_end}" + + def get_sliced_images(self, image, max_slice_nums=None): + slice_images = [] + + if not self.slice_mode: + return [image] + + max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums) + assert max_slice_nums > 0 + source_image, patches, sliced_grid = self.slice_image( + image, + max_slice_nums, # default: 9 + self.scale_resolution, # default: 448 + self.patch_size # default: 14 + ) + + slice_images.append(source_image) + if len(patches) > 0: + for i in range(len(patches)): + for j in range(len(patches[0])): + slice_images.append(patches[i][j]) + return slice_images + + def get_sliced_grid(self, image_size, max_slice_nums, nerver_split=False): + original_width, original_height = image_size + log_ratio = math.log(original_width / original_height) + ratio = original_width * original_height / (self.scale_resolution * self.scale_resolution) + multiple = min(math.ceil(ratio), max_slice_nums) + if multiple <= 1 or nerver_split: + return None + candidate_split_grids_nums = [] + for i in [multiple - 1, multiple, multiple + 1]: + if i == 1 or i > max_slice_nums: + continue + candidate_split_grids_nums.append(i) + + candidate_grids = [] + for split_grids_nums in candidate_split_grids_nums: + m = 1 + while m <= split_grids_nums: + if split_grids_nums % m == 0: + candidate_grids.append([m, split_grids_nums // m]) + m += 1 + + best_grid = [1, 1] + min_error = float("inf") + for grid in candidate_grids: + error = abs(log_ratio - math.log(grid[0] / grid[1])) + if error < min_error: + best_grid = grid + min_error = error + + return best_grid + + def get_slice_image_placeholder(self, image_size, image_idx=0, max_slice_nums=None, use_image_id=None): + max_slice_nums = self.max_slice_nums if max_slice_nums is None else int(max_slice_nums) + assert max_slice_nums > 0 + grid = self.get_sliced_grid(image_size=image_size, max_slice_nums=max_slice_nums) + + image_placeholder = ( + self.im_start_token + + self.unk_token * self.image_feature_size + + self.im_end_token + ) + use_image_id = self.use_image_id if use_image_id is None else bool(use_image_id) + if use_image_id: + final_placeholder = self.get_image_id_placeholder(image_idx) + image_placeholder + else: + final_placeholder = image_placeholder + + if self.slice_mode: + final_placeholder = final_placeholder + self.get_grid_placeholder(grid=grid) + return final_placeholder + + def to_pil_image(self, image, rescale=None) -> PIL.Image.Image: + """ + Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if + needed. + + Args: + image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`): + The image to convert to the PIL Image format. + rescale (`bool`, *optional*): + Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will + default to `True` if the image type is a floating type, `False` otherwise. + """ + if isinstance(image, PIL.Image.Image): + return image + if is_torch_tensor(image): + image = image.numpy() + + if isinstance(image, np.ndarray): + if rescale is None: + # rescale default to the array being of floating type. + rescale = isinstance(image.flat[0], np.floating) + # If the channel as been moved to first dim, we put it back at the end. + if image.ndim == 3 and image.shape[0] in [1, 3]: + image = image.transpose(1, 2, 0) + if rescale: + image = image * 255 + image = image.astype(np.uint8) + return PIL.Image.fromarray(image) + return image + + def reshape_by_patch(self, image): + """ + :param image: shape [3, H, W] + :param patch_size: + :return: [3, patch_size, HW/patch_size] + """ + image = torch.from_numpy(image) + patch_size = self.patch_size + patches = torch.nn.functional.unfold( + image, + (patch_size, patch_size), + stride=(patch_size, patch_size) + ) + + patches = patches.reshape(image.size(0), patch_size, patch_size, -1) + patches = patches.permute(0, 1, 3, 2).reshape(image.size(0), patch_size, -1) + return patches.numpy() + + def preprocess( + self, + images: Union[Image.Image, List[Image.Image], List[List[Image.Image]]], + do_pad: Optional[bool] = True, # TODO: add pad for MiniCPM-Llama3-V-2_5 + max_slice_nums: int = None, + return_tensors: Optional[Union[str, TensorType]] = None, + **kwargs + ) -> MiniCPMVBatchFeature: + if isinstance(images, Image.Image): + images_list = [[images]] + elif isinstance(images[0], Image.Image): + images_list = [images] + else: + images_list = images + + new_images_list = [] + image_sizes_list = [] + tgt_sizes_list = [] + + for _images in images_list: + if _images is None or len(_images) == 0: + new_images_list.append([]) + image_sizes_list.append([]) + tgt_sizes_list.append([]) + continue + if not valid_images(_images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "torch.Tensor, tf.Tensor or jax.ndarray." + ) + + _images = [self.to_pil_image(image).convert("RGB") for image in _images] + input_data_format = infer_channel_dimension_format(np.array(_images[0])) + + new_images = [] + image_sizes = [image.size for image in _images] + tgt_sizes = [] + for image in _images: + image_patches = self.get_sliced_images(image, max_slice_nums) + image_patches = [to_numpy_array(image).astype(np.float32) / 255 for image in image_patches] + image_patches = [ + self.normalize(image=image, mean=self.mean, std=self.std, input_data_format=input_data_format) + for image in image_patches + ] + image_patches = [ + to_channel_dimension_format(image, ChannelDimension.FIRST, input_channel_dim=input_data_format) + for image in image_patches + ] + for slice_image in image_patches: + new_images.append(self.reshape_by_patch(slice_image)) + tgt_sizes.append(np.array((slice_image.shape[1] // self.patch_size, slice_image.shape[2] // self.patch_size))) + + if tgt_sizes: + tgt_sizes = np.vstack(tgt_sizes) + + new_images_list.append(new_images) + image_sizes_list.append(image_sizes) + tgt_sizes_list.append(tgt_sizes) + return MiniCPMVBatchFeature( + data={"pixel_values": new_images_list, "image_sizes": image_sizes_list, "tgt_sizes": tgt_sizes_list}, tensor_type=return_tensors + ) + +AutoImageProcessor.register("MiniCPMVImageProcessor", MiniCPMVImageProcessor) diff --git a/scripts/img/150352.jpg b/scripts/img/150352.jpg new file mode 100644 index 00000000..b8ffde9d Binary files /dev/null and b/scripts/img/150352.jpg differ diff --git a/scripts/img/1592576.jpg b/scripts/img/1592576.jpg new file mode 100644 index 00000000..a7a1e007 Binary files /dev/null and b/scripts/img/1592576.jpg differ diff --git a/scripts/img/1592725.jpg b/scripts/img/1592725.jpg new file mode 100644 index 00000000..30fe812f Binary files /dev/null and b/scripts/img/1592725.jpg differ diff --git a/scripts/img/2315433.jpg b/scripts/img/2315433.jpg new file mode 100644 index 00000000..ce76ce50 Binary files /dev/null and b/scripts/img/2315433.jpg differ diff --git a/scripts/img/2826.jpg b/scripts/img/2826.jpg new file mode 100644 index 00000000..1d0aa4cb Binary files /dev/null and b/scripts/img/2826.jpg differ diff --git a/scripts/img/285689.jpg b/scripts/img/285689.jpg new file mode 100644 index 00000000..41a99368 Binary files /dev/null and b/scripts/img/285689.jpg differ diff --git a/scripts/img/3700.jpg b/scripts/img/3700.jpg new file mode 100644 index 00000000..a5d8b734 Binary files /dev/null and b/scripts/img/3700.jpg differ diff --git a/scripts/img/654.jpg b/scripts/img/654.jpg new file mode 100644 index 00000000..8f963831 Binary files /dev/null and b/scripts/img/654.jpg differ diff --git a/scripts/infer_task.py b/scripts/infer_task.py index 0d1231b7..9508fb6d 100644 --- a/scripts/infer_task.py +++ b/scripts/infer_task.py @@ -35,6 +35,62 @@ def next(self, out_token): else: self.tokens = [out_token] + def __str__(self): + """返回用户友好的字符串表示""" + return ( + f"InferTask(id={self.id}, " + f"tokens_len={len(self.tokens)}, " + f"pos={self.pos}, " + f"max_tokens={self.max_tokens}, " + f"temperature={self.temperature}, " + f"topk={self.topk}, " + f"topp={self.topp}, " + f"finish_reason={self.finish_reason}, " + f"has_kv_cache={self._kv_cache is not None})" + ) + + def __repr__(self): + """返回开发者友好的详细表示""" + # 显示前10个token,避免输出过长 + tokens_preview = self.tokens[:10] if len(self.tokens) > 10 else self.tokens + if len(self.tokens) > 10: + tokens_preview_str = str(tokens_preview) + f" ... (total {len(self.tokens)} tokens)" + else: + tokens_preview_str = str(tokens_preview) + + return ( + f"InferTask(\n" + f" id={self.id},\n" + f" tokens={tokens_preview_str},\n" + f" pos={self.pos},\n" + f" max_tokens={self.max_tokens},\n" + f" temperature={self.temperature},\n" + f" topk={self.topk},\n" + f" topp={self.topp},\n" + f" end_tokens={self.end_tokens},\n" + f" finish_reason={self.finish_reason},\n" + f" has_kv_cache={self._kv_cache is not None}\n" + f" _kv_cache={self._kv_cache}\n" + f")" + ) + + def debug_info(self): + """返回详细的调试信息""" + return { + "id": self.id, + "tokens": self.tokens, + "tokens_len": len(self.tokens), + "pos": self.pos, + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "topk": self.topk, + "topp": self.topp, + "end_tokens": self.end_tokens, + "finish_reason": self.finish_reason, + "has_kv_cache": self._kv_cache is not None, + "remaining_tokens": self.max_tokens - self.pos + } + class KVCache: def __init__(self, model): @@ -57,3 +113,40 @@ def update_tokens(self, tokens, pos): end = max_len self.tokens[pos:end] = tokens + + def __str__(self): + """返回用户友好的字符串表示""" + # 计算非零token数量(已使用的token) + used_tokens = sum(1 for t in self.tokens if t != 0) + return f"KVCache(used_tokens={used_tokens}, max_capacity={len(self.tokens)})" + + def __repr__(self): + """返回开发者友好的详细表示""" + # 显示前20个token,避免输出过长 + tokens_preview = self.tokens[:20] if len(self.tokens) > 20 else self.tokens + if len(self.tokens) > 20: + tokens_preview_str = str(tokens_preview) + f" ... (total {len(self.tokens)} slots)" + else: + tokens_preview_str = str(tokens_preview) + + used_tokens = sum(1 for t in self.tokens if t != 0) + + return ( + f"KVCache(\n" + f" tokens={tokens_preview_str},\n" + f" used_tokens={used_tokens},\n" + f" max_capacity={len(self.tokens)},\n" + f" usage_ratio={used_tokens/len(self.tokens):.2%}\n" + f")" + ) + def debug_info(self): + """返回详细的调试信息""" + used_tokens = sum(1 for t in self.tokens if t != 0) + return { + # "tokens": self.tokens, + "total_slots": len(self.tokens), + "used_tokens": used_tokens, + "empty_slots": len(self.tokens) - used_tokens, + "usage_ratio": used_tokens / len(self.tokens), + "is_full": used_tokens >= len(self.tokens) + } diff --git a/scripts/jiuge.py b/scripts/jiuge.py index 091ccca7..269fa142 100644 --- a/scripts/jiuge.py +++ b/scripts/jiuge.py @@ -25,56 +25,72 @@ class LlamaWeightsNaming: + def __init__(self, prefix: str = ""): + # Optional prefix like "llm." for MiniCPM-V. + self.prefix = prefix + + def _p(self, name: str) -> str: + return f"{self.prefix}{name}" if self.prefix else name + def input_embd(self): - return "model.embed_tokens.weight" + return self._p("model.embed_tokens.weight") def output_norm(self): - return "model.norm.weight" + return self._p("model.norm.weight") def output_embd(self): - return "lm_head.weight" + return self._p("lm_head.weight") def attn_norm(self, i): - return f"model.layers.{i}.input_layernorm.weight" + return self._p(f"model.layers.{i}.input_layernorm.weight") def attn_q(self, i): - return f"model.layers.{i}.self_attn.q_proj.weight" + return self._p(f"model.layers.{i}.self_attn.q_proj.weight") def attn_k(self, i): - return f"model.layers.{i}.self_attn.k_proj.weight" + return self._p(f"model.layers.{i}.self_attn.k_proj.weight") def attn_v(self, i): - return f"model.layers.{i}.self_attn.v_proj.weight" + return self._p(f"model.layers.{i}.self_attn.v_proj.weight") def attn_o(self, i): - return f"model.layers.{i}.self_attn.o_proj.weight" + return self._p(f"model.layers.{i}.self_attn.o_proj.weight") def attn_q_b(self, i): - return f"model.layers.{i}.self_attn.q_proj.bias" + return self._p(f"model.layers.{i}.self_attn.q_proj.bias") def attn_k_b(self, i): - return f"model.layers.{i}.self_attn.k_proj.bias" + return self._p(f"model.layers.{i}.self_attn.k_proj.bias") def attn_v_b(self, i): - return f"model.layers.{i}.self_attn.v_proj.bias" + return self._p(f"model.layers.{i}.self_attn.v_proj.bias") + + def attn_q_norm(self, i): + return self._p(f"model.layers.{i}.self_attn.q_norm.weight") + + def attn_k_norm(self, i): + return self._p(f"model.layers.{i}.self_attn.k_norm.weight") def ffn_norm(self, i): - return f"model.layers.{i}.post_attention_layernorm.weight" + return self._p(f"model.layers.{i}.post_attention_layernorm.weight") def gate(self, i): - return f"model.layers.{i}.mlp.gate_proj.weight" + return self._p(f"model.layers.{i}.mlp.gate_proj.weight") def up(self, i): - return f"model.layers.{i}.mlp.up_proj.weight" + return self._p(f"model.layers.{i}.mlp.up_proj.weight") def down(self, i): - return f"model.layers.{i}.mlp.down_proj.weight" + return self._p(f"model.layers.{i}.mlp.down_proj.weight") - def match(state_dict): - return ( - "model.norm.weight" in state_dict - and "model.layers.0.self_attn.q_proj.weight" in state_dict - ) + @staticmethod + def match(state_dict, prefix: str = ""): + def p(n: str) -> str: + return f"{prefix}{n}" if prefix else n + + return p("model.norm.weight") in state_dict and p( + "model.layers.0.self_attn.q_proj.weight" + ) in state_dict class JiugeMetaFromLlama(JiugeMetaCStruct): @@ -117,14 +133,18 @@ def __init__(self, config, dtype=torch.float16, max_tokens=None): if "num_key_value_heads" in config else config["num_attention_heads"] ), - dh=config["hidden_size"] // config["num_attention_heads"], + dh=( + config["head_dim"] + if "head_dim" in config + else config["hidden_size"] // config["num_attention_heads"] + ), di=config["intermediate_size"], dctx=( config["max_position_embeddings"] if max_tokens is None else max_tokens ), dvoc=config["vocab_size"], epsilon=config["rms_norm_eps"], - theta=(config["rope_theta"] if "rope_theta" in config else 100000.0), + theta=(config["rope_theta"] if "rope_theta" in config else 10000.0), end_token=2, ) self.torch_dtype_logits = dtype @@ -275,6 +295,35 @@ def qkv_b_slices(_i): else: self.attn_qkv_b = None + if naming.attn_q_norm(0) in state_dict: + self.attn_q_norm_tensors = [ + state_dict[naming.attn_q_norm(i)] + .reshape([2, dh // 2]) + .transpose(0, 1) + .contiguous() + .to(torch_dt_norm) + for i in range(nlayer) + ] + self.attn_q_norm_ptrs = [ + self.attn_q_norm_tensors[i].data_ptr() for i in range(nlayer) + ] + self.attn_q_norm = (c_void_p * nlayer)(*self.attn_q_norm_ptrs) + self.attn_k_norm_tensors = [ + state_dict[naming.attn_k_norm(i)] + .reshape([2, dh // 2]) + .transpose(0, 1) + .contiguous() + .to(torch_dt_norm) + for i in range(nlayer) + ] + self.attn_k_norm_ptrs = [ + self.attn_k_norm_tensors[i].data_ptr() for i in range(nlayer) + ] + self.attn_k_norm = (c_void_p * nlayer)(*self.attn_k_norm_ptrs) + else: + self.attn_q_norm = None + self.attn_k_norm = None + self.attn_o_tensor = [ ( state_dict[naming.attn_o(i)] @@ -389,7 +438,12 @@ def input_args(self): class JiugeForCauslLM: def __init__( - self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None + self, + model_dir_path, + device=DeviceType.DEVICE_TYPE_CPU, + ndev=1, + max_tokens=None, + dtype_override=None, ): def load_all_safetensors_from_dir(dir_path_: str): tensors_ = {} @@ -406,6 +460,8 @@ def load_all_safetensors_from_dir(dir_path_: str): with open(os.path.join(model_dir_path, "config.json"), "r") as f: config = json.load(f) self.config = config + + print(f"Model config: {self.config}") eos_token_id = self.config["eos_token_id"] self.eos_token_id = ( [eos_token_id] if type(eos_token_id) == int else eos_token_id @@ -415,6 +471,11 @@ def load_all_safetensors_from_dir(dir_path_: str): ) # y = xW is faster than y=xW^T on Ascend self.jiuge_model = JiugeModel() + # JiugeModel是构造函数,负责创建多线程和设备资源 + # ✅ 创建Python的JiugeModel对象 + # ✅ 继承自BaseModel,加载C++库 + # ❌ 还没有创建C++的工作线程 + # ❌ 还没有分配设备资源 if "llama" == config["model_type"]: model = ( @@ -481,8 +542,46 @@ def load_all_safetensors_from_dir(dir_path_: str): ) else: raise ValueError("Unsupported weight naming") - elif "qwen2" == config["model_type"]: + elif "minicpmv" == config["model_type"]: + if any( + file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir() + ): + state_dict = load_all_safetensors_from_dir(model_dir_path) + else: + state_dict = torch.load( + os.path.join(model_dir_path, "pytorch_model.bin"), + weights_only=True, + map_location="cpu", + ) + naming_prefix = "llm." + if not LlamaWeightsNaming.match(state_dict, prefix=naming_prefix): + raise ValueError("Unsupported MiniCPM-V LLM weight naming") + if dtype_override is not None: + llm_dtype = dtype_override + else: + torch_dtype_str = config.get("torch_dtype", "float16") + dtype_map = { + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + } + llm_dtype = dtype_map.get(torch_dtype_str, torch.float16) + self.meta = JiugeMetaFromLlama(config, dtype=llm_dtype, max_tokens=max_tokens) + naming = LlamaWeightsNaming(prefix=naming_prefix) + self.weights = JiugeWeightsImpl( + self.meta, + naming, + state_dict, + torch_dt_mat=llm_dtype, + ndev=ndev, + transpose_weight=transpose_weight, + ) + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_dir_path, trust_remote_code=True + ) + elif "qwen2" == config["model_type"] or "qwen3" == config["model_type"]: state_dict = load_all_safetensors_from_dir(model_dir_path) + #print(f"state_dict keys: {list(state_dict.keys())[:50]} ...") if LlamaWeightsNaming.match(state_dict): self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) self.weights = JiugeWeightsImpl( @@ -498,6 +597,26 @@ def load_all_safetensors_from_dir(dir_path_: str): else: raise ValueError("Unsupported model architecture") + if "llama" == config["model_type"]: + from tokenizers import decoders as _dec + + backend = getattr(self.tokenizer, "backend_tokenizer", None) + target = getattr(backend, "_tokenizer", backend) + norm = getattr(target, "normalizer", None) + dec = getattr(target, "decoder", None) + sn = repr(norm)[:800] if norm is not None else "" + sd = repr(dec)[:800] if dec is not None else "" + has_prepend = "Prepend" in sn + has_strip = "Strip" in sd + if has_prepend and has_strip: + target.decoder = _dec.Sequence( + [ + _dec.Replace("▁", " "), + _dec.ByteFallback(), + _dec.Fuse(), + ] + ) + load_end_time = time.time() print(f"Time used: {load_end_time - load_start_time:.3f}s") @@ -507,6 +626,7 @@ def load_all_safetensors_from_dir(dir_path_: str): self.ndev = ndev self.device = device + # # 2. 创建C++模型实例(这里创建了工作线程!) self.model_instance = self.jiuge_model.create_model( byref(self.meta), byref(self.weights), @@ -514,6 +634,14 @@ def load_all_safetensors_from_dir(dir_path_: str): ndev, self.dev_ids, ) + # ✅ 调用C++的createJiugeModel函数 + # ✅ 执行C++的JiugeModel构造函数 + # ✅ 创建工作线程! + # ✅ 分配设备资源! + # ✅ 线程开始运行并进入休眠等待! + + + load_end_time = time.time() print(f"Time used: {load_end_time - load_start_time:.3f}s") @@ -537,23 +665,40 @@ def drop_kv_cache(self, kv_cache): self.jiuge_model.drop_kv_cache(kv_cache) def batch_infer_one_round(self, tasks: List[InferTask]): - output = (c_uint * len(tasks))() + output = (c_uint * len(tasks))() # 得到一个可以在C函数中传递的数组对象 + # c_uint: + # output: <__main__.c_uint_Array_1 object at 0x7f76727ef240> + batch_inputs = JiugeBatchedTask(tasks) + # 传给下边这个infer_batch的就是包了一层啥东西的batch_inputs反正 + # infer_batch 等价于 inferBatchJiuge + # output传递给C函数 self.jiuge_model.infer_batch( self.model_instance, *(batch_inputs.input_args()), - output, + output, # 这里传递的是C数组的指针 ) return list(output) - def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0): + def generate( + self, + input_content, + max_steps, + topp_=1.0, + topk_=1, + temperature_=1.0, + verbose=False, + ): input_content = self.tokenizer.apply_chat_template( conversation=[{"role": "user", "content": input_content}], add_generation_prompt=True, tokenize=False, ) + print("=== Input Prompt ===") print(input_content, end="", flush=True) + print("=== Input Prompt End ===") tokens = self.tokenizer.encode(input_content) + print(f"Input token IDs: {tokens}") infer_task = InferTask( 0, tokens, @@ -563,22 +708,98 @@ def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1. topp_, self.eos_token_id, ) + print(f"KV Cache: {infer_task.kvcache()}") + + # 更详细的调试信息 + print("\n=== InferTask 详细信息 ===") + print(repr(infer_task)) + + print("\n=== KV Cache 详细信息 ===") + print(repr(infer_task.kvcache())) + + # # 调试信息字典 + # print("\n=== 调试信息字典 ===") + # debug_info = infer_task.debug_info() + # for key, value in debug_info.items(): + # print(f"{key}: {value}") + infer_task.bind_kvcache(KVCache(self)) + # print(f"\nKV Cache 调试信息:") + # kv_debug = infer_task.kvcache().debug_info() + # for key, value in kv_debug.items(): + # print(f" {key}: {value}") + + # print("\n" + "="*50) + steps = 0 total_time = 0 + prefill_time = 0 + decode_time = 0 output_content = "" - for step_i in range(max_steps): + # Prefill phase - process initial prompt + prefill_start_time = time.time() + output_tokens = self.batch_infer_one_round([infer_task]) + print(f"\nOutput token IDs from prefill: {output_tokens}") + print(f"infer_task after prefill: {repr(infer_task)}") + prefill_end_time = time.time() + prefill_time = prefill_end_time - prefill_start_time + steps += 1 + + output_str = self.tokenizer.decode(output_tokens[0]) + print(f"Decoded output from prefill: {output_str}\n") + output_content += output_str + print(output_str, end="", flush=True) + if output_tokens[0] in self.eos_token_id: + # If generation ends after prefill, calculate metrics + total_time = prefill_time + total_tokens = len(tokens) + 1 # input tokens + first output token + + print("\n") + print(f"Time per step: {total_time * 1000:.3f}ms") + + if verbose: + overall_throughput = total_tokens / total_time + prefill_throughput = len(tokens) / prefill_time + decode_throughput = 1 / 0.001 # Avoid division by zero, use small value + + print("=" * 50) + print("PERFORMANCE METRICS") + print("=" * 50) + print(f"Input tokens: {len(tokens)}") + print(f"Generated tokens: 1") + print(f"Total tokens: {total_tokens}") + print(f"Total time: {total_time * 1000:.3f}ms") + print(f"Prefill time: {prefill_time * 1000:.3f}ms") + print(f"Decode time: 0.000ms") + print("-" * 50) + print(f"Time per step: {total_time * 1000:.3f}ms") + print( + f"Avg prefill time per token: {prefill_time * 1000 / len(tokens):.3f}ms" + ) + print(f"Avg decode time per token: N/A") + print("-" * 50) + print(f"Overall throughput: {overall_throughput:.2f} tokens/s") + print(f"Prefill throughput: {prefill_throughput:.2f} tokens/s") + print(f"Decode throughput: N/A") + print("=" * 50) + + return output_content, total_time * 1000 + + infer_task.next(output_tokens[0]) + + # Decode phase - generate subsequent tokens + decode_start_time = time.time() + for step_i in range(1, max_steps): start_time = time.time() output_tokens = self.batch_infer_one_round([infer_task]) + # print(f"\nOutput token IDs from step {step_i}: {output_tokens}") + # print(f"infer_task after step {step_i}: {repr(infer_task)}") end_time = time.time() steps += 1 - output_str = ( - self.tokenizer._tokenizer.id_to_token(output_tokens[0]) - .replace("▁", " ") - .replace("<0x0A>", "\n") - ) + output_str = self.tokenizer.decode(output_tokens[0]) + output_content += output_str print(output_str, end="", flush=True) if output_tokens[0] in self.eos_token_id: @@ -588,12 +809,65 @@ def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1. if step_i > 0: total_time += end_time - start_time + decode_end_time = time.time() + decode_time = decode_end_time - decode_start_time + print("\n") - avg_time = total_time * 1000 / (steps - 1) - print(f"Time per step: {avg_time:.3f}ms") + + # Calculate performance metrics + total_time = prefill_time + decode_time + input_tokens = len(tokens) + generated_tokens = steps # including first token from prefill + + # Time per token calculations + avg_time_per_step = ( + total_time * 1000 / (steps - 1) if steps > 1 else total_time * 1000 + ) + + print(f"Time per step: {avg_time_per_step:.3f}ms") + + # Only print detailed metrics if verbose flag is set + if verbose: + total_tokens = input_tokens + generated_tokens + + # Throughput calculations + overall_throughput = total_tokens / total_time # tokens per second + prefill_throughput = input_tokens / prefill_time if prefill_time > 0 else 0 + decode_throughput = ( + (generated_tokens - 1) / decode_time if decode_time > 0 else 0 + ) # exclude first token from prefill + + # Time per token calculations + avg_prefill_time_per_token = ( + prefill_time * 1000 / input_tokens if input_tokens > 0 else 0 + ) + avg_decode_time_per_token = ( + decode_time * 1000 / (generated_tokens - 1) + if generated_tokens > 1 + else 0 + ) + + print("=" * 50) + print("PERFORMANCE METRICS") + print("=" * 50) + print(f"Input tokens: {input_tokens}") + print(f"Generated tokens: {generated_tokens}") + print(f"Total tokens: {total_tokens}") + print(f"Total time: {total_time * 1000:.3f}ms") + print(f"Prefill time: {prefill_time * 1000:.3f}ms") + print(f"Decode time: {decode_time * 1000:.3f}ms") + print("-" * 50) + print(f"Time per step: {avg_time_per_step:.3f}ms") + print(f"Avg prefill time per token: {avg_prefill_time_per_token:.3f}ms") + print(f"Avg decode time per token: {avg_decode_time_per_token:.3f}ms") + print("-" * 50) + print(f"Overall throughput: {overall_throughput:.2f} tokens/s") + print(f"Prefill throughput: {prefill_throughput:.2f} tokens/s") + print(f"Decode throughput: {decode_throughput:.2f} tokens/s") + print("=" * 50) infer_task._kv_cache.drop(self) - return output_content, avg_time + return output_content, avg_time_per_step def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10): tasks = [ @@ -656,11 +930,27 @@ def destroy_model_instance(self): def test(): if len(sys.argv) < 3: print( - "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] [n_device]" + "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] [n_device] [--verbose]" ) sys.exit(1) model_path = sys.argv[2] device_type = DeviceType.DEVICE_TYPE_CPU + verbose = "--verbose" in sys.argv + init_only = "--init-only" in sys.argv + + max_steps = 64 + prompt = "山东最高的山是?" + if "--max-steps" in sys.argv: + try: + max_steps = int(sys.argv[sys.argv.index("--max-steps") + 1]) + except Exception: + raise ValueError("--max-steps requires an integer value") + if "--prompt" in sys.argv: + try: + prompt = sys.argv[sys.argv.index("--prompt") + 1] + except Exception: + raise ValueError("--prompt requires a string value") + if sys.argv[1] == "--cpu": device_type = DeviceType.DEVICE_TYPE_CPU elif sys.argv[1] == "--nvidia": @@ -681,13 +971,24 @@ def test(): device_type = DeviceType.DEVICE_TYPE_HYGON else: print( - "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] [n_device]" + "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] [n_device] [--verbose]" ) sys.exit(1) - ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1 + # Find n_device argument (first numeric positional after model_dir) + ndev = 1 + for arg in sys.argv[3:]: + if arg.startswith("--"): + continue + try: + ndev = int(arg) + break + except ValueError: + continue + model = JiugeForCauslLM(model_path, device_type, ndev) - model.generate("山东最高的山是?", 500) + if not init_only: + model.generate(prompt, max_steps, verbose=verbose) model.destroy_model_instance() diff --git a/scripts/jiuge_awq.py b/scripts/jiuge_awq.py index a14836b7..5191efeb 100644 --- a/scripts/jiuge_awq.py +++ b/scripts/jiuge_awq.py @@ -256,11 +256,6 @@ def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1. output_tokens = self.batch_infer_one_round([infer_task]) end_time = time.time() steps += 1 - # output_str = ( - # self.tokenizer._tokenizer.id_to_token(output_tokens[0]) - # .replace("▁", " ") - # .replace("<0x0A>", "\n") - # ) output_str = self.tokenizer.decode(output_tokens[0]) output_content += output_str print(output_str, end="", flush=True) diff --git a/scripts/jiuge_override_smoke.py b/scripts/jiuge_override_smoke.py new file mode 100644 index 00000000..04999f36 --- /dev/null +++ b/scripts/jiuge_override_smoke.py @@ -0,0 +1,127 @@ +import argparse +from ctypes import POINTER, c_float, c_int, c_uint + +import torch + +from libinfinicore_infer import DeviceType, JiugeModel + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--model-dir", required=True) + ap.add_argument("--max-tokens", type=int, default=512) + args = ap.parse_args() + + # Load model via existing helper (loads weights + creates C++ model instance) + from jiuge import JiugeForCauslLM + + llm = JiugeForCauslLM( + args.model_dir, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=args.max_tokens + ) + + model: JiugeModel = llm.jiuge_model + handle = llm.model_instance + meta = llm.meta + + # Deterministic greedy sampling + temperature = (c_float * 1)(1.0) + topk = (c_uint * 1)(1) + topp = (c_float * 1)(1.0) + + # One request with a short prompt + tokens_t = torch.tensor( + [151644, 8948, 198, 2610, 525, 264, 10950, 17847, 13, 151645, 198], + dtype=torch.int32, + ) + tokens = tokens_t.numpy().astype("uint32") + ntok = int(tokens_t.numel()) + tokens_c = (c_uint * ntok)(*tokens.tolist()) + + req_lens = (c_uint * 1)(ntok) + req_pos = (c_uint * 1)(0) + + dev_ids = (c_int * 1)(0) + + # Create KV caches explicitly + kv0 = model.create_kv_cache( + meta.nlayer, + meta.dctx, + meta.nkvh, + meta.dh, + meta.dh, + meta.dt_logits, + DeviceType.DEVICE_TYPE_CPU, + dev_ids, + 1, + ) + kv1 = model.create_kv_cache( + meta.nlayer, + meta.dctx, + meta.nkvh, + meta.dh, + meta.dh, + meta.dt_logits, + DeviceType.DEVICE_TYPE_CPU, + dev_ids, + 1, + ) + + from libinfinicore_infer import KVCacheCStruct + + kv_caches0 = (POINTER(KVCacheCStruct) * 1)(kv0) + kv_caches1 = (POINTER(KVCacheCStruct) * 1)(kv1) + + out0 = (c_uint * 1)() + model.infer_batch( + handle, + tokens_c, + ntok, + req_lens, + 1, + req_pos, + kv_caches0, + temperature, + topk, + topp, + out0, + ) + + # Build overrides equal to original embeddings for a few positions. + emb = llm.weights.input_embd_tensor # [dvoc, d], dtype == dt_logits + d = int(meta.d) + override_pos_list = [0, 3, ntok - 1] + override_pos = (c_uint * len(override_pos_list))(*override_pos_list) + + override_embeds = torch.empty((len(override_pos_list), d), dtype=emb.dtype) + for j, p in enumerate(override_pos_list): + override_embeds[j].copy_(emb[int(tokens_t[p].item())]) + + out1 = (c_uint * 1)() + model.infer_batch_with_overrides( + handle, + tokens_c, + ntok, + req_lens, + 1, + req_pos, + kv_caches1, + len(override_pos_list), + override_pos, + override_embeds.data_ptr(), + temperature, + topk, + topp, + out1, + ) + + print("jiuge override smoke:") + print(" out_no_override:", int(out0[0])) + print(" out_with_override:", int(out1[0])) + + model.drop_kv_cache(kv0) + model.drop_kv_cache(kv1) + model.destroy_model(handle) + + +if __name__ == "__main__": + main() diff --git a/scripts/launch_server.py b/scripts/launch_server.py index 350e811e..3fb3ee6b 100644 --- a/scripts/launch_server.py +++ b/scripts/launch_server.py @@ -226,11 +226,8 @@ async def chat_stream(id_, request_data, request: Request): break token = await infer_task.output_queue.async_q.get() - content = ( - request.app.state.model.tokenizer._tokenizer.id_to_token(token) - .replace("▁", " ") - .replace("<0x0A>", "\n") - ) + content = request.app.state.model.tokenizer.decode(token) + chunk = json.dumps(chunk_json(id_, content=content), ensure_ascii=False) yield f"data: {chunk}\n\n" @@ -255,11 +252,7 @@ async def chat(id_, request_data, request: Request): break token = await infer_task.output_queue.async_q.get() - content = ( - request.app.state.model.tokenizer._tokenizer.id_to_token(token) - .replace("▁", " ") - .replace("<0x0A>", "\n") - ) + content = request.app.state.model.tokenizer.decode(token) output.append(content) output_text = "".join(output).strip() diff --git a/scripts/libinfinicore_infer/__init__.py b/scripts/libinfinicore_infer/__init__.py index 8fc5f4db..46e588b4 100644 --- a/scripts/libinfinicore_infer/__init__.py +++ b/scripts/libinfinicore_infer/__init__.py @@ -1,4 +1,4 @@ -from .base import DataType, DeviceType, KVCacheCStruct +from .base import DataType, DeviceType, KVCacheCStruct, KVCompressionConfigCStruct from .jiuge import JiugeModel, JiugeMetaCStruct, JiugeWeightsCStruct from .jiuge_awq import JiugeAWQModel, JiugeAWQMetaCStruct, ModelWeightsCStruct from .deepseek_v3 import ( @@ -8,11 +8,30 @@ DeepSeekV3WeightLoaderCStruct, DeepSeekV3CacheCStruct, ) +from .llava import ( + LlavaModel, + LlavaMetaCStruct, + LlavaWeightsCStruct, + LlavaKVCacheCStruct, + LlavaVisionMetaCStruct, + LlavaLanguageMetaCStruct, + LlavaProjectorMetaCStruct, +) +from .minicpmv import ( + MiniCPMVModel, + MiniCPMVMetaCStruct, + MiniCPMVWeightsCStruct, + MiniCPMVVisionMetaCStruct, + MiniCPMVResamplerMetaCStruct, + MiniCPMVLanguageMetaCStruct, + MiniCPMVSiglipLayerWeightsCStruct, +) __all__ = [ "DataType", "DeviceType", "KVCacheCStruct", + "KVCompressionConfigCStruct", "JiugeModel", "JiugeMetaCStruct", "JiugeWeightsCStruct", @@ -23,5 +42,20 @@ "DeepSeekV3MetaCStruct", "DeepSeekV3WeightsCStruct", "DeepSeekV3WeightLoaderCStruct", + "DeepSeekV3CacheCStruct", + "LlavaModel", + "LlavaMetaCStruct", + "LlavaWeightsCStruct", + "LlavaKVCacheCStruct", + "LlavaVisionMetaCStruct", + "LlavaLanguageMetaCStruct", + "LlavaProjectorMetaCStruct", + "MiniCPMVModel", + "MiniCPMVMetaCStruct", + "MiniCPMVWeightsCStruct", + "MiniCPMVVisionMetaCStruct", + "MiniCPMVResamplerMetaCStruct", + "MiniCPMVLanguageMetaCStruct", + "MiniCPMVSiglipLayerWeightsCStruct", "ModelRegister", ] diff --git a/scripts/libinfinicore_infer/base.py b/scripts/libinfinicore_infer/base.py index bed65b2e..c1d1cb53 100644 --- a/scripts/libinfinicore_infer/base.py +++ b/scripts/libinfinicore_infer/base.py @@ -42,6 +42,16 @@ class KVCacheCStruct(ctypes.Structure): pass +class KVCompressionConfigCStruct(ctypes.Structure): + _fields_ = [ + ("enable", c_uint), + ("compression_factor", c_uint), + ("min_seq_len", c_uint), + ("image_kv_len", c_uint), + ("weight_path", c_char_p), + ] + + # Model registration system _model_registry = [] diff --git a/scripts/libinfinicore_infer/jiuge.py b/scripts/libinfinicore_infer/jiuge.py index 42b3082b..e9d90e6a 100644 --- a/scripts/libinfinicore_infer/jiuge.py +++ b/scripts/libinfinicore_infer/jiuge.py @@ -1,4 +1,4 @@ -from .base import BaseModel, DataType, DeviceType, KVCacheCStruct, register_model +from .base import BaseModel, DataType, DeviceType, KVCacheCStruct, KVCompressionConfigCStruct, register_model from ctypes import c_size_t, c_uint, c_int, c_float, c_void_p, POINTER, Structure, byref @@ -31,6 +31,8 @@ class JiugeWeightsCStruct(Structure): ("attn_norm", POINTER(c_void_p)), ("attn_qkv", POINTER(c_void_p)), ("attn_qkv_b", POINTER(c_void_p)), + ("attn_q_norm", POINTER(c_void_p)), + ("attn_k_norm", POINTER(c_void_p)), ("attn_o", POINTER(c_void_p)), ("ffn_norm", POINTER(c_void_p)), ("ffn_gate_up", POINTER(c_void_p)), @@ -72,6 +74,13 @@ def register_lib(cls, lib): lib.dropKVCache.argtypes = [POINTER(KVCacheCStruct)] + lib.compressKVCacheInplace.argtypes = [ + POINTER(KVCacheCStruct), + c_uint, + POINTER(KVCompressionConfigCStruct), + ] + lib.compressKVCacheInplace.restype = c_uint + lib.inferBatchJiuge.argtypes = [ POINTER(JiugeModelCStruct), POINTER(c_uint), @@ -86,6 +95,52 @@ def register_lib(cls, lib): POINTER(c_uint), ] + lib.inferBatchJiugeWithLogits.argtypes = [ + POINTER(JiugeModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), + POINTER(POINTER(KVCacheCStruct)), + POINTER(c_float), + POINTER(c_uint), + POINTER(c_float), + POINTER(c_uint), + c_void_p, # logits + ] + + lib.inferBatchJiugeEx.argtypes = [ + POINTER(JiugeModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), # req_pos (RoPE) + POINTER(c_uint), # kv_pos (cache) + POINTER(POINTER(KVCacheCStruct)), + POINTER(c_float), + POINTER(c_uint), + POINTER(c_float), + POINTER(c_uint), + ] + + lib.inferBatchJiugeExWithLogits.argtypes = [ + POINTER(JiugeModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), # req_pos (RoPE) + POINTER(c_uint), # kv_pos (cache) + POINTER(POINTER(KVCacheCStruct)), + POINTER(c_float), + POINTER(c_uint), + POINTER(c_float), + POINTER(c_uint), + c_void_p, # logits + ] + lib.forwardBatchJiuge.argtypes = [ POINTER(JiugeModelCStruct), POINTER(c_uint), @@ -97,6 +152,119 @@ def register_lib(cls, lib): c_void_p, ] + lib.forwardBatchJiugeEx.argtypes = [ + POINTER(JiugeModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), # req_pos (RoPE) + POINTER(c_uint), # kv_pos (cache) + POINTER(POINTER(KVCacheCStruct)), + c_void_p, + ] + + lib.inferBatchJiugeWithOverrides.argtypes = [ + POINTER(JiugeModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), + POINTER(POINTER(KVCacheCStruct)), + c_uint, # n_override + POINTER(c_uint), # override_pos + c_void_p, # override_embeds + POINTER(c_float), + POINTER(c_uint), + POINTER(c_float), + POINTER(c_uint), + ] + + lib.inferBatchJiugeWithOverridesEx.argtypes = [ + POINTER(JiugeModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), # req_pos (RoPE) + POINTER(c_uint), # kv_pos (cache) + POINTER(POINTER(KVCacheCStruct)), + c_uint, # n_override + POINTER(c_uint), # override_pos + c_void_p, # override_embeds + POINTER(c_float), + POINTER(c_uint), + POINTER(c_float), + POINTER(c_uint), + ] + + lib.inferBatchJiugeWithOverridesWithLogits.argtypes = [ + POINTER(JiugeModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), + POINTER(POINTER(KVCacheCStruct)), + c_uint, # n_override + POINTER(c_uint), # override_pos + c_void_p, # override_embeds + POINTER(c_float), + POINTER(c_uint), + POINTER(c_float), + POINTER(c_uint), + c_void_p, # logits + ] + + # lib.inferBatchJiugeWithOverridesExWithLogits.argtypes = [ + # POINTER(JiugeModelCStruct), + # POINTER(c_uint), + # c_uint, + # POINTER(c_uint), + # c_uint, + # POINTER(c_uint), # req_pos (RoPE) + # POINTER(c_uint), # kv_pos (cache) + # POINTER(POINTER(KVCacheCStruct)), + # c_uint, # n_override + # POINTER(c_uint), # override_pos + # c_void_p, # override_embeds + # POINTER(c_float), + # POINTER(c_uint), + # POINTER(c_float), + # POINTER(c_uint), + # c_void_p, # logits + # ] + + lib.forwardBatchJiugeWithOverrides.argtypes = [ + POINTER(JiugeModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), + POINTER(POINTER(KVCacheCStruct)), + c_uint, # n_override + POINTER(c_uint), # override_pos + c_void_p, # override_embeds + c_void_p, # logits + ] + + lib.forwardBatchJiugeWithOverridesEx.argtypes = [ + POINTER(JiugeModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), # req_pos (RoPE) + POINTER(c_uint), # kv_pos (cache) + POINTER(POINTER(KVCacheCStruct)), + c_uint, # n_override + POINTER(c_uint), # override_pos + c_void_p, # override_embeds + c_void_p, # logits + ] + def create_model(self, meta, weights, device_type, ndev, dev_ids): return self.lib.createJiugeModel(meta, weights, device_type, ndev, dev_ids) @@ -106,6 +274,7 @@ def destroy_model(self, model): def create_kv_cache( self, nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev ): + #import pdb;pdb.set_trace() return self.lib.createKVCache( nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev ) @@ -113,6 +282,9 @@ def create_kv_cache( def drop_kv_cache(self, kv_cache): self.lib.dropKVCache(kv_cache) + def compress_kv_cache_inplace(self, kv_cache, seq_len, cfg: KVCompressionConfigCStruct): + return self.lib.compressKVCacheInplace(kv_cache, seq_len, byref(cfg)) + def infer_batch( self, model, @@ -141,9 +313,310 @@ def infer_batch( output, ) + def infer_batch_with_logits( + self, + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + temperature, + topk, + topp, + output, + logits, + ): + self.lib.inferBatchJiugeWithLogits( + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + temperature, + topk, + topp, + output, + logits, + ) + + def infer_batch_ex( + self, + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_pos, + kv_caches, + temperature, + topk, + topp, + output, + ): + self.lib.inferBatchJiugeEx( + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_pos, + kv_caches, + temperature, + topk, + topp, + output, + ) + + def infer_batch_ex_with_logits( + self, + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_pos, + kv_caches, + temperature, + topk, + topp, + output, + logits, + ): + self.lib.inferBatchJiugeExWithLogits( + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_pos, + kv_caches, + temperature, + topk, + topp, + output, + logits, + ) + def forward_batch( self, model, tokens, ntok, req_lens, nreq, req_pos, kv_caches, logits ): self.lib.forwardBatchJiuge( model, tokens, ntok, req_lens, nreq, req_pos, kv_caches, logits ) + + def forward_batch_ex( + self, model, tokens, ntok, req_lens, nreq, req_pos, kv_pos, kv_caches, logits + ): + self.lib.forwardBatchJiugeEx( + model, tokens, ntok, req_lens, nreq, req_pos, kv_pos, kv_caches, logits + ) + + def infer_batch_with_overrides( + self, + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + n_override, + override_pos, + override_embeds, + temperature, + topk, + topp, + output, + ): + self.lib.inferBatchJiugeWithOverrides( + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + n_override, + override_pos, + override_embeds, + temperature, + topk, + topp, + output, + ) + + def infer_batch_with_overrides_ex( + self, + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_pos, + kv_caches, + n_override, + override_pos, + override_embeds, + temperature, + topk, + topp, + output, + ): + self.lib.inferBatchJiugeWithOverridesEx( + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_pos, + kv_caches, + n_override, + override_pos, + override_embeds, + temperature, + topk, + topp, + output, + ) + + def infer_batch_with_overrides_with_logits( + self, + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + n_override, + override_pos, + override_embeds, + temperature, + topk, + topp, + output, + logits, + ): + self.lib.inferBatchJiugeWithOverridesWithLogits( + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + n_override, + override_pos, + override_embeds, + temperature, + topk, + topp, + output, + logits, + ) + + # def infer_batch_with_overrides_ex_with_logits( + # self, + # model, + # tokens, + # ntok, + # req_lens, + # nreq, + # req_pos, + # kv_pos, + # kv_caches, + # n_override, + # override_pos, + # override_embeds, + # temperature, + # topk, + # topp, + # output, + # logits, + # ): + # self.lib.inferBatchJiugeWithOverridesExWithLogits( + # model, + # tokens, + # ntok, + # req_lens, + # nreq, + # req_pos, + # kv_pos, + # kv_caches, + # n_override, + # override_pos, + # override_embeds, + # temperature, + # topk, + # topp, + # output, + # logits, + # ) + + def forward_batch_with_overrides( + self, + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + n_override, + override_pos, + override_embeds, + logits, + ): + self.lib.forwardBatchJiugeWithOverrides( + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + n_override, + override_pos, + override_embeds, + logits, + ) + + def forward_batch_with_overrides_ex( + self, + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_pos, + kv_caches, + n_override, + override_pos, + override_embeds, + logits, + ): + self.lib.forwardBatchJiugeWithOverridesEx( + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_pos, + kv_caches, + n_override, + override_pos, + override_embeds, + logits, + ) diff --git a/scripts/libinfinicore_infer/llava.py b/scripts/libinfinicore_infer/llava.py new file mode 100644 index 00000000..0c0fb216 --- /dev/null +++ b/scripts/libinfinicore_infer/llava.py @@ -0,0 +1,252 @@ +from .base import BaseModel, DataType, DeviceType, KVCacheCStruct, register_model +from ctypes import c_size_t, c_uint, c_int, c_float, c_void_p, POINTER, Structure, byref + + +class LlavaVisionMetaCStruct(Structure): + _fields_ = [ + ("image_size", c_size_t), + ("patch_size", c_size_t), + ("num_patches", c_size_t), + ("vision_embed_dim", c_size_t), + ("vision_num_layers", c_size_t), + ("vision_num_heads", c_size_t), + ("vision_intermediate_size", c_size_t), + ("vision_epsilon", c_float), + ] + + +class LlavaLanguageMetaCStruct(Structure): + _fields_ = [ + ("dt_logits", DataType), + ("nlayer", c_size_t), + ("d", c_size_t), + ("nh", c_size_t), + ("nkvh", c_size_t), + ("dh", c_size_t), + ("di", c_size_t), + ("dctx", c_size_t), + ("dvoc", c_size_t), + ("epsilon", c_float), + ("theta", c_float), + ("end_token", c_uint), + ] + + +class LlavaProjectorMetaCStruct(Structure): + _fields_ = [ + ("vision_embed_dim", c_size_t), + ("text_embed_dim", c_size_t), + ("projector_hidden_size", c_size_t), + ] + + +class LlavaMetaCStruct(Structure): + _fields_ = [ + ("vision_meta", LlavaVisionMetaCStruct), + ("language_meta", LlavaLanguageMetaCStruct), + ("projector_meta", LlavaProjectorMetaCStruct), + ] + + +class LlavaWeightsCStruct(Structure): + _fields_ = [ + # Vision Encoder Weights + ("vision_nlayer", c_void_p), + ("vision_patch_embed_weight", c_void_p), + ("vision_class_token", c_void_p), + ("vision_position_embedding", c_void_p), + ("vision_encoder_weights", POINTER(c_void_p)), # 好像没用 + ("vision_pre_layernorm_weight", c_void_p), + ("vision_pre_layernorm_bias", c_void_p), + ("vision_post_layernorm_weight", c_void_p), + ("vision_post_layernorm_bias", c_void_p), + + ("vision_q_weights", POINTER(c_void_p)), + ("vision_q_biases", POINTER(c_void_p)), + ("vision_k_weights", POINTER(c_void_p)), + ("vision_k_biases", POINTER(c_void_p)), + ("vision_v_weights", POINTER(c_void_p)), + ("vision_v_biases", POINTER(c_void_p)), + + ("vision_in_layer_pre_norm_weights", POINTER(c_void_p)), + ("vision_in_layer_pre_norm_biases", POINTER(c_void_p)), + + ("vision_proj_weight", POINTER(c_void_p)), + ("vision_proj_bias", POINTER(c_void_p)), + + ("vision_in_layer_post_norm_weight", POINTER(c_void_p)), + ("vision_post_norm_bias", POINTER(c_void_p)), + + ("vision_mlp_fc1_weight", POINTER(c_void_p)), + ("vision_mlp_fc1_bias", POINTER(c_void_p)), + + ("vision_mlp_fc2_weight", POINTER(c_void_p)), + ("vision_mlp_fc2_bias", POINTER(c_void_p)), + + + + # MultiModal Projector Weights + ("projector_weight_1", c_void_p), + ("projector_bias_1", c_void_p), + ("projector_weight_2", c_void_p), + ("projector_bias_2", c_void_p), + + # Language Model Weights (reuse Jiuge structure) + ("nlayer", c_size_t), + ("dt_norm", DataType), + ("dt_mat", DataType), + ("transpose_linear_weights", c_int), + ("input_embd", c_void_p), + ("output_norm", c_void_p), + ("output_embd", c_void_p), + ("attn_norm", POINTER(c_void_p)), + ("attn_qkv", POINTER(c_void_p)), + ("attn_qkv_b", POINTER(c_void_p)), + ("attn_q_norm", POINTER(c_void_p)), + ("attn_k_norm", POINTER(c_void_p)), + ("attn_o", POINTER(c_void_p)), + ("ffn_norm", POINTER(c_void_p)), + ("ffn_gate_up", POINTER(c_void_p)), + ("ffn_down", POINTER(c_void_p)), + ] + + +class LlavaKVCacheCStruct(Structure): + _fields_ = [ + ("past_key", c_void_p), + ("past_value", c_void_p), + ("past_seq_len", c_size_t), + ("max_seq_len", c_size_t), + ] + +class LlavaModelCStruct(Structure): + pass + +@register_model +class LlavaModel(BaseModel): + def __init__(self): + super().__init__() + + @classmethod + def register_lib(cls, lib): + # Setup function signatures + lib.createLlavaModel.restype = POINTER(LlavaModelCStruct) + lib.createLlavaModel.argtypes = [ + POINTER(LlavaMetaCStruct), + POINTER(LlavaWeightsCStruct), + DeviceType, # device + c_int, # ndev + POINTER(c_int), # dev_ids + ] + + lib.destroyLlavaModel.argtypes = [POINTER(LlavaModelCStruct)] + + lib.createKVCache.argtypes = [ + c_size_t, # nlayer + c_size_t, # max_len + c_size_t, # nkvh + c_size_t, # dk + c_size_t, # dv + DataType, # dtype + DeviceType, # device + POINTER(c_int), # dev_ids + c_size_t, # ndev + ] + lib.createKVCache.restype = POINTER(KVCacheCStruct) + lib.dropKVCache.argtypes = [POINTER(KVCacheCStruct)] + + # 新增:LLaVA Vision Encoding (用于batch_infer_vision) + lib.inferBatchLlavaVison.argtypes = [ + POINTER(LlavaModelCStruct), # model + c_void_p, # image_data + c_void_p, # output + ] + lib.inferBatchLlavaVison.restype = None + + lib.inferBatchLlavaVisionStage.argtypes = [ + POINTER(LlavaModelCStruct), # model + c_void_p, # image_data + c_uint, # stage + c_void_p, # output + ] + lib.inferBatchLlavaVisionStage.restype = None + + # lib.encodeVision.argtypes = [ + # POINTER(c_void_p), # model + # c_void_p, # image_tensor + # c_void_p, # output + # ] + # lib.encodeVision.restype = None + + # lib.projectMultiModal.argtypes = [ + # POINTER(c_void_p), # model + # c_void_p, # vision_features + # c_void_p, # output + # ] + # lib.projectMultiModal.restype = None + + # lib.inferBatchLlavaLanguage.argtypes = [ + # POINTER(c_void_p), # model + # POINTER(c_uint), # tokens + # c_size_t, # ntok + # POINTER(c_size_t), # req_lens + # c_size_t, # nreq + # POINTER(c_size_t), # req_pos + # POINTER(LlavaKVCacheCStruct), # kv_caches + # POINTER(c_float), # temperature + # POINTER(c_float), # topk + # POINTER(c_float), # topp + # POINTER(c_uint), # output + # ] + # lib.inferBatchLlavaLanguage.restype = None + + def create_model(self, meta, weights, device, ndev, dev_ids): + return self.lib.createLlavaModel( + meta, + weights, + device, + ndev, + dev_ids, + ) + + def destroy_model(self, model): + self.lib.destroyLlavaModel(model) + + def create_kv_cache( + self, nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev + ): + return self.lib.createKVCache( + nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev + ) + + def drop_kv_cache(self, kv_cache): + self.lib.dropKVCache(kv_cache) + + def infer_batch_vision(self, model, image_data, output): + """LLaVA Vision Encoding - 对应Python中的infer_batch_vision""" + self.lib.inferBatchLlavaVison(model, image_data, output) + + def infer_batch_vision_stage(self, model, image_data, stage, output): + self.lib.inferBatchLlavaVisionStage(model, image_data, stage, output) + + def encode_vision(self, model, image_tensor, output): + self.lib.encodeVision(model, image_tensor, output) + + def project_multimodal(self, model, vision_features, output): + self.lib.projectMultiModal(model, vision_features, output) + + def infer_batch_language(self, model, tokens, ntok, req_lens, nreq, req_pos, + kv_caches, temperature, topk, topp, output): + self.lib.inferBatchLlavaLanguage( + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + temperature, + topk, + topp, + output, + ) diff --git a/scripts/libinfinicore_infer/minicpmv.py b/scripts/libinfinicore_infer/minicpmv.py new file mode 100644 index 00000000..591ce630 --- /dev/null +++ b/scripts/libinfinicore_infer/minicpmv.py @@ -0,0 +1,224 @@ +from .base import BaseModel, DataType, DeviceType, register_model +from ctypes import c_size_t, c_uint, c_int, c_float, c_void_p, POINTER, Structure + + +class MiniCPMVVisionMetaCStruct(Structure): + _fields_ = [ + ("patch_size", c_size_t), + ("vision_embed_dim", c_size_t), + ("vision_num_layers", c_size_t), + ("vision_num_heads", c_size_t), + ("vision_intermediate_size", c_size_t), + ("vision_layer_norm_eps", c_float), + ("vision_image_size", c_size_t), + ("vision_num_positions", c_size_t), + ] + + +class MiniCPMVResamplerMetaCStruct(Structure): + _fields_ = [ + ("num_queries", c_size_t), + ("embed_dim", c_size_t), + ("num_heads", c_size_t), + ("kv_dim", c_size_t), + ("layer_norm_eps", c_float), + ("max_patches_h", c_size_t), + ("max_patches_w", c_size_t), + ] + + +class MiniCPMVLanguageMetaCStruct(Structure): + _fields_ = [ + ("dt_logits", DataType), + ("nlayer", c_size_t), + ("d", c_size_t), + ("nh", c_size_t), + ("nkvh", c_size_t), + ("dh", c_size_t), + ("di", c_size_t), + ("dctx", c_size_t), + ("dvoc", c_size_t), + ("epsilon", c_float), + ("theta", c_float), + ("end_token", c_uint), + ] + + +class MiniCPMVMetaCStruct(Structure): + _fields_ = [ + ("vision_meta", MiniCPMVVisionMetaCStruct), + ("resampler_meta", MiniCPMVResamplerMetaCStruct), + ("language_meta", MiniCPMVLanguageMetaCStruct), + ] + + +class MiniCPMVSiglipLayerWeightsCStruct(Structure): + _fields_ = [ + ("layer_norm1_weight", c_void_p), + ("layer_norm1_bias", c_void_p), + ("layer_norm2_weight", c_void_p), + ("layer_norm2_bias", c_void_p), + ("q_weight", c_void_p), + ("q_bias", c_void_p), + ("k_weight", c_void_p), + ("k_bias", c_void_p), + ("v_weight", c_void_p), + ("v_bias", c_void_p), + ("out_weight", c_void_p), + ("out_bias", c_void_p), + ("fc1_weight", c_void_p), + ("fc1_bias", c_void_p), + ("fc2_weight", c_void_p), + ("fc2_bias", c_void_p), + ] + + +class MiniCPMVWeightsCStruct(Structure): + _fields_ = [ + # Vision + ("vpm_patch_embedding_weight", c_void_p), + ("vpm_patch_embedding_bias", c_void_p), + ("vpm_position_embedding", c_void_p), + ("vpm_layers", POINTER(MiniCPMVSiglipLayerWeightsCStruct)), + ("vpm_post_layernorm_weight", c_void_p), + ("vpm_post_layernorm_bias", c_void_p), + # Resampler + ("resampler_query", c_void_p), + ("resampler_kv_proj_weight", c_void_p), + ("resampler_attn_in_proj_weight", c_void_p), + ("resampler_attn_in_proj_bias", c_void_p), + ("resampler_attn_out_proj_weight", c_void_p), + ("resampler_attn_out_proj_bias", c_void_p), + ("resampler_ln_q_weight", c_void_p), + ("resampler_ln_q_bias", c_void_p), + ("resampler_ln_kv_weight", c_void_p), + ("resampler_ln_kv_bias", c_void_p), + ("resampler_ln_post_weight", c_void_p), + ("resampler_ln_post_bias", c_void_p), + ("resampler_proj", c_void_p), + # Language (reuse Jiuge layout) + ("nlayer", c_size_t), + ("dt_norm", DataType), + ("dt_mat", DataType), + ("transpose_linear_weights", c_int), + ("input_embd", c_void_p), + ("output_norm", c_void_p), + ("output_embd", c_void_p), + ("attn_norm", POINTER(c_void_p)), + ("attn_qkv", POINTER(c_void_p)), + ("attn_qkv_b", POINTER(c_void_p)), + ("attn_q_norm", POINTER(c_void_p)), + ("attn_k_norm", POINTER(c_void_p)), + ("attn_o", POINTER(c_void_p)), + ("ffn_norm", POINTER(c_void_p)), + ("ffn_gate_up", POINTER(c_void_p)), + ("ffn_down", POINTER(c_void_p)), + ] + + +class MiniCPMVModelCStruct(Structure): + pass + + +@register_model +class MiniCPMVModel(BaseModel): + @classmethod + def register_lib(cls, lib): + lib.createMiniCPMVModel.restype = POINTER(MiniCPMVModelCStruct) + lib.createMiniCPMVModel.argtypes = [ + POINTER(MiniCPMVMetaCStruct), + POINTER(MiniCPMVWeightsCStruct), + DeviceType, + c_int, + POINTER(c_int), + ] + + lib.destroyMiniCPMVModel.argtypes = [POINTER(MiniCPMVModelCStruct)] + + lib.inferMiniCPMVResampler.argtypes = [ + POINTER(MiniCPMVModelCStruct), + c_void_p, # x + c_size_t, # seq_len + c_uint, # tgt_h + c_uint, # tgt_w + c_void_p, # output + ] + lib.inferMiniCPMVResampler.restype = None + + lib.inferMiniCPMVSiglipEmbeddings.argtypes = [ + POINTER(MiniCPMVModelCStruct), + c_void_p, # pixel_values + c_size_t, # seq_len + c_uint, # tgt_h + c_uint, # tgt_w + c_void_p, # output + ] + lib.inferMiniCPMVSiglipEmbeddings.restype = None + + lib.inferMiniCPMVSiglipLayer0.argtypes = [ + POINTER(MiniCPMVModelCStruct), + c_void_p, # hidden_states + c_size_t, # seq_len + c_void_p, # output + ] + lib.inferMiniCPMVSiglipLayer0.restype = None + + lib.inferMiniCPMVSiglipLayer.argtypes = [ + POINTER(MiniCPMVModelCStruct), + c_uint, # layer_idx + c_void_p, # hidden_states + c_size_t, # seq_len + c_void_p, # output + ] + lib.inferMiniCPMVSiglipLayer.restype = None + + lib.inferMiniCPMVSiglipEncoder.argtypes = [ + POINTER(MiniCPMVModelCStruct), + c_uint, # num_layers + c_void_p, # hidden_states + c_size_t, # seq_len + c_void_p, # output + ] + lib.inferMiniCPMVSiglipEncoder.restype = None + + lib.inferMiniCPMVVisionResampler.argtypes = [ + POINTER(MiniCPMVModelCStruct), + c_void_p, # pixel_values + c_size_t, # seq_len + c_uint, # tgt_h + c_uint, # tgt_w + c_void_p, # output + ] + lib.inferMiniCPMVVisionResampler.restype = None + + def create_model(self, meta, weights, device_type, ndev, dev_ids): + return self.lib.createMiniCPMVModel(meta, weights, device_type, ndev, dev_ids) + + def destroy_model(self, model): + self.lib.destroyMiniCPMVModel(model) + + def infer_resampler(self, model, x, seq_len, tgt_h, tgt_w, output): + self.lib.inferMiniCPMVResampler(model, x, seq_len, tgt_h, tgt_w, output) + + def infer_siglip_embeddings(self, model, pixel_values, seq_len, tgt_h, tgt_w, output): + self.lib.inferMiniCPMVSiglipEmbeddings( + model, pixel_values, seq_len, tgt_h, tgt_w, output + ) + + def infer_siglip_layer0(self, model, hidden_states, seq_len, output): + self.lib.inferMiniCPMVSiglipLayer0(model, hidden_states, seq_len, output) + + def infer_siglip_layer(self, model, layer_idx, hidden_states, seq_len, output): + self.lib.inferMiniCPMVSiglipLayer( + model, layer_idx, hidden_states, seq_len, output + ) + + def infer_siglip_encoder(self, model, num_layers, hidden_states, seq_len, output): + self.lib.inferMiniCPMVSiglipEncoder( + model, num_layers, hidden_states, seq_len, output + ) + + def infer_vision_resampler(self, model, pixel_values, seq_len, tgt_h, tgt_w, output): + self.lib.inferMiniCPMVVisionResampler( + model, pixel_values, seq_len, tgt_h, tgt_w, output + ) diff --git a/scripts/llava.py b/scripts/llava.py new file mode 100644 index 00000000..e86e3cb6 --- /dev/null +++ b/scripts/llava.py @@ -0,0 +1,1769 @@ +from typing import List, Optional, Sequence +import math +import os +from pathlib import Path +import safetensors +import sys +import time +import json +import torch +import transformers +from transformers import AutoProcessor +import ctypes +from ctypes import c_int, c_void_p, c_uint, byref, POINTER, c_float +import numpy as np +# from PIL import Image +# import numpy as np + + + +from libinfinicore_infer import ( + JiugeModel, + JiugeMetaCStruct, + JiugeWeightsCStruct, + KVCacheCStruct, + KVCompressionConfigCStruct, + LlavaMetaCStruct, + LlavaVisionMetaCStruct, + LlavaLanguageMetaCStruct, + LlavaProjectorMetaCStruct, + LlavaWeightsCStruct, + LlavaModel, + DataType, + DeviceType, +) +from infer_task import InferTask, KVCache + + + +class LlamaWeightsNaming: + def input_embd(self): + return "language_model.model.embed_tokens.weight" + + def output_norm(self): + return "language_model.model.norm.weight" + + def output_embd(self): + return "language_model.lm_head.weight" + + def attn_norm(self, i): + return f"language_model.model.layers.{i}.input_layernorm.weight" + + def attn_q(self, i): + return f"language_model.model.layers.{i}.self_attn.q_proj.weight" + + def attn_k(self, i): + return f"language_model.model.layers.{i}.self_attn.k_proj.weight" + + def attn_v(self, i): + return f"language_model.model.layers.{i}.self_attn.v_proj.weight" + + def attn_o(self, i): + return f"language_model.model.layers.{i}.self_attn.o_proj.weight" + + def attn_q_b(self, i): + return f"language_model.model.layers.{i}.self_attn.q_proj.bias" + + def attn_k_b(self, i): + return f"language_model.model.layers.{i}.self_attn.k_proj.bias" + + def attn_v_b(self, i): + return f"model.layers.{i}.self_attn.v_proj.bias" + + def attn_q_norm(self, i): + return f"language_model.model.layers.{i}.self_attn.q_norm.weight" + + def attn_k_norm(self, i): + return f"language_model.model.layers.{i}.self_attn.k_norm.weight" + + def ffn_norm(self, i): + return f"language_model.model.layers.{i}.post_attention_layernorm.weight" + + def gate(self, i): + return f"language_model.model.layers.{i}.mlp.gate_proj.weight" + + def up(self, i): + return f"language_model.model.layers.{i}.mlp.up_proj.weight" + + def down(self, i): + return f"language_model.model.layers.{i}.mlp.down_proj.weight" + + def match(state_dict): + return ( + "model.norm.weight" in state_dict + and "model.layers.0.self_attn.q_proj.weight" in state_dict + ) + + + +class JiugeMetaFromLlama(JiugeMetaCStruct): + def __init__(self, config, dtype=torch.float16, max_tokens=None): + if dtype == torch.float16: + dt_ = DataType.INFINI_DTYPE_F16 + elif dtype == torch.float32: + dt_ = DataType.INFINI_DTYPE_F32 + elif dtype == torch.bfloat16: + dt_ = DataType.INFINI_DTYPE_BF16 + else: + dt_ = DataType.INFINI_DTYPE_F16 + + self.scale_input = 1.0 + self.scale_output = 1.0 + self.scale_o = 1.0 + self.scale_down = 1.0 + if ( + config["model_type"] in ["fm9g", "minicpm"] + and "scale_emb" in config + and "scale_depth" in config + and "dim_model_base" in config + ): + self.scale_input = config["scale_emb"] + self.scale_output = config["hidden_size"] // config["dim_model_base"] + self.scale_o = config["scale_depth"] / math.sqrt( + config["num_hidden_layers"] + ) + self.scale_down = config["scale_depth"] / math.sqrt( + config["num_hidden_layers"] + ) + + super().__init__( + dt_logits=dt_, + # nlayer=config["num_hidden_layers"], + nlayer=32, # vicuna-7b-v1.5 config + d=4096, + nh=32, + nkvh=32, + dh=(4096 // 32), + di=11008, + dctx=( + 4096 if max_tokens is None else max_tokens + ), + dvoc=32064, + epsilon=1e-05, + theta=(config["rope_theta"] if "rope_theta" in config else 10000.0), + end_token=2, + ) + self.torch_dtype_logits = dtype + + + + +class JiugeWeightsImpl(JiugeWeightsCStruct): + def __init__( + self, + meta, + naming, + state_dict, + torch_dt_mat=torch.float16, + torch_dt_norm=torch.float32, + ndev=1, + transpose_weight=True, + ): + nlayer = meta.nlayer + nh = meta.nh + nkvh = meta.nkvh + dh = meta.dh + d = meta.d + di = meta.di + scale_input = meta.scale_input + scale_output = meta.scale_output + scale_o = meta.scale_o + scale_down = meta.scale_down + assert nh % nkvh == 0 + assert nh % ndev == 0 + assert nkvh % ndev == 0 + assert di % ndev == 0 + torch_dt_logits = meta.torch_dtype_logits + if torch_dt_mat == torch.float16: + self.dt_mat = DataType.INFINI_DTYPE_F16 + elif torch_dt_mat == torch.float32: + self.dt_mat = DataType.INFINI_DTYPE_F32 + elif torch_dt_mat == torch.bfloat16: + self.dt_mat = DataType.INFINI_DTYPE_BF16 + else: + raise ValueError("Unsupported proj weight data type") + if torch_dt_norm == torch.float16: + self.dt_norm = DataType.INFINI_DTYPE_F16 + elif torch_dt_norm == torch.float32: + self.dt_norm = DataType.INFINI_DTYPE_F32 + elif torch_dt_norm == torch.bfloat16: + self.dt_norm = DataType.INFINI_DTYPE_BF16 + else: + raise ValueError("Unsupported norm weight data type") + + input_embd_naming = ( + naming.input_embd() + if naming.input_embd() in state_dict + else naming.output_embd() + ) + output_embd_naming = ( + naming.output_embd() + if naming.output_embd() in state_dict + else naming.input_embd() + ) + self.transpose_linear_weights = 1 if transpose_weight else 0 + self.nlayer = nlayer + self.input_embd_tensor = ( + state_dict[input_embd_naming].to(torch_dt_logits) * scale_input + ) + self.input_embd = self.input_embd_tensor.data_ptr() + self.output_norm_tensor = ( + state_dict[naming.output_norm()].to(torch_dt_norm) * scale_output + ) + self.output_norm = self.output_norm_tensor.data_ptr() + self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat) + if not transpose_weight: + self.output_embd_tensor = self.output_embd_tensor.transpose( + 0, 1 + ).contiguous() + self.output_embd = self.output_embd_tensor.data_ptr() + + self.attn_norm_tensors = [ + state_dict[naming.attn_norm(i)].to(torch_dt_norm) for i in range(nlayer) + ] + self.attn_norm_ptrs = [ + self.attn_norm_tensors[i].data_ptr() for i in range(nlayer) + ] + self.attn_norm = (c_void_p * nlayer)(*self.attn_norm_ptrs) + + def qkv_slices(_i): + _Q = ( + state_dict[naming.attn_q(_i)] + .reshape([nh, 2, dh // 2, d]) + .transpose(1, 2) + ) + _K = ( + state_dict[naming.attn_k(_i)] + .reshape([nkvh, 2, dh // 2, d]) + .transpose(1, 2) + ) + _V = state_dict[naming.attn_v(_i)].reshape([nkvh, dh // 2, 2, d]) + _result = [] + _nh = nh // ndev + _nkvh = nkvh // ndev + for _idev in range(ndev): + _result.append(_Q[_idev * _nh : (_idev + 1) * _nh, :, :, :]) + _result.append(_K[_idev * _nkvh : (_idev + 1) * _nkvh, :, :, :]) + _result.append(_V[_idev * _nkvh : (_idev + 1) * _nkvh, :, :]) + return _result + + self.qkv_tensor = [ + torch.concat(qkv_slices(i)).to(torch_dt_mat) for i in range(nlayer) + ] + if not transpose_weight: + for i in range(nlayer): + self.qkv_tensor[i] = ( + self.qkv_tensor[i] + .reshape(ndev, (nh + 2 * nkvh) // ndev * dh, d) + .transpose(1, 2) + .contiguous() + ) + self.qkv_tensor_ptrs = [self.qkv_tensor[i].data_ptr() for i in range(nlayer)] + self.attn_qkv = (c_void_p * nlayer)(*self.qkv_tensor_ptrs) + + def qkv_b_slices(_i): + _QB = ( + state_dict[naming.attn_q_b(_i)] + .reshape([nh, 2, dh // 2]) + .transpose(1, 2) + ) + _KB = ( + state_dict[naming.attn_k_b(_i)] + .reshape([nkvh, 2, dh // 2]) + .transpose(1, 2) + ) + _VB = state_dict[naming.attn_v_b(_i)].reshape([nkvh, dh // 2, 2]) + _result = [] + _nh = nh // ndev + _nkvh = nkvh // ndev + for _idev in range(ndev): + _result.append(_QB[_idev * _nh : (_idev + 1) * _nh, :, :].flatten()) + _result.append(_KB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :].flatten()) + _result.append(_VB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :].flatten()) + return _result + + if naming.attn_q_b(0) in state_dict: + self.qkv_b_tensors = [ + torch.concat(qkv_b_slices(i)).to(torch_dt_logits) for i in range(nlayer) + ] + self.qkv_b_tensor_ptrs = [ + self.qkv_b_tensors[i].data_ptr() for i in range(nlayer) + ] + self.attn_qkv_b = (c_void_p * nlayer)(*self.qkv_b_tensor_ptrs) + else: + self.attn_qkv_b = None + + if naming.attn_q_norm(0) in state_dict: + self.attn_q_norm_tensors = [ + state_dict[naming.attn_q_norm(i)] + .reshape([2, dh // 2]) + .transpose(0, 1) + .contiguous() + .to(torch_dt_norm) + for i in range(nlayer) + ] + self.attn_q_norm_ptrs = [ + self.attn_q_norm_tensors[i].data_ptr() for i in range(nlayer) + ] + self.attn_q_norm = (c_void_p * nlayer)(*self.attn_q_norm_ptrs) + self.attn_k_norm_tensors = [ + state_dict[naming.attn_k_norm(i)] + .reshape([2, dh // 2]) + .transpose(0, 1) + .contiguous() + .to(torch_dt_norm) + for i in range(nlayer) + ] + self.attn_k_norm_ptrs = [ + self.attn_k_norm_tensors[i].data_ptr() for i in range(nlayer) + ] + self.attn_k_norm = (c_void_p * nlayer)(*self.attn_k_norm_ptrs) + else: + self.attn_q_norm = None + self.attn_k_norm = None + + self.attn_o_tensor = [ + ( + state_dict[naming.attn_o(i)] + .to(torch_dt_mat) + .reshape([d, ndev, nh // ndev * dh]) + .transpose(0, 1) + .contiguous() + if transpose_weight + else state_dict[naming.attn_o(i)] + .transpose(0, 1) + .to(torch_dt_mat) + .contiguous() + ) + * scale_o + for i in range(nlayer) + ] + self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)] + self.attn_o = (c_void_p * nlayer)(*self.attn_o_ptrs) + + self.ffn_norm_tensors = [ + state_dict[naming.ffn_norm(i)].to(torch_dt_norm) for i in range(nlayer) + ] + self.ffn_norm_ptrs = [ + self.ffn_norm_tensors[i].data_ptr() for i in range(nlayer) + ] + self.ffn_norm = (c_void_p * nlayer)(*self.ffn_norm_ptrs) + + def gate_up_slices(_i): + _result = [] + _di = di // ndev + for _idev in range(ndev): + _start = _idev * _di + _end = (_idev + 1) * _di + _result.append(state_dict[naming.gate(_i)][_start:_end, :]) + _result.append(state_dict[naming.up(_i)][_start:_end, :]) + return _result + + self.gate_up_tensors = [ + torch.concat(gate_up_slices(i)).to(torch_dt_mat) for i in range(nlayer) + ] + if not transpose_weight: + for i in range(nlayer): + self.gate_up_tensors[i] = ( + self.gate_up_tensors[i] + .reshape(ndev, 2 * di // ndev, d) + .transpose(1, 2) + .contiguous() + ) + self.gate_up_ptrs = [self.gate_up_tensors[i].data_ptr() for i in range(nlayer)] + self.ffn_gate_up = (c_void_p * nlayer)(*self.gate_up_ptrs) + + self.ffn_down_tensor = [ + ( + state_dict[naming.down(i)] + .to(torch_dt_mat) + .reshape([d, ndev, di // ndev]) + .transpose(0, 1) + .contiguous() + if transpose_weight + else state_dict[naming.down(i)] + .transpose(0, 1) + .to(torch_dt_mat) + .contiguous() + ) + * scale_down + for i in range(nlayer) + ] + self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)] + self.ffn_down = (c_void_p * nlayer)(*self.ffn_down_ptrs) + + + + +class LlavaWeightsNaming: + """LLaVA权重命名映射类""" + + def input_embd(self): + """输入嵌入层权重名""" + return "language_model.model.embed_tokens.weight" + + def output_norm(self): + """输出层归一化权重名""" + return "language_model.model.norm.weight" + + def output_embd(self): + """输出嵌入层权重名""" + return "language_model.lm_head.weight" + + def vision_patch_embed_weight(self): + """视觉编码器patch嵌入权重名""" + return "vision_tower.vision_model.embeddings.patch_embedding.weight" + + def vision_position_embedding(self): + """视觉编码器位置嵌入权重名""" + return "vision_tower.vision_model.embeddings.position_embedding.weight" + + # def vision_class_embedding(self): + # return "vision_tower.vision_model.embeddings.class_embedding.weight" + + def vision_class_token(self): + """视觉编码器class token权重名""" + return "vision_tower.vision_model.embeddings.class_embedding" + + def vision_post_layernorm_bias(self): + """视觉编码器 post_layernorm.bias 权重名""" + return "vision_tower.vision_model.post_layernorm.bias" + + def vision_post_layernorm_weight(self): + """视觉编码器 post_layernorm.weight 权重名""" + return "vision_tower.vision_model.post_layernorm.weight" + + def vision_pre_layernorm_bias(self): + """视觉编码器 pre_layernorm.bias 权重名""" + return "vision_tower.vision_model.pre_layrnorm.bias" + + def vision_pre_layernorm_weight(self): + """视觉编码器 pre_layernorm.weight 权重名""" + return "vision_tower.vision_model.pre_layrnorm.weight" + + def vision_in_layer_pre_norm_weights(self, layer_idx): + """视觉编码器前置归一化权重名""" + return f"vision_tower.vision_model.encoder.layers.{layer_idx}.layer_norm1.weight" + + def vision_in_layer_pre_norm_biases(self, layer_idx): + """视觉编码器前置归一化偏置名""" + return f"vision_tower.vision_model.encoder.layers.{layer_idx}.layer_norm1.bias" + + def vision_q_weights(self, layer_idx): + """视觉编码器Q权重名""" + return f"vision_tower.vision_model.encoder.layers.{layer_idx}.self_attn.q_proj.weight" + + def vision_q_biases(self, layer_idx): + """视觉编码器Q偏置名""" + return f"vision_tower.vision_model.encoder.layers.{layer_idx}.self_attn.q_proj.bias" + + def vision_k_weights(self, layer_idx): + """视觉编码器K权重名""" + return f"vision_tower.vision_model.encoder.layers.{layer_idx}.self_attn.k_proj.weight" + + def vision_k_biases(self, layer_idx): + """视觉编码器K偏置名""" + return f"vision_tower.vision_model.encoder.layers.{layer_idx}.self_attn.k_proj.bias" + + def vision_v_weights(self, layer_idx): + """视觉编码器V权重名""" + return f"vision_tower.vision_model.encoder.layers.{layer_idx}.self_attn.v_proj.weight" + + def vision_v_biases(self, layer_idx): + """视觉编码器V偏置名""" + return f"vision_tower.vision_model.encoder.layers.{layer_idx}.self_attn.v_proj.bias" + + # def vision_qkv_weight(self, layer_idx): + # """视觉编码器QKV合并权重名(如果存在)""" + # # 某些实现可能将QKV合并为一个权重 + # return f"vision_tower.vision_model.encoder.layers.{layer_idx}.self_attn.qkv.weight" + + # def vision_qkv_bias(self, layer_idx): + # """视觉编码器QKV合并偏置名(如果存在)""" + # return f"vision_tower.vision_model.encoder.layers.{layer_idx}.self_attn.qkv.bias" + + def vision_proj_weight(self, layer_idx): + """视觉编码器投影权重名""" + return f"vision_tower.vision_model.encoder.layers.{layer_idx}.self_attn.out_proj.weight" + + def vision_proj_bias(self, layer_idx): + """视觉编码器投影偏置名""" + return f"vision_tower.vision_model.encoder.layers.{layer_idx}.self_attn.out_proj.bias" + + def vision_in_layer_post_norm_weight(self, layer_idx): + """视觉编码器后置归一化权重名""" + return f"vision_tower.vision_model.encoder.layers.{layer_idx}.layer_norm2.weight" + + def vision_post_norm_bias(self, layer_idx): + """视觉编码器后置归一化偏置名""" + return f"vision_tower.vision_model.encoder.layers.{layer_idx}.layer_norm2.bias" + + def vision_mlp_fc1_weight(self, layer_idx): + """视觉编码器MLP第一层权重名""" + return f"vision_tower.vision_model.encoder.layers.{layer_idx}.mlp.fc1.weight" + + def vision_mlp_fc1_bias(self, layer_idx): + """视觉编码器MLP第一层偏置名""" + return f"vision_tower.vision_model.encoder.layers.{layer_idx}.mlp.fc1.bias" + + def vision_mlp_fc2_weight(self, layer_idx): + """视觉编码器MLP第二层权重名""" + return f"vision_tower.vision_model.encoder.layers.{layer_idx}.mlp.fc2.weight" + + def vision_mlp_fc2_bias(self, layer_idx): + """视觉编码器MLP第二层偏置名""" + return f"vision_tower.vision_model.encoder.layers.{layer_idx}.mlp.fc2.bias" + + + + + + + def vision_post_norm_final_weight(self): + """视觉编码器最终归一化权重名""" + return "vision_tower.vision_model.post_layernorm.weight" + + def vision_post_norm_final_bias(self): + """视觉编码器最终归一化偏置名""" + return "vision_tower.vision_model.post_layernorm.bias" + + def projector_weight_1(self): + """多模态投影器第一层权重名""" + return "multi_modal_projector.linear_1.weight" + + def projector_bias_1(self): + """多模态投影器第一层偏置名""" + return "multi_modal_projector.linear_1.bias" + + def projector_weight_2(self): + """多模态投影器第二层权重名""" + return "multi_modal_projector.linear_2.weight" + + def projector_bias_2(self): + """多模态投影器第二层偏置名""" + return "multi_modal_projector.linear_2.bias" + + def attn_norm(self, layer_idx): + """注意力归一化权重名""" + return f"language_model.model.layers.{layer_idx}.input_layernorm.weight" + + def attn_q(self, layer_idx): + """注意力Q权重名""" + return f"language_model.model.layers.{layer_idx}.self_attn.q_proj.weight" + + def attn_k(self, layer_idx): + """注意力K权重名""" + return f"language_model.model.layers.{layer_idx}.self_attn.k_proj.weight" + + def attn_v(self, layer_idx): + """注意力V权重名""" + return f"language_model.model.layers.{layer_idx}.self_attn.v_proj.weight" + + def attn_o(self, layer_idx): + """注意力O权重名""" + return f"language_model.model.layers.{layer_idx}.self_attn.o_proj.weight" + + def attn_qkv(self, layer_idx): + """注意力QKV合并权重名""" + # 对于LLaMA,通常Q、K、V是分开的,但某些实现可能合并 + return f"language_model.model.layers.{layer_idx}.self_attn.qkv.weight" # 可能不存在 + + def attn_qkv_b(self, layer_idx): + """注意力QKV合并偏置名""" + return f"language_model.model.layers.{layer_idx}.self_attn.qkv.bias" # 可能不存在 + + def attn_q_norm(self, layer_idx): + """注意力Q归一化权重名(用于某些优化)""" + return f"language_model.model.layers.{layer_idx}.self_attn.q_norm.weight" # 可能不存在 + + def attn_k_norm(self, layer_idx): + """注意力K归一化权重名(用于某些优化)""" + return f"language_model.model.layers.{layer_idx}.self_attn.k_norm.weight" # 可能不存在 + + def ffn_gate_up(self, layer_idx): + """FFN gate和up合并权重名(某些实现的优化)""" + return f"language_model.model.layers.{layer_idx}.mlp.gate_up_proj.weight" # 可能不存在 + + def ffn_norm(self, layer_idx): + """FFN归一化权重名""" + return f"language_model.model.layers.{layer_idx}.post_attention_layernorm.weight" + + def ffn_gate(self, layer_idx): + """FFN门控权重名""" + return f"language_model.model.layers.{layer_idx}.mlp.gate_proj.weight" + + def ffn_up(self, layer_idx): + """FFN上投影权重名""" + return f"language_model.model.layers.{layer_idx}.mlp.up_proj.weight" + + def ffn_down(self, layer_idx): + """FFN下投影权重名""" + return f"language_model.model.layers.{layer_idx}.mlp.down_proj.weight" + +class LlavaMetaFromLlava(LlavaMetaCStruct): + def __init__(self, config, dtype=torch.float16, max_tokens=None): + # Data type conversion + if dtype == torch.float16: + dt_ = DataType.INFINI_DTYPE_F16 + elif dtype == torch.float32: + dt_ = DataType.INFINI_DTYPE_F32 + elif dtype == torch.bfloat16: + dt_ = DataType.INFINI_DTYPE_BF16 + else: + dt_ = DataType.INFINI_DTYPE_F16 + + # Vision encoder meta (from vision_config) + vision_config = config.get("vision_config", {}) + print(f"[LlavaMetaFromLlava] vision_config: {vision_config}") + vision_meta = LlavaVisionMetaCStruct( + image_size=vision_config.get("image_size", 336), + patch_size=vision_config.get("patch_size", 14), + num_patches=(vision_config.get("image_size", 336) // vision_config.get("patch_size", 14)) ** 2, + vision_embed_dim=vision_config.get("hidden_size", 1024), + vision_num_layers=vision_config.get("num_hidden_layers", 24), + vision_num_heads=vision_config.get("num_attention_heads", 16), + vision_intermediate_size=vision_config.get("intermediate_size", 4096), + vision_epsilon=1e-5, # 来自 transformers + ) + + # Language model meta (from text_config or main config) + text_config = config.get("text_config", config) + + # Vicuna-7B-v1.5的完整配置 (LLaVA text_config可能不完整) + vicuna_config = { + "num_hidden_layers": 32, + "hidden_size": 4096, + "num_attention_heads": 32, + "num_key_value_heads": 32, + "intermediate_size": 11008, + "max_position_embeddings": 4096, + "rms_norm_eps": 1e-05, + "vocab_size": 32000, + "rope_theta": 10000.0, + "head_dim": 128, # 4096 // 32 + } + + # 合并配置:优先使用LLaVA的text_config,缺失的用Vicuna默认值 + language_meta = LlavaLanguageMetaCStruct( + dt_logits=dt_, + nlayer=text_config.get("num_hidden_layers", vicuna_config["num_hidden_layers"]), + d=text_config.get("hidden_size", vicuna_config["hidden_size"]), + nh=text_config.get("num_attention_heads", vicuna_config["num_attention_heads"]), + nkvh=text_config.get("num_key_value_heads", vicuna_config["num_key_value_heads"]), + dh=text_config.get("head_dim", vicuna_config["head_dim"]), + di=text_config.get("intermediate_size", vicuna_config["intermediate_size"]), + dctx=( + text_config.get("max_position_embeddings", vicuna_config["max_position_embeddings"]) + if max_tokens is None else max_tokens + ), + dvoc=text_config.get("vocab_size", vicuna_config["vocab_size"]), + epsilon=text_config.get("rms_norm_eps", vicuna_config["rms_norm_eps"]), + theta=text_config.get("rope_theta", vicuna_config["rope_theta"]), + end_token=2, + ) + + # Projector meta + projector_meta = LlavaProjectorMetaCStruct( + vision_embed_dim=vision_config.get("hidden_size", 1024), + text_embed_dim=text_config.get("hidden_size", vicuna_config["hidden_size"]), + projector_hidden_size=config.get("mm_hidden_size", 4096), + ) + + # Call parent constructor with three meta structures + super().__init__( + vision_meta=vision_meta, + language_meta=language_meta, + projector_meta=projector_meta, + ) + self.torch_dtype_logits = dtype + + + +class LlavaWeightsImpl(LlavaWeightsCStruct): + def __init__( + self, + meta, + naming, + state_dict, + torch_dt_mat=torch.float16, + torch_dt_norm=torch.float16, + ndev=1, + ): + nlayer = meta.language_meta.nlayer + vision_nlayer = meta.vision_meta.vision_num_layers + d = meta.language_meta.d + di = meta.language_meta.di + nh = meta.language_meta.nh + nkvh = meta.language_meta.nkvh + dh = meta.language_meta.dh + + # 数据类型转换 + if torch_dt_mat == torch.float16: + self.dt_mat = DataType.INFINI_DTYPE_F16 + elif torch_dt_mat == torch.float32: + self.dt_mat = DataType.INFINI_DTYPE_F32 + elif torch_dt_mat == torch.bfloat16: + self.dt_mat = DataType.INFINI_DTYPE_BF16 + else: + self.dt_mat = DataType.INFINI_DTYPE_F16 + + if torch_dt_norm == torch.float16: + self.dt_norm = DataType.INFINI_DTYPE_F16 + elif torch_dt_norm == torch.float32: + self.dt_norm = DataType.INFINI_DTYPE_F32 + elif torch_dt_norm == torch.bfloat16: + self.dt_norm = DataType.INFINI_DTYPE_BF16 + else: + self.dt_norm = DataType.INFINI_DTYPE_F32 + + # self.transpose_linear_weights = 1 if transpose_weight else 0 + self.nlayer = nlayer + self.vision_nlayer = vision_nlayer + + # === 视觉编码器权重 === + # Patch嵌入权重 + if naming.vision_patch_embed_weight() in state_dict: + self.vision_patch_embed_tensor = state_dict[naming.vision_patch_embed_weight()].to(torch_dt_mat) + # print(f"[Python LlavaWeightsImpl] torch_dt_mat: {torch_dt_mat} ") # torch.float16 + # print(f"[Python LlavaWeightsImpl] vision_patch_embed_tensor shape: {self.vision_patch_embed_tensor.shape} ") + self.vision_patch_embed_weight = self.vision_patch_embed_tensor.data_ptr() + # print(f"[Python LlavaWeightsImpl] vision_patch_embed_weight pointer: {hex(self.vision_patch_embed_weight)} " ) + # print(f"[Python LlavaWeightsImpl] first 10 vision_patch_embed_weight: {self.vision_patch_embed_tensor.flatten()[:10]} ") + # Print pointer address in 0x... format + try: + addr = int(self.vision_patch_embed_weight) + # print(f"[Python LlavaWeightsImpl] vision_patch_embed_weight address: {hex(addr)}") + except Exception as e: + print(f"[Python LlavaWeightsImpl] failed to get vision_patch_embed_weight address: {e}") + else: + self.vision_patch_embed_weight = 0 + + # 位置嵌入和class token + if naming.vision_position_embedding() in state_dict: + self.vision_position_embedding_tensor = state_dict[naming.vision_position_embedding()].to(torch_dt_mat) + self.vision_position_embedding = self.vision_position_embedding_tensor.data_ptr() + else: + self.vision_position_embedding = 0 + + # if naming.vision_class_embedding() in state_dict: + # self.vision_class_embedding_tensor = state_dict[naming.vision_class_embedding()].to(torch_dt_mat) + # self.vision_class_embedding = self.vision_class_embedding_tensor.data_ptr() + if naming.vision_class_token() in state_dict: + self.vision_class_token_tensor = state_dict[naming.vision_class_token()].to(torch_dt_mat) + # print(f"[Python LlavaWeightsImpl] vision_class_token_tensor: {self.vision_class_token_tensor} ") + # print(f"[Python LlavaWeightsImpl] vision_class_token_tensor shape: {self.vision_class_token_tensor.shape} " ) + # print(f"[Python LlavaWeightsImpl] vision_class_token_tensor dtype: {self.vision_class_token_tensor.dtype} " ) + self.vision_class_token = self.vision_class_token_tensor.data_ptr() + #print(f"[Python LlavaWeightsImpl] vision_class_token pointer: {hex(self.vision_class_token)} ") + else: + self.vision_class_token = 0 + + # pre_layernorm.weight + if naming.vision_pre_layernorm_weight() in state_dict: + self.vision_pre_layernorm_weight_tensor = state_dict[naming.vision_pre_layernorm_weight()].to(torch_dt_mat) + self.vision_pre_layernorm_weight = self.vision_pre_layernorm_weight_tensor.data_ptr() + #print(f"[Python LlavaWeightsImpl] vision_pre_layernorm_weight pointer: {hex(self.vision_pre_layernorm_weight)} ") + else: + self.vision_pre_layernorm_weight = 0 + + # pre_layernorm.bias + if naming.vision_pre_layernorm_bias() in state_dict: + self.vision_pre_layernorm_bias_tensor = state_dict[naming.vision_pre_layernorm_bias()].to(torch_dt_mat) + self.vision_pre_layernorm_bias = self.vision_pre_layernorm_bias_tensor.data_ptr() + else: + self.vision_pre_layernorm_bias = 0 + # post_layernorm.weight + if naming.vision_post_layernorm_weight() in state_dict: + self.vision_post_layernorm_weight_tensor = state_dict[naming.vision_post_layernorm_weight()].to(torch_dt_mat) + self.vision_post_layernorm_weight = self.vision_post_layernorm_weight_tensor.data_ptr() + else: + self.vision_post_layernorm_weight = 0 + + # post_layernorm.bias + if naming.vision_post_layernorm_bias() in state_dict: + self.vision_post_layernorm_bias_tensor = state_dict[naming.vision_post_layernorm_bias()].to(torch_dt_mat) + self.vision_post_layernorm_bias = self.vision_post_layernorm_bias_tensor.data_ptr() + else: + self.vision_post_layernorm_bias = 0 + + # in_layer pre_norm weights + self.vision_in_layer_pre_norm_weight_tensors = [ + state_dict[naming.vision_in_layer_pre_norm_weights(i)].to(torch_dt_mat) for i in range(vision_nlayer) + ] + self.vision_in_layer_pre_norm_weight_ptrs = [ + self.vision_in_layer_pre_norm_weight_tensors[i].data_ptr() for i in range(vision_nlayer) + ] + self.vision_in_layer_pre_norm_weights = (c_void_p * vision_nlayer)(*self.vision_in_layer_pre_norm_weight_ptrs) + + # in_layer pre_norm biases + self.vision_in_layer_pre_norm_bias_tensors = [ + state_dict[naming.vision_in_layer_pre_norm_biases(i)].to(torch_dt_mat) for i in range(vision_nlayer) + ] + self.vision_in_layer_pre_norm_bias_ptrs = [ + self.vision_in_layer_pre_norm_bias_tensors[i].data_ptr() for i in range(vision_nlayer) + ] + self.vision_in_layer_pre_norm_biases = (c_void_p * vision_nlayer)(*self.vision_in_layer_pre_norm_bias_ptrs) + + # q weights + self.vision_q_weight_tensors = [ + state_dict[naming.vision_q_weights(i)].to(torch_dt_mat) for i in range(vision_nlayer) + ] + self.vision_q_weight_ptrs = [ + self.vision_q_weight_tensors[i].data_ptr() for i in range(vision_nlayer) + ] + self.vision_q_weights = (c_void_p * vision_nlayer)(*self.vision_q_weight_ptrs) + # q biases + self.vision_q_bias_tensors = [ + state_dict[naming.vision_q_biases(i)].to(torch_dt_mat) for i in range(vision_nlayer) + ] + self.vision_q_bias_ptrs = [ + self.vision_q_bias_tensors[i].data_ptr() for i in range(vision_nlayer) + ] + self.vision_q_biases = (c_void_p * vision_nlayer)(*self.vision_q_bias_ptrs) + # k weights + self.vision_k_weight_tensors = [ + state_dict[naming.vision_k_weights(i)].to(torch_dt_mat) for i in range(vision_nlayer) + ] + self.vision_k_weight_ptrs = [ + self.vision_k_weight_tensors[i].data_ptr() for i in range(vision_nlayer) + ] + self.vision_k_weights = (c_void_p * vision_nlayer)(*self.vision_k_weight_ptrs) + # k biases + self.vision_k_bias_tensors = [ + state_dict[naming.vision_k_biases(i)].to(torch_dt_mat) for i in range(vision_nlayer) + ] + self.vision_k_bias_ptrs = [ + self.vision_k_bias_tensors[i].data_ptr() for i in range(vision_nlayer) + ] + self.vision_k_biases = (c_void_p * vision_nlayer)(*self.vision_k_bias_ptrs) + # v weights + self.vision_v_weight_tensors = [ + state_dict[naming.vision_v_weights(i)].to(torch_dt_mat) for i in range(vision_nlayer) + ] + self.vision_v_weight_ptrs = [ + self.vision_v_weight_tensors[i].data_ptr() for i in range(vision_nlayer) + ] + self.vision_v_weights = (c_void_p * vision_nlayer)(*self.vision_v_weight_ptrs) + # v biases + self.vision_v_bias_tensors = [ + state_dict[naming.vision_v_biases(i)].to(torch_dt_mat) for i in range(vision_nlayer) + ] + self.vision_v_bias_ptrs = [ + self.vision_v_bias_tensors[i].data_ptr() for i in range(vision_nlayer) + ] + self.vision_v_biases = (c_void_p * vision_nlayer)(*self.vision_v_bias_ptrs) + + ############################################### + # out_proj.weight / out_proj.bias + ############################################### + + self.vision_proj_weight_tensors = [ + state_dict[naming.vision_proj_weight(i)].to(torch_dt_mat) for i in range(vision_nlayer) + ] + self.vision_proj_weight_ptrs = [ + self.vision_proj_weight_tensors[i].data_ptr() for i in range(vision_nlayer) + ] + self.vision_proj_weight = (c_void_p * vision_nlayer)(*self.vision_proj_weight_ptrs) + + self.vision_proj_bias_tensors = [ + state_dict[naming.vision_proj_bias(i)].to(torch_dt_mat) for i in range(vision_nlayer) + ] + self.vision_proj_bias_ptrs = [ + self.vision_proj_bias_tensors[i].data_ptr() for i in range(vision_nlayer) + ] + self.vision_proj_bias = (c_void_p * vision_nlayer)(*self.vision_proj_bias_ptrs) + + + ############################################### + # post norm (after attention) weight / bias + ############################################### + + self.vision_in_layer_post_norm_tensors = [ + state_dict[naming.vision_in_layer_post_norm_weight(i)].to(torch_dt_mat) for i in range(vision_nlayer) + ] + self.vision_in_layer_post_norm_ptrs = [ + self.vision_in_layer_post_norm_tensors[i].data_ptr() for i in range(vision_nlayer) + ] + self.vision_in_layer_post_norm_weight = (c_void_p * vision_nlayer)(*self.vision_in_layer_post_norm_ptrs) + # print(f"[Python LlavaWeightsImpl] vision_in_layer_post_norm_weight pointers: {[hex(ptr) for ptr in self.vision_in_layer_post_norm_ptrs]} ") + + self.vision_post_norm_bias_tensors = [ + state_dict[naming.vision_post_norm_bias(i)].to(torch_dt_mat) for i in range(vision_nlayer) + ] + self.vision_post_norm_bias_ptrs = [ + self.vision_post_norm_bias_tensors[i].data_ptr() for i in range(vision_nlayer) + ] + self.vision_post_norm_bias = (c_void_p * vision_nlayer)(*self.vision_post_norm_bias_ptrs) + + + ############################################### + # MLP: fc1 / fc2 + ############################################### + + # fc1.weight + self.vision_mlp_fc1_weight_tensors = [ + state_dict[naming.vision_mlp_fc1_weight(i)].to(torch_dt_mat) for i in range(vision_nlayer) + ] + self.vision_mlp_fc1_weight_ptrs = [ + self.vision_mlp_fc1_weight_tensors[i].data_ptr() for i in range(vision_nlayer) + ] + self.vision_mlp_fc1_weight = (c_void_p * vision_nlayer)(*self.vision_mlp_fc1_weight_ptrs) + + # fc1.bias + self.vision_mlp_fc1_bias_tensors = [ + state_dict[naming.vision_mlp_fc1_bias(i)].to(torch_dt_mat) for i in range(vision_nlayer) + ] + self.vision_mlp_fc1_bias_ptrs = [ + self.vision_mlp_fc1_bias_tensors[i].data_ptr() for i in range(vision_nlayer) + ] + self.vision_mlp_fc1_bias = (c_void_p * vision_nlayer)(*self.vision_mlp_fc1_bias_ptrs) + + + # fc2.weight + self.vision_mlp_fc2_weight_tensors = [ + state_dict[naming.vision_mlp_fc2_weight(i)].to(torch_dt_mat) for i in range(vision_nlayer) + ] + self.vision_mlp_fc2_weight_ptrs = [ + self.vision_mlp_fc2_weight_tensors[i].data_ptr() for i in range(vision_nlayer) + ] + self.vision_mlp_fc2_weight = (c_void_p * vision_nlayer)(*self.vision_mlp_fc2_weight_ptrs) + + # fc2.bias + self.vision_mlp_fc2_bias_tensors = [ + state_dict[naming.vision_mlp_fc2_bias(i)].to(torch_dt_mat) for i in range(vision_nlayer) + ] + self.vision_mlp_fc2_bias_ptrs = [ + self.vision_mlp_fc2_bias_tensors[i].data_ptr() for i in range(vision_nlayer) + ] + self.vision_mlp_fc2_bias = (c_void_p * vision_nlayer)(*self.vision_mlp_fc2_bias_ptrs) + + + + # === 多模态投影器权重 === + if naming.projector_weight_1() in state_dict: + self.projector_weight_1_tensor = state_dict[naming.projector_weight_1()].to(torch_dt_mat) + self.projector_weight_1 = self.projector_weight_1_tensor.data_ptr() + else: + self.projector_weight_1 = 0 + + if naming.projector_bias_1() in state_dict: + self.projector_bias_1_tensor = state_dict[naming.projector_bias_1()].to(torch_dt_mat) + self.projector_bias_1 = self.projector_bias_1_tensor.data_ptr() + else: + self.projector_bias_1 = 0 + + if naming.projector_weight_2() in state_dict: + self.projector_weight_2_tensor = state_dict[naming.projector_weight_2()].to(torch_dt_mat) + self.projector_weight_2 = self.projector_weight_2_tensor.data_ptr() + else: + self.projector_weight_2 = 0 + + if naming.projector_bias_2() in state_dict: + self.projector_bias_2_tensor = state_dict[naming.projector_bias_2()].to(torch_dt_mat) + self.projector_bias_2 = self.projector_bias_2_tensor.data_ptr() + else: + self.projector_bias_2 = 0 + + # === 语言模型权重 (按照Jiuge模式) === + # 输入输出嵌入 + self.input_embd_tensor = state_dict[naming.input_embd()].to(torch_dt_mat) + self.input_embd = self.input_embd_tensor.data_ptr() + + self.output_norm_tensor = state_dict[naming.output_norm()].to(torch_dt_mat) + self.output_norm = self.output_norm_tensor.data_ptr() + + self.output_embd_tensor = state_dict[naming.output_embd()].to(torch_dt_mat) + self.output_embd = self.output_embd_tensor.data_ptr() + + # 注意力权重数组 + self.attn_norm_tensors = [ + state_dict[naming.attn_norm(i)].to(torch_dt_mat) for i in range(nlayer) + ] + self.attn_norm_ptrs = [ + self.attn_norm_tensors[i].data_ptr() for i in range(nlayer) + ] + self.attn_norm = (c_void_p * nlayer)(*self.attn_norm_ptrs) + + # # QKV权重 - 对于LLaVA,Q、K、V是分开的,但我们可以按Jiuge的方式合并处理 + # def qkv_slices(_i): + # _Q = ( + # state_dict[naming.attn_q(_i)] + # .reshape([nh, 2, dh // 2, d]) + # .transpose(1, 2) + # ) + # _K = ( + # state_dict[naming.attn_k(_i)] + # .reshape([nkvh, 2, dh // 2, d]) + # .transpose(1, 2) + # ) + # _V = state_dict[naming.attn_v(_i)].reshape([nkvh, dh // 2, 2, d]) + # _result = [] + # _nh = nh // ndev + # _nkvh = nkvh // ndev + # for _idev in range(ndev): + # _result.append(_Q[_idev * _nh : (_idev + 1) * _nh, :, :, :]) + # _result.append(_K[_idev * _nkvh : (_idev + 1) * _nkvh, :, :, :]) + # _result.append(_V[_idev * _nkvh : (_idev + 1) * _nkvh, :, :]) + # return _result + + # self.qkv_tensor = [ + # torch.concat(qkv_slices(i)).to(torch_dt_mat) for i in range(nlayer) + # ] + # if not transpose_weight: + # for i in range(nlayer): + # self.qkv_tensor[i] = ( + # self.qkv_tensor[i] + # .reshape(ndev, (nh + 2 * nkvh) // ndev * dh, d) + # .transpose(1, 2) + # .contiguous() + # ) + # self.qkv_tensor_ptrs = [self.qkv_tensor[i].data_ptr() for i in range(nlayer)] + # self.attn_qkv = (c_void_p * nlayer)(*self.qkv_tensor_ptrs) + + # # QKV bias (LLaVA通常没有bias) + # self.attn_qkv_b = (c_void_p * nlayer)() + # for i in range(nlayer): + # self.attn_qkv_b[i] = 0 + + # # Q norm 和 K norm (LLaVA通常没有) + # self.attn_q_norm = (c_void_p * nlayer)() + # self.attn_k_norm = (c_void_p * nlayer)() + # for i in range(nlayer): + # self.attn_q_norm[i] = 0 + # self.attn_k_norm[i] = 0 + + # # Attention O权重 + # self.attn_o_tensor = [ + # ( + # state_dict[naming.attn_o(i)] + # .to(torch_dt_mat) + # .reshape([d, ndev, nh // ndev * dh]) + # .transpose(0, 1) + # .contiguous() + # if transpose_weight + # else state_dict[naming.attn_o(i)] + # .transpose(0, 1) + # .to(torch_dt_mat) + # .contiguous() + # ) + # for i in range(nlayer) + # ] + # self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)] + # self.attn_o = (c_void_p * nlayer)(*self.attn_o_ptrs) + + # # FFN权重 + # self.ffn_norm_tensors = [ + # state_dict[naming.ffn_norm(i)].to(torch_dt_norm) for i in range(nlayer) + # ] + # self.ffn_norm_ptrs = [ + # self.ffn_norm_tensors[i].data_ptr() for i in range(nlayer) + # ] + # self.ffn_norm = (c_void_p * nlayer)(*self.ffn_norm_ptrs) + + # def gate_up_slices(_i): + # _result = [] + # _di = di // ndev + # for _idev in range(ndev): + # _start = _idev * _di + # _end = (_idev + 1) * _di + # _result.append(state_dict[naming.ffn_gate(_i)][_start:_end, :]) + # _result.append(state_dict[naming.ffn_up(_i)][_start:_end, :]) + # return _result + + # self.gate_up_tensors = [ + # torch.concat(gate_up_slices(i)).to(torch_dt_mat) for i in range(nlayer) + # ] + # if not transpose_weight: + # for i in range(nlayer): + # self.gate_up_tensors[i] = ( + # self.gate_up_tensors[i] + # .reshape(ndev, 2 * di // ndev, d) + # .transpose(1, 2) + # .contiguous() + # ) + # self.gate_up_ptrs = [self.gate_up_tensors[i].data_ptr() for i in range(nlayer)] + # self.ffn_gate_up = (c_void_p * nlayer)(*self.gate_up_ptrs) + + # self.ffn_down_tensor = [ + # ( + # state_dict[naming.ffn_down(i)] + # .to(torch_dt_mat) + # .reshape([d, ndev, di // ndev]) + # .transpose(0, 1) + # .contiguous() + # if transpose_weight + # else state_dict[naming.ffn_down(i)] + # .transpose(0, 1) + # .to(torch_dt_mat) + # .contiguous() + # ) + # for i in range(nlayer) + # ] + # self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)] + # self.ffn_down = (c_void_p * nlayer)(*self.ffn_down_ptrs) + + # # === 视觉编码器权重数组 === + # vision_layer_size = meta.vision_meta.vision_num_layers + # self.vision_encoder_weights = (c_void_p * (vision_layer_size * 10))() + + # # 填充视觉编码器权重 (简化版,实际应该按Jiuge模式处理) + # for i in range(vision_layer_size): + # idx = i * 10 + # # 这里简化处理,实际应该像Jiuge那样创建tensor对象并保存 + # vision_pre_norm_key = naming.vision_pre_norm(i) + # if vision_pre_norm_key in state_dict: + # self.vision_encoder_weights[idx] = state_dict[vision_pre_norm_key].data_ptr() + # else: + # self.vision_encoder_weights[idx] = 0 + + # # 其他视觉权重类似处理... + # for j in range(1, 10): + # self.vision_encoder_weights[idx + j] = 0 + + # 初始化父类结构 + super().__init__() + + + +class LLaVAForCauslLM: + def __init__(self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None): + def load_all_safetensors_from_dir(dir_path_: str): + tensors_ = {} + dir_path_ = Path(dir_path_) + for file in sorted(dir_path_.glob("*.safetensors")): + data_ = safetensors.safe_open(file, "pt") + for name_ in data_.keys(): + tensors_[name_] = data_.get_tensor(name_) + + return tensors_ + + + # 内部三个组件 + self.preprocessor = AutoProcessor.from_pretrained(model_dir_path) + # self.vision_encoder = LLaVAVisionEncoder(model_dir_path, device_type, ndev) + # self.mm_projector = LLaVAMultiModalProjector(model_dir_path, device_type, ndev) + # self.language_model = JiugeForCauslLM(model_dir_path, device_type, ndev) # ✅ 复用 + #print("Loading model weights to host...") + load_start_time = time.time() + + with open(os.path.join(model_dir_path, "config.json"), "r") as f: + config = json.load(f) + self.config = config + self.eos_token_id = [2] + # print(f"Model config: {self.config}") + # print(f"Model eos_token_id: {self.eos_token_id}") + + # transpose_weight = ( + # device != DeviceType.DEVICE_TYPE_ASCEND + # ) # y = xW is faster than y=xW^T on Ascend + + # print(f"device: {device}") + self.llava_model = LlavaModel() + + if "llava" == config["model_type"]: + #print("Loading LLaVA model...") + state_dict = load_all_safetensors_from_dir(model_dir_path) + #print(f"state_dict keys: {list(state_dict.keys())[:10]} ...") + self.meta = LlavaMetaFromLlava(config, max_tokens=max_tokens) + # print(f"meta type: {type(self.meta)}") # meta type: + # print(f"meta value: {self.meta}") # meta value: <__main__.LlavaMetaFromLlava object at 0x7fda3c5e91c0> + self.weights = LlavaWeightsImpl( + self.meta, + LlavaWeightsNaming(), + state_dict, + ndev=ndev, + ) + + transpose_weight = ( + device != DeviceType.DEVICE_TYPE_ASCEND + ) # y = xW is faster than y=xW^T on Ascend + + + self.language_meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) + self.language_weights = JiugeWeightsImpl( + self.language_meta, + LlamaWeightsNaming(), + state_dict, + ndev=ndev, + transpose_weight=transpose_weight, + ) + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_dir_path + ) + + # print(f"weights type: {type(self.weights)}") # weights type: + # print(f"weights value: {self.weights}") # weights value: <__main__.LlavaWeightsImpl object at 0x7fda3c5e9340> + load_end_time = time.time() + # print(f"Time used: {load_end_time - load_start_time:.3f}s") + # print(f"Creating model on {ndev} devices...") + self.dev_ids = (c_int * ndev)(*[i for i in range(ndev)]) + self.ndev = ndev + self.device = device + + self.model_instance = self.llava_model.create_model( + byref(self.meta), + byref(self.weights), + device, + ndev, + self.dev_ids, + ) + + # Language model (Jiuge) instance for end-to-end generation (reuses WithOverrides injection). + self.jiuge_model = JiugeModel() + self.language_model_instance = self.jiuge_model.create_model( + byref(self.language_meta), + byref(self.language_weights), + device, + ndev, + self.dev_ids, + ) + + load_end_time = time.time() + print(f"Time used: {load_end_time - load_start_time:.3f}s") + + + def max_context_len(self): + return self.meta.language_meta.dctx + + def create_kv_cache(self): + """创建 LLaVA 的 KV Cache""" + # 调用 C++ 层的 createKVCache 函数 + # 参数:nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev + return self.llava_model.create_kv_cache( + self.meta.language_meta.nlayer, # 语言模型层数 + self.meta.language_meta.dctx, # 最大上下文长度 + self.meta.language_meta.nkvh, # key-value head 数 + self.meta.language_meta.dh, # head 维度 + self.meta.language_meta.dh, # value 维度 (通常与dh相同) + self.meta.language_meta.dt_logits, # 数据类型 + self.device, # 设备类型 + self.dev_ids, # 设备ID列表 + self.ndev # 设备数量 + ) + + def debug_image(self, ptr, pixel_values, num): + print("tensor数组:", pixel_values.flatten()[:num].tolist()) + num_values = pixel_values.numel() + # 数值数组 + values_list = [] + # 二进制表示(uint16或二进制字符串) + binary_list = [] + + for i in range(num_values): + addr = ptr + i * 2 + raw_uint16 = ctypes.c_uint16.from_address(addr).value + + # 正确解读位模式为float16 + float16_val = np.array([raw_uint16], dtype=np.uint16).view(np.float16)[0] + + values_list.append(float(float16_val)) # 转为Python float + binary_list.append(f"{raw_uint16:016b}") + + print("数值数组:", values_list[:num]) + print("二进制数组:", binary_list[:num]) + + + def drop_kv_cache(self, kv_cache): + """删除 LLaVA 的 KV Cache""" + self.llava_model.drop_kv_cache(kv_cache) + + # === LLaVA四阶段推理方法 === + LLAVA_VISION_STAGE_PRE_LN = 0 + LLAVA_VISION_STAGE_SELECT_ALL = 1 + LLAVA_VISION_STAGE_SELECT_PATCH = 2 + LLAVA_VISION_STAGE_PROJECTOR = 3 + LLAVA_VISION_STAGE_PROJECTOR_ALL = 4 + + def _alloc_vision_stage_output(self, stage: int) -> torch.Tensor: + vision_seq = int(self.meta.vision_meta.num_patches) + 1 + vision_dim = int(self.meta.vision_meta.vision_embed_dim) + text_dim = int(self.meta.projector_meta.text_embed_dim) + if stage == self.LLAVA_VISION_STAGE_PRE_LN: + shape = (vision_seq, vision_dim) + elif stage == self.LLAVA_VISION_STAGE_SELECT_ALL: + shape = (vision_seq, vision_dim) + elif stage == self.LLAVA_VISION_STAGE_SELECT_PATCH: + shape = (vision_seq - 1, vision_dim) + elif stage == self.LLAVA_VISION_STAGE_PROJECTOR: + shape = (vision_seq - 1, text_dim) + elif stage == self.LLAVA_VISION_STAGE_PROJECTOR_ALL: + shape = (vision_seq, text_dim) + else: + raise ValueError(f"Unknown vision stage: {stage}") + return torch.empty(shape, dtype=torch.float16, device="cpu") + + def batch_infer_vision_stage(self, pixel_values, stage: int): + if pixel_values is None: + return None + if hasattr(pixel_values, "contiguous"): + pixel_values = pixel_values.contiguous() + if len(pixel_values.shape) != 4 or int(pixel_values.shape[0]) != 1: + raise ValueError(f"Only batch_size=1 supported, got shape={tuple(pixel_values.shape)}") + + image_data_fp16 = pixel_values.to(torch.float16).cpu() + image_data = image_data_fp16.data_ptr() + out = self._alloc_vision_stage_output(stage) + + self.llava_model.infer_batch_vision_stage( + self.model_instance, + image_data, + stage, + out.data_ptr(), + ) + return out + + def batch_infer_encode(self, pixel_values, input_tokens_list): + """阶段1: Vision Encoder - 将图像编码为视觉特征""" + return self.batch_infer_vision_stage(pixel_values, self.LLAVA_VISION_STAGE_PROJECTOR) + + + def batch_infer_compressor(self, features, kv_caches): + """阶段4: KV-Cache Compression - 压缩KV缓存以节省内存""" + if kv_caches is None: + print("=== KV-Cache Compression Skipped (No KV Caches) ===") + return kv_caches + + print("=== LLaVA KV-Cache Compression ===") + + # TODO: 集成Fastcache的压缩算法 + print("KV-Cache compression: (Future - Fastcache integration)") + + return kv_caches + + def _find_image_token_positions(self, input_ids: torch.Tensor) -> list[int]: + image_token_index = int(self.config.get("image_token_index", 32000)) + ids = input_ids[0].to(dtype=torch.int64) + return (ids == image_token_index).nonzero(as_tuple=False).flatten().tolist() + + def _prefill_with_overrides(self, input_ids: torch.Tensor, pixel_values: torch.Tensor, + temperature_: float, topk_: int, topp_: float, + logits: Optional[torch.Tensor] = None): + # 1) image embeds (projector output) + vision_start_time = time.time() + img_embeds = self.batch_infer_vision_stage( + pixel_values, self.LLAVA_VISION_STAGE_PROJECTOR + ).contiguous() + vision_end_time = time.time() + vision_time = float(vision_end_time - vision_start_time) + # 2) override positions: processor already expands to 576 image tokens for v1.5 + pos = self._find_image_token_positions(input_ids) + if len(pos) != int(img_embeds.shape[0]): + raise ValueError(f"image token count mismatch: pos={len(pos)} embeds={int(img_embeds.shape[0])}") + override_pos = (c_uint * len(pos))(*pos) + + # 3) tokens + tokens = input_ids[0].to(dtype=torch.int32).tolist() + ntok = len(tokens) + tokens_c = (c_uint * ntok)(*tokens) + req_lens = (c_uint * 1)(ntok) + req_pos = (c_uint * 1)(0) + + # 4) kv cache + kv = self.jiuge_model.create_kv_cache( + self.language_meta.nlayer, + self.language_meta.dctx, + self.language_meta.nkvh, + self.language_meta.dh, + self.language_meta.dh, + self.language_meta.dt_logits, + self.device, + self.dev_ids, + self.ndev, + ) + kv_caches = (POINTER(KVCacheCStruct) * 1)(kv) + + # 5) sampling + temperature = (c_float * 1)(float(temperature_)) + topk = (c_uint * 1)(int(topk_)) + topp = (c_float * 1)(float(topp_)) + out = (c_uint * 1)() + + prefill_start_time = time.time() + if logits is None: + self.jiuge_model.infer_batch_with_overrides( + self.language_model_instance, + tokens_c, + ntok, + req_lens, + 1, + req_pos, + kv_caches, + len(pos), + override_pos, + img_embeds.data_ptr(), + temperature, + topk, + topp, + out, + ) + else: + self.jiuge_model.infer_batch_with_overrides_with_logits( + self.language_model_instance, + tokens_c, + ntok, + req_lens, + 1, + req_pos, + kv_caches, + len(pos), + override_pos, + img_embeds.data_ptr(), + temperature, + topk, + topp, + out, + logits.data_ptr(), + ) + prefill_end_time = time.time() + prefill_time = float(prefill_end_time - prefill_start_time) + return int(out[0]), kv, kv_caches, ntok, vision_time, prefill_time + + def _decode_one(self, last_token_id: int, rope_pos: int, kv_caches, + temperature_: float, topk_: int, topp_: float, + kv_pos: Optional[int] = None, + logits: Optional[torch.Tensor] = None) -> int: + req_lens = (c_uint * 1)(1) + req_pos = (c_uint * 1)(rope_pos) + tokens_c = (c_uint * 1)(int(last_token_id)) + temperature = (c_float * 1)(float(temperature_)) + topk = (c_uint * 1)(int(topk_)) + topp = (c_float * 1)(float(topp_)) + out = (c_uint * 1)() + if kv_pos is None: + if logits is None: + self.jiuge_model.infer_batch( + self.language_model_instance, + tokens_c, + 1, + req_lens, + 1, + req_pos, + kv_caches, + temperature, + topk, + topp, + out, + ) + else: + self.jiuge_model.infer_batch_with_logits( + self.language_model_instance, + tokens_c, + 1, + req_lens, + 1, + req_pos, + kv_caches, + temperature, + topk, + topp, + out, + logits.data_ptr(), + ) + else: + kv_pos_c = (c_uint * 1)(int(kv_pos)) + if logits is None: + self.jiuge_model.infer_batch_ex( + self.language_model_instance, + tokens_c, + 1, + req_lens, + 1, + req_pos, + kv_pos_c, + kv_caches, + temperature, + topk, + topp, + out, + ) + else: + self.jiuge_model.infer_batch_ex_with_logits( + self.language_model_instance, + tokens_c, + 1, + req_lens, + 1, + req_pos, + kv_pos_c, + kv_caches, + temperature, + topk, + topp, + out, + logits.data_ptr(), + ) + return int(out[0]) + + def generate( + self, + messages, + max_new_tokens=128, + topp_=1.0, + topk_=1, + temperature_=1.0, + verbose=False, + kv_compress: bool = False, + kv_compress_bin: str = "", + kv_compress_factor: int = 5, + kv_compress_min_seq_len: int = 2, + perplexity: bool = False, + perplexity_verbose_steps: int = 5, + time_stats: bool = False): + import math + + def token_log_prob(logits_1d: torch.Tensor, token_id: int) -> float: + lp = torch.nn.functional.log_softmax(logits_1d.float(), dim=-1)[int(token_id)] + return float(lp.item()) + + total_nll = 0.0 + total_tokens = 0 + + mm_inputs = self.preprocessor.apply_chat_template( + messages, + add_generation_prompt=True, + tokenize=True, + return_dict=True, + return_tensors="pt", + ) + pixel_values = mm_inputs.pixel_values + attention_mask = mm_inputs.attention_mask + input_ids = mm_inputs.input_ids + #print(f"Input token IDs shape: {input_ids.shape}") + + # 将torch tensor转换为Python列表,就像jiuge.py那样 + if hasattr(input_ids, 'flatten'): + input_ids_list = input_ids.flatten().tolist() + else: + input_ids_list = input_ids.tolist() + + if verbose: + print("pixel_values.shape:", tuple(pixel_values.shape)) + print("attention_mask.shape:", tuple(attention_mask.shape)) + print("input_ids_len:", int(input_ids.shape[1])) + + # Prefill with image embedding overrides (+ optional logits capture) + prefill_logits = None + if perplexity: + dvoc = int(self.language_meta.dvoc) + ntok = int(input_ids.shape[1]) + prefill_logits = torch.empty( + (ntok, dvoc), + dtype=self.language_meta.torch_dtype_logits, + device="cpu", + ) + + first_token, kv, kv_caches, ntok, vision_time, prefill_time = self._prefill_with_overrides( + input_ids, + pixel_values, + temperature_, + topk_, + topp_, + logits=prefill_logits, + ) + + generated = [first_token] + rope_pos = ntok + #import pdb;pdb.set_trace() + kv_pos: Optional[int] = None + if kv_compress: + if self.ndev != 1: + raise ValueError("KV compression currently requires ndev=1 (compressKVCacheInplace is single-device).") + if not kv_compress_bin: + raise ValueError("kv_compress=True requires kv_compress_bin (path to llava_mlp.bin)") + + # Approx strategy: treat everything before the end of the image token block as "image prefix". + # This includes a small text prefix before the image tokens (e.g., 'USER:'), but keeps the + # API contract (image_kv_len is prefix length). + image_pos = self._find_image_token_positions(input_ids) + image_kv_len = int(max(image_pos) + 1) if image_pos else 0 + if verbose: + print("kv_compress:", {"image_kv_len": image_kv_len, "image_token_count": len(image_pos), "ntok": ntok}) + + cfg = KVCompressionConfigCStruct( + enable=1, + compression_factor=int(kv_compress_factor), + min_seq_len=int(kv_compress_min_seq_len), + image_kv_len=int(image_kv_len), + weight_path=kv_compress_bin.encode("utf-8"), + ) + kv_pos = int(self.jiuge_model.compress_kv_cache_inplace(kv, int(ntok), cfg)) + if verbose: + print("kv_compress_done:", {"kv_pos": kv_pos, "rope_pos": int(rope_pos)}) + + if perplexity: + if prefill_logits is None or int(prefill_logits.shape[0]) != int(ntok): + raise RuntimeError("prefill_logits missing or shape mismatch") + lp0 = token_log_prob(prefill_logits[int(ntok) - 1], first_token) + total_nll += -lp0 + total_tokens += 1 + if int(perplexity_verbose_steps) > 0: + tok_str = self.tokenizer.decode([int(first_token)], skip_special_tokens=False) + print(f"[ppl] step=0 token={int(first_token)} text={tok_str!r} log_prob={lp0:.6f}") + + decode_start_time = time.time() + for _ in range(int(max_new_tokens) - 1): + if generated[-1] in self.eos_token_id: + break + decode_logits = None + if perplexity: + decode_logits = torch.empty( + (1, int(self.language_meta.dvoc)), + dtype=self.language_meta.torch_dtype_logits, + device="cpu", + ) + nxt = self._decode_one( + generated[-1], + rope_pos, + kv_caches, + temperature_, + topk_, + topp_, + kv_pos=kv_pos, + logits=decode_logits, + ) + generated.append(nxt) + if perplexity: + if decode_logits is None: + raise RuntimeError("decode_logits missing") + lp = token_log_prob(decode_logits[0], nxt) + total_nll += -lp + total_tokens += 1 + if total_tokens <= int(perplexity_verbose_steps): + tok_str = self.tokenizer.decode([int(nxt)], skip_special_tokens=False) + print(f"[ppl] step={total_tokens-1} token={int(nxt)} text={tok_str!r} log_prob={lp:.6f}") + rope_pos += 1 + if kv_pos is not None: + kv_pos += 1 + decode_end_time = time.time() + + text = self.tokenizer.decode(generated, skip_special_tokens=False) + if verbose: + print("generated_token_ids:", generated) + print("decoded:", text) + + self.jiuge_model.drop_kv_cache(kv) + if time_stats: + steps = int(len(generated)) + decode_time = float(decode_end_time - decode_start_time) + llm_total_time = float(prefill_time + decode_time) + avg_time_per_step = ( + llm_total_time * 1000 / (steps - 1) + if steps > 1 + else llm_total_time * 1000 + ) + # Mirror scripts/jiuge.py's primary metric, but also expose vision time for multimodal runs. + print(f"Vision time: {vision_time * 1000:.3f}ms") + print(f"Prefill time: {prefill_time * 1000:.3f}ms") + print(f"Decode time: {decode_time * 1000:.3f}ms") + print(f"Time per step: {avg_time_per_step:.3f}ms") + if perplexity and total_tokens > 0: + ppl = math.exp(total_nll / total_tokens) + print(f"Perplexity: {ppl:.4f}") + return text + + + + + + + + + + + + + + + + + + if verbose: + print("LLaVAForConditionalGeneration.generate:") + print(f" pixel_values.shape: {pixel_values.shape}") + print(f" attention_mask.shape: {attention_mask.shape}") + print(f" input_ids.shape: {input_ids.shape}") + # TODO: 2. 视觉编码 + # vision_features = self.vision_encoder.encode(image_tensor) + + # TODO: 3. 多模态投影 + # image_tokens = self.mm_projector.project(vision_features) + + # TODO: 4. Token融合 + # combined_tokens = self._fuse_tokens(prompt, image_tokens) + + # TODO: 5. 语言模型生成 (复用Jiuge) + # return self.language_model.generate_tokens(combined_tokens, max_new_tokens, verbose) + + + + +def test(): + if len(sys.argv) < 3: + print( + "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] [n_device] [--verbose]" + ) + sys.exit(1) + + # Parse command line arguments + model_path = sys.argv[2] + device_type = DeviceType.DEVICE_TYPE_CPU + verbose = False + + # Check for verbose flag + for arg in sys.argv: + if arg == "--verbose": + verbose = True + break + + if sys.argv[1] == "--cpu": + device_type = DeviceType.DEVICE_TYPE_CPU + elif sys.argv[1] == "--nvidia": + device_type = DeviceType.DEVICE_TYPE_NVIDIA + elif sys.argv[1] == "--cambricon": + device_type = DeviceType.DEVICE_TYPE_CAMBRICON + elif sys.argv[1] == "--ascend": + device_type = DeviceType.DEVICE_TYPE_ASCEND + elif sys.argv[1] == "--metax": + device_type = DeviceType.DEVICE_TYPE_METAX + elif sys.argv[1] == "--moore": + device_type = DeviceType.DEVICE_TYPE_MOORE + elif sys.argv[1] == "--iluvatar": + device_type = DeviceType.DEVICE_TYPE_ILUVATAR + elif sys.argv[1] == "--kunlun": + device_type = DeviceType.DEVICE_TYPE_KUNLUN + elif sys.argv[1] == "--hygon": + device_type = DeviceType.DEVICE_TYPE_HYGON + else: + print( + "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore | --iluvatar | --kunlun | --hygon] [n_device] [--verbose]" + ) + sys.exit(1) + + # Find n_device argument (skip --verbose) + ndev_args = [arg for arg in sys.argv[3:] if arg != "--verbose"] + ndev = int(ndev_args[0]) if ndev_args else 1 + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": "scripts/img/47_42.jpg"}, + {"type": "text", "text": "Describe this image."} + ] + }, + ] + + model = LLaVAForCauslLM(model_path, device_type, ndev) + model.generate(messages, verbose=verbose) + # model.destroy_model_instance() + + +if __name__ == "__main__": + test() + + + + +# compress = Compress() + +# compress.compress(kv_caches, [(i_start, i_end), .......]) diff --git a/scripts/llava_chat.py b/scripts/llava_chat.py new file mode 100644 index 00000000..90681f77 --- /dev/null +++ b/scripts/llava_chat.py @@ -0,0 +1,79 @@ +import argparse +import sys + +from libinfinicore_infer import DeviceType +from llava import LLaVAForCauslLM + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument( + "--dev", + choices=["cpu", "nvidia", "hygon", "moore"], + default="cpu", + help="Device backend for inference.", + ) + ap.add_argument("--ndev", type=int, default=1) + ap.add_argument("--model-dir", required=True) + ap.add_argument("--image", required=True) + ap.add_argument("--question", default="Describe this image.") + ap.add_argument("--max-new-tokens", type=int, default=60) + ap.add_argument("--topk", type=int, default=1) + ap.add_argument("--topp", type=float, default=1.0) + ap.add_argument("--temperature", type=float, default=1.0) + ap.add_argument("--verbose", action="store_true") + ap.add_argument("--kv-compress", action="store_true", help="Enable in-place KV cache compression after prefill.") + ap.add_argument("--kv-compress-bin", default="", help="Path to llava_mlp.bin compressor weights.") + ap.add_argument("--kv-compress-factor", type=int, default=5) + ap.add_argument("--kv-compress-min-seq-len", type=int, default=2) + ap.add_argument("--perplexity", action="store_true", help="Collect logits for perplexity calculation") + ap.add_argument("--time", action="store_true", help="Print timing metrics (time per step, etc.)") + args = ap.parse_args() + + if args.kv_compress: + if args.ndev != 1: + ap.error("--kv-compress currently requires --ndev 1") + if not args.kv_compress_bin: + ap.error("--kv-compress requires --kv-compress-bin") + + messages = [ + { + "role": "user", + "content": [ + {"type": "image", "url": args.image}, + {"type": "text", "text": args.question}, + ], + } + ] + + device_map = { + "cpu": DeviceType.DEVICE_TYPE_CPU, + "nvidia": DeviceType.DEVICE_TYPE_NVIDIA, + "hygon": DeviceType.DEVICE_TYPE_HYGON, + "moore": DeviceType.DEVICE_TYPE_MOORE, + } + device_type = device_map[args.dev] + model = LLaVAForCauslLM( + args.model_dir, + device=device_type, + ndev=args.ndev, + ) + text = model.generate( + messages, + max_new_tokens=args.max_new_tokens, + topk_=args.topk, + topp_=args.topp, + temperature_=args.temperature, + verbose=args.verbose, + kv_compress=bool(args.kv_compress), + kv_compress_bin=str(args.kv_compress_bin), + kv_compress_factor=int(args.kv_compress_factor), + kv_compress_min_seq_len=int(args.kv_compress_min_seq_len), + perplexity=bool(args.perplexity), + time_stats=bool(args.time), + ) + sys.stdout.write(text + "\n") + + +if __name__ == "__main__": + main() diff --git a/scripts/minicpmv.py b/scripts/minicpmv.py new file mode 100644 index 00000000..a0d56c75 --- /dev/null +++ b/scripts/minicpmv.py @@ -0,0 +1,482 @@ +import argparse +import json +import os +from ctypes import POINTER, c_float, c_int, c_uint +from pathlib import Path + +import torch +from PIL import Image +from safetensors.torch import safe_open + +from libinfinicore_infer import ( + DataType, + DeviceType, + KVCacheCStruct, + MiniCPMVLanguageMetaCStruct, + MiniCPMVMetaCStruct, + MiniCPMVModel, + MiniCPMVResamplerMetaCStruct, + MiniCPMVVisionMetaCStruct, + MiniCPMVWeightsCStruct, + MiniCPMVSiglipLayerWeightsCStruct, +) + + +def _load_tensor(model_dir: Path, weight_map: dict, key: str) -> torch.Tensor: + if key not in weight_map: + if key.endswith(".weight") and key[: -len(".weight")] in weight_map: + key = key[: -len(".weight")] + elif (key + ".weight") in weight_map: + key = key + ".weight" + filename = weight_map[key] + full = model_dir / filename + with safe_open(str(full), framework="pt", device="cpu") as f: + return f.get_tensor(key) + + +def _make_siglip_layer_struct( + model_dir: Path, weight_map: dict, layer_idx: int, torch_dt +) -> tuple[MiniCPMVSiglipLayerWeightsCStruct, dict]: + keepalive: dict[str, torch.Tensor] = {} + + def to_dt(x: torch.Tensor) -> torch.Tensor: + return x.detach().to(dtype=torch_dt).contiguous() + + def t_weight(key: str) -> torch.Tensor: + w = _load_tensor(model_dir, weight_map, key) + return to_dt(w.transpose(0, 1)) + + lw = MiniCPMVSiglipLayerWeightsCStruct() + keepalive["ln1_w"] = to_dt( + _load_tensor( + model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.layer_norm1.weight" + ) + ) + keepalive["ln1_b"] = to_dt( + _load_tensor( + model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.layer_norm1.bias" + ) + ) + keepalive["ln2_w"] = to_dt( + _load_tensor( + model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.layer_norm2.weight" + ) + ) + keepalive["ln2_b"] = to_dt( + _load_tensor( + model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.layer_norm2.bias" + ) + ) + lw.layer_norm1_weight = keepalive["ln1_w"].data_ptr() + lw.layer_norm1_bias = keepalive["ln1_b"].data_ptr() + lw.layer_norm2_weight = keepalive["ln2_w"].data_ptr() + lw.layer_norm2_bias = keepalive["ln2_b"].data_ptr() + + keepalive["q_w_t"] = t_weight(f"vpm.encoder.layers.{layer_idx}.self_attn.q_proj.weight") + keepalive["k_w_t"] = t_weight(f"vpm.encoder.layers.{layer_idx}.self_attn.k_proj.weight") + keepalive["v_w_t"] = t_weight(f"vpm.encoder.layers.{layer_idx}.self_attn.v_proj.weight") + keepalive["o_w_t"] = t_weight(f"vpm.encoder.layers.{layer_idx}.self_attn.out_proj.weight") + keepalive["q_b"] = to_dt( + _load_tensor( + model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.self_attn.q_proj.bias" + ) + ) + keepalive["k_b"] = to_dt( + _load_tensor( + model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.self_attn.k_proj.bias" + ) + ) + keepalive["v_b"] = to_dt( + _load_tensor( + model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.self_attn.v_proj.bias" + ) + ) + keepalive["o_b"] = to_dt( + _load_tensor( + model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.self_attn.out_proj.bias" + ) + ) + lw.q_weight = keepalive["q_w_t"].data_ptr() + lw.q_bias = keepalive["q_b"].data_ptr() + lw.k_weight = keepalive["k_w_t"].data_ptr() + lw.k_bias = keepalive["k_b"].data_ptr() + lw.v_weight = keepalive["v_w_t"].data_ptr() + lw.v_bias = keepalive["v_b"].data_ptr() + lw.out_weight = keepalive["o_w_t"].data_ptr() + lw.out_bias = keepalive["o_b"].data_ptr() + + keepalive["fc1_w_t"] = t_weight(f"vpm.encoder.layers.{layer_idx}.mlp.fc1.weight") + keepalive["fc2_w_t"] = t_weight(f"vpm.encoder.layers.{layer_idx}.mlp.fc2.weight") + keepalive["fc1_b"] = to_dt( + _load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.mlp.fc1.bias") + ) + keepalive["fc2_b"] = to_dt( + _load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.mlp.fc2.bias") + ) + lw.fc1_weight = keepalive["fc1_w_t"].data_ptr() + lw.fc1_bias = keepalive["fc1_b"].data_ptr() + lw.fc2_weight = keepalive["fc2_w_t"].data_ptr() + lw.fc2_bias = keepalive["fc2_b"].data_ptr() + + return lw, keepalive + + +def _build_vision_model(model_dir: Path, torch_dt_logits, dt_logits: DataType, device: DeviceType): + config = json.loads((model_dir / "config.json").read_text()) + index = json.loads((model_dir / "model.safetensors.index.json").read_text()) + weight_map = index["weight_map"] + + vision_cfg = config["vision_config"] + patch = int(vision_cfg["patch_size"]) + d_v = int(vision_cfg["hidden_size"]) + nh_v = int(vision_cfg["num_attention_heads"]) + di_v = int(vision_cfg["intermediate_size"]) + nlayer = int(vision_cfg["num_hidden_layers"]) + + language_meta = MiniCPMVLanguageMetaCStruct( + dt_logits=dt_logits, + nlayer=int(config["num_hidden_layers"]), + d=int(config["hidden_size"]), + nh=int(config["num_attention_heads"]), + nkvh=int(config["num_key_value_heads"]), + dh=int(config["hidden_size"] // config["num_attention_heads"]), + di=int(config["intermediate_size"]), + dctx=int(config["max_position_embeddings"]), + dvoc=int(config["vocab_size"]), + epsilon=float(config["rms_norm_eps"]), + theta=float(config["rope_theta"]), + end_token=int(config["eos_token_id"]), + ) + vision_meta = MiniCPMVVisionMetaCStruct( + patch_size=patch, + vision_embed_dim=d_v, + vision_num_layers=nlayer, + vision_num_heads=nh_v, + vision_intermediate_size=di_v, + vision_layer_norm_eps=1e-6, + vision_image_size=int(vision_cfg["image_size"]), + vision_num_positions=4900, + ) + resampler_meta = MiniCPMVResamplerMetaCStruct( + num_queries=int(config["query_num"]), + embed_dim=int(config["hidden_size"]), + num_heads=int(config["num_attention_heads"]), + kv_dim=d_v, + layer_norm_eps=1e-6, + max_patches_h=70, + max_patches_w=70, + ) + meta = MiniCPMVMetaCStruct( + vision_meta=vision_meta, resampler_meta=resampler_meta, language_meta=language_meta + ) + + keepalive: dict[str, object] = {} + + keepalive["patch_w"] = _load_tensor(model_dir, weight_map, "vpm.embeddings.patch_embedding.weight").detach().to(dtype=torch_dt_logits).contiguous() + keepalive["patch_b"] = _load_tensor(model_dir, weight_map, "vpm.embeddings.patch_embedding.bias").detach().to(dtype=torch_dt_logits).contiguous() + keepalive["pos_emb"] = _load_tensor(model_dir, weight_map, "vpm.embeddings.position_embedding.weight").detach().to(dtype=torch_dt_logits).contiguous() + + layers = [] + for i in range(nlayer): + lw, ka = _make_siglip_layer_struct(model_dir, weight_map, i, torch_dt_logits) + layers.append(lw) + for k, v in ka.items(): + keepalive[f"l{i}_{k}"] = v + layers_arr = (MiniCPMVSiglipLayerWeightsCStruct * nlayer)(*layers) + + keepalive["post_ln_w"] = _load_tensor(model_dir, weight_map, "vpm.post_layernorm.weight").detach().to(dtype=torch_dt_logits).contiguous() + keepalive["post_ln_b"] = _load_tensor(model_dir, weight_map, "vpm.post_layernorm.bias").detach().to(dtype=torch_dt_logits).contiguous() + + def t(key: str) -> torch.Tensor: + return _load_tensor(model_dir, weight_map, key).detach().to(dtype=torch_dt_logits).transpose(0, 1).contiguous() + + keepalive["res_kv_proj_w_t"] = t("resampler.kv_proj.weight") + keepalive["res_in_w_t"] = t("resampler.attn.in_proj_weight") + keepalive["res_out_w_t"] = t("resampler.attn.out_proj.weight") + keepalive["res_in_b"] = _load_tensor(model_dir, weight_map, "resampler.attn.in_proj_bias").detach().to(dtype=torch_dt_logits).contiguous() + keepalive["res_out_b"] = _load_tensor(model_dir, weight_map, "resampler.attn.out_proj.bias").detach().to(dtype=torch_dt_logits).contiguous() + keepalive["res_query"] = _load_tensor(model_dir, weight_map, "resampler.query").detach().to(dtype=torch_dt_logits).contiguous() + keepalive["res_proj"] = _load_tensor(model_dir, weight_map, "resampler.proj").detach().to(dtype=torch_dt_logits).contiguous() + for name in ["ln_q", "ln_kv", "ln_post"]: + keepalive[f"{name}_w"] = _load_tensor(model_dir, weight_map, f"resampler.{name}.weight").detach().to(dtype=torch_dt_logits).contiguous() + keepalive[f"{name}_b"] = _load_tensor(model_dir, weight_map, f"resampler.{name}.bias").detach().to(dtype=torch_dt_logits).contiguous() + + weights = MiniCPMVWeightsCStruct() + weights.vpm_patch_embedding_weight = keepalive["patch_w"].data_ptr() + weights.vpm_patch_embedding_bias = keepalive["patch_b"].data_ptr() + weights.vpm_position_embedding = keepalive["pos_emb"].data_ptr() + weights.vpm_layers = layers_arr + weights.vpm_post_layernorm_weight = keepalive["post_ln_w"].data_ptr() + weights.vpm_post_layernorm_bias = keepalive["post_ln_b"].data_ptr() + + weights.resampler_query = keepalive["res_query"].data_ptr() + weights.resampler_kv_proj_weight = keepalive["res_kv_proj_w_t"].data_ptr() + weights.resampler_attn_in_proj_weight = keepalive["res_in_w_t"].data_ptr() + weights.resampler_attn_in_proj_bias = keepalive["res_in_b"].data_ptr() + weights.resampler_attn_out_proj_weight = keepalive["res_out_w_t"].data_ptr() + weights.resampler_attn_out_proj_bias = keepalive["res_out_b"].data_ptr() + weights.resampler_ln_q_weight = keepalive["ln_q_w"].data_ptr() + weights.resampler_ln_q_bias = keepalive["ln_q_b"].data_ptr() + weights.resampler_ln_kv_weight = keepalive["ln_kv_w"].data_ptr() + weights.resampler_ln_kv_bias = keepalive["ln_kv_b"].data_ptr() + weights.resampler_ln_post_weight = keepalive["ln_post_w"].data_ptr() + weights.resampler_ln_post_bias = keepalive["ln_post_b"].data_ptr() + weights.resampler_proj = keepalive["res_proj"].data_ptr() + + # Language weights unused + weights.nlayer = 0 + weights.dt_norm = dt_logits + weights.dt_mat = dt_logits + weights.transpose_linear_weights = 0 + weights.input_embd = 0 + weights.output_norm = 0 + weights.output_embd = 0 + weights.attn_norm = None + weights.attn_qkv = None + weights.attn_qkv_b = None + weights.attn_q_norm = None + weights.attn_k_norm = None + weights.attn_o = None + weights.ffn_norm = None + weights.ffn_gate_up = None + weights.ffn_down = None + + # Keep ctypes objects alive (C++ holds pointers to them). + keepalive["layers_arr"] = layers_arr + keepalive["weights_struct"] = weights + + m = MiniCPMVModel() + dev_ids = (c_int * 1)(0) + handle = m.create_model(meta, weights, device, 1, dev_ids) + return m, handle, meta, keepalive + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--model-dir", required=True) + ap.add_argument("--image", action="append", required=True, help="Repeatable image path") + ap.add_argument("--question", default="请描述图片内容。") + ap.add_argument("--max-steps", type=int, default=64) + ap.add_argument("--max-tokens", type=int, default=4096) + ap.add_argument("--temperature", type=float, default=1.0) + ap.add_argument("--topk", type=int, default=1) + ap.add_argument("--topp", type=float, default=1.0) + ap.add_argument("--max-slice-nums", type=int, default=None) + ap.add_argument("--vision-f32", action="store_true", help="Compute vision in FP32 then cast to LLM dtype") + ap.add_argument("--hygon", action="store_true", help="Run on Hygon device") + ap.add_argument("--debug", action="store_true") + args = ap.parse_args() + + debug = args.debug + + model_dir = Path(args.model_dir) + + # LLM loader (Jiuge) + from jiuge import JiugeForCauslLM + + device = DeviceType.DEVICE_TYPE_HYGON if args.hygon else DeviceType.DEVICE_TYPE_CPU + + dtype_override = torch.float16 if args.hygon else None + llm = JiugeForCauslLM( + str(model_dir), + device=device, + ndev=1, + max_tokens=args.max_tokens, + dtype_override=dtype_override, + ) + + # HF processor + preproc_cfg = json.loads((model_dir / "preprocessor_config.json").read_text()) + from minicpmv_config.image_processing_minicpmv import MiniCPMVImageProcessor + from minicpmv_config.processing_minicpmv import MiniCPMVProcessor + + image_processor = MiniCPMVImageProcessor(**preproc_cfg) + processor = MiniCPMVProcessor(image_processor=image_processor, tokenizer=llm.tokenizer) + + # Build user content with one tag per image (pattern required by processor: `(./)`) + tags = "\n".join(["./" for _ in args.image]) + user_content = f"{tags}\n{args.question}" + prompt = llm.tokenizer.apply_chat_template( + conversation=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": user_content}, + ], + add_generation_prompt=True, + tokenize=False, + ) + + images = [Image.open(p).convert("RGB") for p in args.image] + batch = processor( + text=prompt, + images=images, + max_slice_nums=args.max_slice_nums, + return_tensors="pt", + ) + + input_ids = batch["input_ids"][0].to(dtype=torch.int64) + attn = batch["attention_mask"][0].to(dtype=torch.bool) + pad_left = int((~attn).sum().item()) + tokens = input_ids[pad_left:].to(dtype=torch.int32) + + bounds_all = (batch["image_bound"][0].to(dtype=torch.int64) - pad_left) + pixel_values_slices = batch["pixel_values"][0] + tgt_sizes = batch["tgt_sizes"][0] + + feature_len = int(preproc_cfg.get("image_feature_size", 64)) + bounds = torch.stack( + [b for b in bounds_all if int((b[1] - b[0]).item()) == feature_len], dim=0 + ) + + if debug: + print("pad_left:", pad_left, "tokens_len:", int(tokens.numel())) + print("image_bound_all:", bounds_all.tolist()) + print("image_bound_kept:", bounds.tolist()) + print("num_slices:", len(pixel_values_slices)) + + if len(pixel_values_slices) != bounds.shape[0]: + n = min(len(pixel_values_slices), int(bounds.shape[0])) + bounds = bounds[:n] + pixel_values_slices = pixel_values_slices[:n] + tgt_sizes = tgt_sizes[:n] + if debug: + print("WARNING: truncated to", n, "slices to match bounds") + + if len(pixel_values_slices) == 0: + raise SystemExit("No image slices to run vision.") + + # Vision dtype: optionally compute in f32 for stability. + llm_torch_dt = llm.meta.torch_dtype_logits + llm_dt = llm.meta.dt_logits + vision_f32 = bool(args.vision_f32) and not args.hygon + vision_torch_dt = torch.float32 if vision_f32 else llm_torch_dt + vision_dt = DataType.INFINI_DTYPE_F32 if vision_f32 else llm_dt + + vision_model, vision_handle, vision_meta, vision_keepalive = _build_vision_model( + model_dir, vision_torch_dt, vision_dt, device + ) + + # Compute per-slice vision embeddings + patch = int(preproc_cfg.get("patch_size", 14)) + slice_embeds = [] + for i, x in enumerate(pixel_values_slices): + th, tw = int(tgt_sizes[i][0].item()), int(tgt_sizes[i][1].item()) + seq_len = th * tw + x = x.to(dtype=vision_torch_dt).contiguous() + packed = x.unsqueeze(0).contiguous() + if packed.shape != (1, 3, patch, seq_len * patch): + raise SystemExit(f"bad packed shape: {tuple(packed.shape)} for slice {i}") + + out = torch.empty( + (vision_meta.resampler_meta.num_queries, vision_meta.resampler_meta.embed_dim), + dtype=vision_torch_dt, + ) + vision_model.infer_vision_resampler( + vision_handle, packed.data_ptr(), seq_len, th, tw, out.data_ptr() + ) + if torch.isnan(out).any(): + raise SystemExit(f"vision output contains NaNs (slice {i})") + if out.dtype != llm_torch_dt: + out = out.to(dtype=llm_torch_dt) + slice_embeds.append(out.contiguous()) + + # Build overrides (positions + embeddings) + override_pos_list: list[int] = [] + override_embed_list: list[torch.Tensor] = [] + for i in range(bounds.shape[0]): + s = int(bounds[i][0].item()) + e = int(bounds[i][1].item()) + if e - s != int(vision_meta.resampler_meta.num_queries): + raise SystemExit(f"unexpected bound length: {e-s}") + override_pos_list.extend(list(range(s, e))) + override_embed_list.append(slice_embeds[i]) + override_embeds = torch.cat(override_embed_list, dim=0).contiguous() + override_pos = (c_uint * len(override_pos_list))(*override_pos_list) + + # Sanity: override positions should be tokens. + unk_id = getattr(llm.tokenizer, "unk_token_id", None) + if unk_id is not None: + override_tok = tokens[torch.tensor(override_pos_list, dtype=torch.long)] + uniq = torch.unique(override_tok).tolist() + if len(uniq) != 1 or int(uniq[0]) != int(unk_id): + if debug: + print("WARNING: override positions are not all tokens.") + print(" unk_id:", int(unk_id)) + print(" override_token_ids_unique:", [int(x) for x in uniq[:16]]) + + # Prefill + decode + ntok = int(tokens.numel()) + tokens_c = (c_uint * ntok)(*tokens.tolist()) + req_lens = (c_uint * 1)(ntok) + req_pos = (c_uint * 1)(0) + dev_ids = (c_int * 1)(0) + + kv = llm.jiuge_model.create_kv_cache( + llm.meta.nlayer, + llm.meta.dctx, + llm.meta.nkvh, + llm.meta.dh, + llm.meta.dh, + llm.meta.dt_logits, + device, + dev_ids, + 1, + ) + kv_caches = (POINTER(KVCacheCStruct) * 1)(kv) + + temperature = (c_float * 1)(float(args.temperature)) + topk = (c_uint * 1)(int(args.topk)) + topp = (c_float * 1)(float(args.topp)) + out = (c_uint * 1)() + + llm.jiuge_model.infer_batch_with_overrides( + llm.model_instance, + tokens_c, + ntok, + req_lens, + 1, + req_pos, + kv_caches, + len(override_pos_list), + override_pos, + override_embeds.data_ptr(), + temperature, + topk, + topp, + out, + ) + + generated = [int(out[0])] + cur_pos = ntok + eos_ids = set(llm.eos_token_id) + for _ in range(int(args.max_steps) - 1): + if generated[-1] in eos_ids: + break + req_lens = (c_uint * 1)(1) + req_pos = (c_uint * 1)(cur_pos) + tokens_c = (c_uint * 1)(generated[-1]) + llm.jiuge_model.infer_batch( + llm.model_instance, + tokens_c, + 1, + req_lens, + 1, + req_pos, + kv_caches, + temperature, + topk, + topp, + out, + ) + generated.append(int(out[0])) + cur_pos += 1 + + text = llm.tokenizer.decode(generated, skip_special_tokens=False) + print(text) + + llm.jiuge_model.drop_kv_cache(kv) + vision_model.destroy_model(vision_handle) + llm.jiuge_model.destroy_model(llm.model_instance) + + +if __name__ == "__main__": + main() diff --git a/scripts/minicpmv_chat.py b/scripts/minicpmv_chat.py new file mode 100644 index 00000000..b6cbf4ef --- /dev/null +++ b/scripts/minicpmv_chat.py @@ -0,0 +1,722 @@ +import argparse +import json +import os +import time +from ctypes import POINTER, c_float, c_int, c_uint +from pathlib import Path + +import torch +from PIL import Image +from safetensors.torch import safe_open + +from libinfinicore_infer import ( + DataType, + DeviceType, + JiugeModel, + KVCacheCStruct, + KVCompressionConfigCStruct, + MiniCPMVLanguageMetaCStruct, + MiniCPMVMetaCStruct, + MiniCPMVModel, + MiniCPMVResamplerMetaCStruct, + MiniCPMVVisionMetaCStruct, + MiniCPMVWeightsCStruct, + MiniCPMVSiglipLayerWeightsCStruct, +) + + +def _dtype_from_dt_logits(dt_logits: DataType): + if dt_logits == DataType.INFINI_DTYPE_F32: + return torch.float32 + if dt_logits == DataType.INFINI_DTYPE_BF16: + return torch.bfloat16 + if dt_logits == DataType.INFINI_DTYPE_F16: + return torch.float16 + raise ValueError(f"Unsupported dt_logits: {dt_logits}") + + +def _load_tensor(model_dir: Path, weight_map: dict, key: str) -> torch.Tensor: + if key not in weight_map: + if key.endswith(".weight") and key[: -len(".weight")] in weight_map: + key = key[: -len(".weight")] + elif (key + ".weight") in weight_map: + key = key + ".weight" + filename = weight_map[key] + full = model_dir / filename + with safe_open(str(full), framework="pt", device="cpu") as f: + return f.get_tensor(key) + + +def _make_siglip_layer_struct(model_dir: Path, weight_map: dict, layer_idx: int, torch_dt) -> tuple: + keepalive = {} + + def to_dt(x: torch.Tensor) -> torch.Tensor: + return x.detach().to(dtype=torch_dt).contiguous() + + def t_weight(key: str) -> torch.Tensor: + w = _load_tensor(model_dir, weight_map, key) + return to_dt(w.transpose(0, 1)) + + lw = MiniCPMVSiglipLayerWeightsCStruct() + keepalive["ln1_w"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.layer_norm1.weight")) + keepalive["ln1_b"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.layer_norm1.bias")) + keepalive["ln2_w"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.layer_norm2.weight")) + keepalive["ln2_b"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.layer_norm2.bias")) + lw.layer_norm1_weight = keepalive["ln1_w"].data_ptr() + lw.layer_norm1_bias = keepalive["ln1_b"].data_ptr() + lw.layer_norm2_weight = keepalive["ln2_w"].data_ptr() + lw.layer_norm2_bias = keepalive["ln2_b"].data_ptr() + + keepalive["q_w_t"] = t_weight(f"vpm.encoder.layers.{layer_idx}.self_attn.q_proj.weight") + keepalive["k_w_t"] = t_weight(f"vpm.encoder.layers.{layer_idx}.self_attn.k_proj.weight") + keepalive["v_w_t"] = t_weight(f"vpm.encoder.layers.{layer_idx}.self_attn.v_proj.weight") + keepalive["o_w_t"] = t_weight(f"vpm.encoder.layers.{layer_idx}.self_attn.out_proj.weight") + keepalive["q_b"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.self_attn.q_proj.bias")) + keepalive["k_b"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.self_attn.k_proj.bias")) + keepalive["v_b"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.self_attn.v_proj.bias")) + keepalive["o_b"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.self_attn.out_proj.bias")) + lw.q_weight = keepalive["q_w_t"].data_ptr() + lw.q_bias = keepalive["q_b"].data_ptr() + lw.k_weight = keepalive["k_w_t"].data_ptr() + lw.k_bias = keepalive["k_b"].data_ptr() + lw.v_weight = keepalive["v_w_t"].data_ptr() + lw.v_bias = keepalive["v_b"].data_ptr() + lw.out_weight = keepalive["o_w_t"].data_ptr() + lw.out_bias = keepalive["o_b"].data_ptr() + + keepalive["fc1_w_t"] = t_weight(f"vpm.encoder.layers.{layer_idx}.mlp.fc1.weight") + keepalive["fc2_w_t"] = t_weight(f"vpm.encoder.layers.{layer_idx}.mlp.fc2.weight") + keepalive["fc1_b"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.mlp.fc1.bias")) + keepalive["fc2_b"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.mlp.fc2.bias")) + lw.fc1_weight = keepalive["fc1_w_t"].data_ptr() + lw.fc1_bias = keepalive["fc1_b"].data_ptr() + lw.fc2_weight = keepalive["fc2_w_t"].data_ptr() + lw.fc2_bias = keepalive["fc2_b"].data_ptr() + + return lw, keepalive + + +def _build_minicpmv_vision_model(model_dir: Path, torch_dt_logits, dt_logits: DataType, device: DeviceType): + config = json.loads((model_dir / "config.json").read_text()) + index = json.loads((model_dir / "model.safetensors.index.json").read_text()) + weight_map = index["weight_map"] + + vision_cfg = config["vision_config"] + patch = int(vision_cfg["patch_size"]) + d_v = int(vision_cfg["hidden_size"]) + nh_v = int(vision_cfg["num_attention_heads"]) + di_v = int(vision_cfg["intermediate_size"]) + nlayer = int(vision_cfg["num_hidden_layers"]) + + language_meta = MiniCPMVLanguageMetaCStruct( + dt_logits=dt_logits, + nlayer=int(config["num_hidden_layers"]), + d=int(config["hidden_size"]), + nh=int(config["num_attention_heads"]), + nkvh=int(config["num_key_value_heads"]), + dh=int(config["hidden_size"] // config["num_attention_heads"]), + di=int(config["intermediate_size"]), + dctx=int(config["max_position_embeddings"]), + dvoc=int(config["vocab_size"]), + epsilon=float(config["rms_norm_eps"]), + theta=float(config["rope_theta"]), + end_token=int(config["eos_token_id"]), + ) + vision_meta = MiniCPMVVisionMetaCStruct( + patch_size=patch, + vision_embed_dim=d_v, + vision_num_layers=nlayer, + vision_num_heads=nh_v, + vision_intermediate_size=di_v, + vision_layer_norm_eps=1e-6, + vision_image_size=int(vision_cfg["image_size"]), + vision_num_positions=4900, + ) + resampler_meta = MiniCPMVResamplerMetaCStruct( + num_queries=int(config["query_num"]), + embed_dim=int(config["hidden_size"]), + num_heads=int(config["num_attention_heads"]), + kv_dim=d_v, + layer_norm_eps=1e-6, + max_patches_h=70, + max_patches_w=70, + ) + meta = MiniCPMVMetaCStruct( + vision_meta=vision_meta, resampler_meta=resampler_meta, language_meta=language_meta + ) + + keepalive = {} + + # Vision weights + keepalive["patch_w"] = _load_tensor(model_dir, weight_map, "vpm.embeddings.patch_embedding.weight").detach().to(dtype=torch_dt_logits).contiguous() + keepalive["patch_b"] = _load_tensor(model_dir, weight_map, "vpm.embeddings.patch_embedding.bias").detach().to(dtype=torch_dt_logits).contiguous() + keepalive["pos_emb"] = _load_tensor(model_dir, weight_map, "vpm.embeddings.position_embedding.weight").detach().to(dtype=torch_dt_logits).contiguous() + + layers = [] + for i in range(nlayer): + lw, ka = _make_siglip_layer_struct(model_dir, weight_map, i, torch_dt_logits) + layers.append(lw) + for k, v in ka.items(): + keepalive[f"l{i}_{k}"] = v + layers_arr = (MiniCPMVSiglipLayerWeightsCStruct * nlayer)(*layers) + + keepalive["post_ln_w"] = _load_tensor(model_dir, weight_map, "vpm.post_layernorm.weight").detach().to(dtype=torch_dt_logits).contiguous() + keepalive["post_ln_b"] = _load_tensor(model_dir, weight_map, "vpm.post_layernorm.bias").detach().to(dtype=torch_dt_logits).contiguous() + + # Resampler weights (linear weights must be transposed to [in, out]) + def t(key: str) -> torch.Tensor: + return _load_tensor(model_dir, weight_map, key).detach().to(dtype=torch_dt_logits).transpose(0, 1).contiguous() + + keepalive["res_kv_proj_w_t"] = t("resampler.kv_proj.weight") + keepalive["res_in_w_t"] = t("resampler.attn.in_proj_weight") + keepalive["res_out_w_t"] = t("resampler.attn.out_proj.weight") + keepalive["res_in_b"] = _load_tensor(model_dir, weight_map, "resampler.attn.in_proj_bias").detach().to(dtype=torch_dt_logits).contiguous() + keepalive["res_out_b"] = _load_tensor(model_dir, weight_map, "resampler.attn.out_proj.bias").detach().to(dtype=torch_dt_logits).contiguous() + keepalive["res_query"] = _load_tensor(model_dir, weight_map, "resampler.query").detach().to(dtype=torch_dt_logits).contiguous() + keepalive["res_proj"] = _load_tensor(model_dir, weight_map, "resampler.proj").detach().to(dtype=torch_dt_logits).contiguous() + for name in ["ln_q", "ln_kv", "ln_post"]: + keepalive[f"{name}_w"] = _load_tensor(model_dir, weight_map, f"resampler.{name}.weight").detach().to(dtype=torch_dt_logits).contiguous() + keepalive[f"{name}_b"] = _load_tensor(model_dir, weight_map, f"resampler.{name}.bias").detach().to(dtype=torch_dt_logits).contiguous() + + weights = MiniCPMVWeightsCStruct() + weights.vpm_patch_embedding_weight = keepalive["patch_w"].data_ptr() + weights.vpm_patch_embedding_bias = keepalive["patch_b"].data_ptr() + weights.vpm_position_embedding = keepalive["pos_emb"].data_ptr() + weights.vpm_layers = layers_arr + weights.vpm_post_layernorm_weight = keepalive["post_ln_w"].data_ptr() + weights.vpm_post_layernorm_bias = keepalive["post_ln_b"].data_ptr() + + weights.resampler_query = keepalive["res_query"].data_ptr() + weights.resampler_kv_proj_weight = keepalive["res_kv_proj_w_t"].data_ptr() + weights.resampler_attn_in_proj_weight = keepalive["res_in_w_t"].data_ptr() + weights.resampler_attn_in_proj_bias = keepalive["res_in_b"].data_ptr() + weights.resampler_attn_out_proj_weight = keepalive["res_out_w_t"].data_ptr() + weights.resampler_attn_out_proj_bias = keepalive["res_out_b"].data_ptr() + weights.resampler_ln_q_weight = keepalive["ln_q_w"].data_ptr() + weights.resampler_ln_q_bias = keepalive["ln_q_b"].data_ptr() + weights.resampler_ln_kv_weight = keepalive["ln_kv_w"].data_ptr() + weights.resampler_ln_kv_bias = keepalive["ln_kv_b"].data_ptr() + weights.resampler_ln_post_weight = keepalive["ln_post_w"].data_ptr() + weights.resampler_ln_post_bias = keepalive["ln_post_b"].data_ptr() + weights.resampler_proj = keepalive["res_proj"].data_ptr() + + # Language weights unused here + weights.nlayer = 0 + weights.dt_norm = dt_logits + weights.dt_mat = dt_logits + weights.transpose_linear_weights = 0 + weights.input_embd = 0 + weights.output_norm = 0 + weights.output_embd = 0 + weights.attn_norm = None + weights.attn_qkv = None + weights.attn_qkv_b = None + weights.attn_q_norm = None + weights.attn_k_norm = None + weights.attn_o = None + weights.ffn_norm = None + weights.ffn_gate_up = None + weights.ffn_down = None + + # Keep ctypes objects alive: MiniCPMVModel stores pointers to `weights` and `vpm_layers`. + keepalive["layers_arr"] = layers_arr + keepalive["weights_struct"] = weights + + m = MiniCPMVModel() + dev_ids = (c_int * 1)(0) + handle = m.create_model(meta, weights, device, 1, dev_ids) + return m, handle, meta, keepalive + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument( + "--dev", + choices=["cpu", "nvidia", "hygon", "moore"], + default="cpu", + help="Device backend for inference.", + ) + ap.add_argument("--model-dir", required=True) + ap.add_argument("--image", required=True) + ap.add_argument("--question", default="图片是什么?") + ap.add_argument("--max-steps", type=int, default=128) + ap.add_argument("--max-tokens", type=int, default=2048) + ap.add_argument("--debug", action="store_true") + ap.add_argument("--kv-compress", action="store_true", help="Enable in-place KV cache compression after prefill.") + ap.add_argument("--kv-compress-bin", default="", help="Path to compressor .bin weights.") + ap.add_argument("--kv-compress-factor", type=int, default=5) + ap.add_argument("--kv-compress-min-seq-len", type=int, default=2) + ap.add_argument("--kv-compress-image-len", type=int, default=0, help="Prefix tokens treated as image KV (0 for Hybrid text-only).") + ap.add_argument("--perplexity", action="store_true", help="Collect logits for perplexity calculation") + ap.add_argument("--time", action="store_true", help="Print timing metrics (time per step, etc.)") + args = ap.parse_args() + debug = args.debug or os.environ.get("MINICPMV_DEBUG", "0") == "1" + + model_dir = Path(args.model_dir) + + # LLM (Jiuge) loader + from jiuge import JiugeForCauslLM + + device_map = { + "cpu": DeviceType.DEVICE_TYPE_CPU, + "nvidia": DeviceType.DEVICE_TYPE_NVIDIA, + "hygon": DeviceType.DEVICE_TYPE_HYGON, + "moore": DeviceType.DEVICE_TYPE_MOORE, + } + device = device_map[args.dev] + dtype_override = ( + torch.float16 + if device + in { + DeviceType.DEVICE_TYPE_HYGON, + DeviceType.DEVICE_TYPE_MOORE, + DeviceType.DEVICE_TYPE_NVIDIA, + } + else None + ) + + llm = JiugeForCauslLM( + str(model_dir), + device=device, + ndev=1, + max_tokens=args.max_tokens, + dtype_override=dtype_override, + ) + + + # Build processor using the same tokenizer + preproc_cfg = json.loads((model_dir / "preprocessor_config.json").read_text()) + from image_processing_minicpmv import MiniCPMVImageProcessor + from processing_minicpmv import MiniCPMVProcessor + + image_processor = MiniCPMVImageProcessor(**preproc_cfg) + processor = MiniCPMVProcessor(image_processor=image_processor, tokenizer=llm.tokenizer) + + # The vendored HF processor searches for the literal pattern `(./)`, + # so we must include exactly one char + '/' inside the image tag. + user_content = f"./\n{args.question}" + prompt = llm.tokenizer.apply_chat_template( + conversation=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": user_content}, + ], + add_generation_prompt=True, + tokenize=False, + ) + + img = Image.open(args.image).convert("RGB") + batch = processor(text=prompt, images=[img], return_tensors="pt") + + input_ids = batch["input_ids"][0].to(dtype=torch.int64) + attn = batch["attention_mask"][0].to(dtype=torch.bool) + pad_left = int((~attn).sum().item()) + tokens = input_ids[pad_left:].to(dtype=torch.int32) + + bounds = batch["image_bound"][0].to(dtype=torch.int64) + bounds = bounds - pad_left + if bounds.shape[0] > 0: + if debug: + print("DEBUG pad_left:", pad_left) + print("DEBUG tokens_len:", int(tokens.numel())) + print("DEBUG bounds_all:", bounds.tolist()) + + pixel_values_slices = batch["pixel_values"][0] + tgt_sizes = batch["tgt_sizes"][0] + + + # `image_bound` may include non-vision spans (e.g., ...), which are not 64-token features. + feature_len = int(preproc_cfg.get("image_feature_size", 64)) + bounds_all = bounds + bounds = torch.stack([b for b in bounds_all if int((b[1] - b[0]).item()) == feature_len], dim=0) + + if bounds.shape[0] != bounds_all.shape[0]: + if debug: + print( + f"INFO: filtered image_bound: total={bounds_all.shape[0]} feature_len={feature_len} kept={bounds.shape[0]}" + ) + print(" image_bound_all (after left-pad adjust):", bounds_all.tolist()) + print(" image_bound_kept:", bounds.tolist()) + + if len(pixel_values_slices) != bounds.shape[0]: + if debug: + print(f"WARNING: slice count mismatch: slices={len(pixel_values_slices)} bounds={bounds.shape[0]}") + # Proceed by truncating to the common prefix (processor constructs placeholders in slice order). + n = min(len(pixel_values_slices), int(bounds.shape[0])) + bounds = bounds[:n] + pixel_values_slices = pixel_values_slices[:n] + tgt_sizes = tgt_sizes[:n] + + if len(pixel_values_slices) == 0: + raise SystemExit("No image slices to run vision.") + + + # Vision can be computed in f32 for numerical stability, then cast to LLM dtype for injection. + llm_torch_dt = llm.meta.torch_dtype_logits + llm_dt = llm.meta.dt_logits + vision_force_f32 = os.environ.get("MINICPMV_VISION_FORCE_F32", "0") == "1" + vision_torch_dt = torch.float32 if vision_force_f32 else llm_torch_dt + vision_dt = DataType.INFINI_DTYPE_F32 if vision_force_f32 else llm_dt + + vision_model, vision_handle, vision_meta, vision_keepalive = _build_minicpmv_vision_model( + model_dir, vision_torch_dt, vision_dt, device + ) + + # Compute per-slice vision embeddings [num_slices, 64, 3584] + slice_embeds = [] + patch = int(preproc_cfg.get("patch_size", 14)) + vision_infer_start_time = time.time() + for i, x in enumerate(pixel_values_slices): + th, tw = int(tgt_sizes[i][0].item()), int(tgt_sizes[i][1].item()) + seq_len = th * tw + x = x.to(dtype=vision_torch_dt).contiguous() + packed = x.unsqueeze(0).contiguous() + if packed.shape != (1, 3, patch, seq_len * patch): + raise SystemExit(f"bad packed shape: {tuple(packed.shape)} for slice {i}") + + out = torch.empty( + (vision_meta.resampler_meta.num_queries, vision_meta.resampler_meta.embed_dim), + dtype=vision_torch_dt, + ) + vision_model.infer_vision_resampler(vision_handle, packed.data_ptr(), seq_len, th, tw, out.data_ptr()) + if torch.isnan(out).any(): + nan_cnt = int(torch.isnan(out).sum().item()) + print(f"ERROR: vision out has NaN: slice={i} tgt_h={th} tgt_w={tw} nan_cnt={nan_cnt}") + print( + " vision_out_abs_max/mean:", + float(out.float().abs().max().item()), + float(out.float().abs().mean().item()), + ) + raise SystemExit("vision output contains NaNs") + if out.dtype != llm_torch_dt: + out = out.to(dtype=llm_torch_dt) + slice_embeds.append(out.contiguous()) + vision_infer_end_time = time.time() + vision_infer_time = float(vision_infer_end_time - vision_infer_start_time) + + # Flatten override positions and embeddings according to image_bound. + override_pos_list = [] + override_embed_list = [] + for i in range(bounds.shape[0]): + s = int(bounds[i][0].item()) + e = int(bounds[i][1].item()) + if e - s != vision_meta.resampler_meta.num_queries: + raise SystemExit(f"unexpected bound length: {e-s} (expected {vision_meta.resampler_meta.num_queries})") + override_pos_list.extend(list(range(s, e))) + override_embed_list.append(slice_embeds[i]) + override_embeds = torch.cat(override_embed_list, dim=0).contiguous() + if debug: + print( + "DEBUG override_embeds stats:", + float(override_embeds.float().abs().max().item()), + float(override_embeds.float().abs().mean().item()), + override_embeds.dtype, + ) + + # Sanity: all override positions should correspond to `` tokens. + unk_id = getattr(llm.tokenizer, "unk_token_id", None) + if unk_id is not None: + override_tok = tokens[torch.tensor(override_pos_list, dtype=torch.long)] + uniq = torch.unique(override_tok).tolist() + if len(uniq) != 1 or int(uniq[0]) != int(unk_id): + if debug: + print("WARNING: override positions are not all tokens.") + print(" unk_id:", int(unk_id)) + print(" override_token_ids_unique:", [int(x) for x in uniq[:16]]) + + override_pos = (c_uint * len(override_pos_list))(*override_pos_list) + + # Prefill with overrides + ntok = int(tokens.numel()) + tokens_c = (c_uint * ntok)(*tokens.tolist()) + req_lens = (c_uint * 1)(ntok) + req_pos = (c_uint * 1)(0) + dev_ids = (c_int * 1)(0) + + kv = llm.jiuge_model.create_kv_cache( + llm.meta.nlayer, + llm.meta.dctx, + llm.meta.nkvh, + llm.meta.dh, + llm.meta.dh, + llm.meta.dt_logits, + device, + dev_ids, + 1, + ) + kv_caches = (POINTER(KVCacheCStruct) * 1)(kv) + + temperature = (c_float * 1)(1.0) + topk = (c_uint * 1)(1) + topp = (c_float * 1)(1.0) + + prefill_logits = None + all_logits = [] + + out = (c_uint * 1)() + + if args.perplexity: + prefill_logits = torch.zeros((ntok, llm.meta.dvoc), dtype=llm.meta.torch_dtype_logits) + print(f"准备收集 prefill logits: shape {prefill_logits.shape}") + + prefill_start_time = time.time() + # 使用 infer_batch_with_overrides_with_logits 传递 logits + llm.jiuge_model.infer_batch_with_overrides_with_logits( + llm.model_instance, + tokens_c, + ntok, + req_lens, + 1, + req_pos, + kv_caches, + len(override_pos_list), + override_pos, + override_embeds.data_ptr(), + temperature, + topk, + topp, + out, + prefill_logits.data_ptr(), # 传递 logits 指针 + ) + prefill_end_time = time.time() + prefill_time = float(prefill_end_time - prefill_start_time) + + # 保存 prefill logits + all_logits.append(prefill_logits.clone()) + print(f"Collected prefill logits: shape {prefill_logits.shape}") + else: + prefill_start_time = time.time() + llm.jiuge_model.infer_batch_with_overrides( + llm.model_instance, + tokens_c, + ntok, + req_lens, + 1, + req_pos, + kv_caches, + len(override_pos_list), + override_pos, + override_embeds.data_ptr(), + temperature, + topk, + topp, + out, + ) + prefill_end_time = time.time() + prefill_time = float(prefill_end_time - prefill_start_time) + if debug: + print("DEBUG prefill next_token:", int(out[0])) + + generated = [int(out[0])] + rope_pos = ntok + kv_pos = ntok + eos_ids = set(llm.eos_token_id) + + if args.kv_compress: + if not args.kv_compress_bin: + raise SystemExit("--kv-compress requires --kv-compress-bin") + cfg = KVCompressionConfigCStruct( + enable=1, + compression_factor=int(args.kv_compress_factor), + min_seq_len=int(args.kv_compress_min_seq_len), + image_kv_len=int(args.kv_compress_image_len), + weight_path=args.kv_compress_bin.encode("utf-8"), + ) + kv_pos = int(llm.jiuge_model.compress_kv_cache_inplace(kv, ntok, cfg)) + if debug: + print("DEBUG kv_compress:", {"rope_pos": int(rope_pos), "kv_pos": int(kv_pos)}) + + decode_start_time = time.time() + for _ in range(args.max_steps - 1): + if generated[-1] in eos_ids: + break + req_lens = (c_uint * 1)(1) + req_pos = (c_uint * 1)(rope_pos) + kv_pos_c = (c_uint * 1)(kv_pos) + tokens_c = (c_uint * 1)(generated[-1]) + # if args.kv_compress: + # llm.jiuge_model.infer_batch_ex( + # llm.model_instance, + # tokens_c, + # 1, + # req_lens, + # 1, + # req_pos, + # kv_pos_c, + # kv_caches, + # temperature, + # topk, + # topp, + # out, + # ) + + if args.perplexity: + # 收集 decode 阶段的 logits + decode_logits = torch.zeros((1, llm.meta.dvoc), dtype=llm.meta.torch_dtype_logits) + + if args.kv_compress: + # 使用 infer_batch_ex_with_logits 收集logits(KV压缩模式) + llm.jiuge_model.infer_batch_ex_with_logits( + llm.model_instance, + tokens_c, + 1, + req_lens, + 1, + req_pos, + kv_pos_c, + kv_caches, + temperature, + topk, + topp, + out, + decode_logits.data_ptr(), # 传递 logits 指针 + ) + else: + # 使用 infer_batch_with_logits 一次性完成推理和 logits 收集 + llm.jiuge_model.infer_batch_with_logits( + llm.model_instance, + tokens_c, + 1, + req_lens, + 1, + req_pos, + kv_caches, + temperature, + topk, + topp, + out, + decode_logits.data_ptr(), # 传递 logits 指针 + ) + + # 保存 decode logits(两种模式都保存) + all_logits.append(decode_logits.clone()) + # print(f"Collected decode logits step {_+1}: shape {decode_logits.shape}") + + else: + # 原有的推理方式 + if args.kv_compress: + llm.jiuge_model.infer_batch_ex( + llm.model_instance, + tokens_c, + 1, + req_lens, + 1, + req_pos, + kv_pos_c, + kv_caches, + temperature, + topk, + topp, + out, + ) + else: + llm.jiuge_model.infer_batch( + llm.model_instance, + tokens_c, + 1, + req_lens, + 1, + req_pos, + kv_caches, + temperature, + topk, + topp, + out, + ) + + generated.append(int(out[0])) + rope_pos += 1 + kv_pos += 1 + decode_end_time = time.time() + decode_time = float(decode_end_time - decode_start_time) + + if debug: + print("DEBUG generated_ids:", generated) + text = llm.tokenizer.decode(generated, skip_special_tokens=False) + print(text) + + + # 计算困惑度(如果启用了logits收集) + if args.perplexity and len(all_logits) > 0: + print("\n" + "="*60) + print("Computing perplexity...") + + import math + + total_nll = 0.0 + total_tokens = 0 + + # 处理 prefill logits + if len(all_logits) > 0 and len(all_logits[0].shape) == 2: + prefill_logits = all_logits[0] # [ntok, vocab_size] + # prefill阶段:只计算输入序列最后一个位置对第一个生成token的预测 + # 这与llava.py的实现一致:line 1596 + input_seq_len = prefill_logits.shape[0] + last_position_logits = prefill_logits[input_seq_len - 1] # 最后一个位置的logits + target_token_id = generated[0] # 预测第一个生成的token + + # 计算log概率 + log_probs = torch.nn.functional.log_softmax(last_position_logits, dim=-1) + token_log_prob = log_probs[target_token_id].item() + + total_nll += -token_log_prob + total_tokens += 1 + + print(f" Prefill pos {input_seq_len-1}: token={llm.tokenizer.decode([target_token_id])} log_prob={token_log_prob:.4f}") + + # 处理 decode logits + # decode阶段:第step_idx步的logits应该预测generated[step_idx+1] + # 这与llava.py的实现一致:line 1628 + decode_start_idx = 1 # 跳过 prefill logits + for step_idx, logits in enumerate(all_logits[decode_start_idx:]): + if len(logits.shape) == 2: + decode_logits = logits[0] # [vocab_size] + else: + decode_logits = logits # [vocab_size] + + # decode阶段:第step_idx步预测generated[step_idx+1] + # 因为generated[0]已经在prefill阶段被预测了 + if step_idx + 1 < len(generated): + target_token_id = generated[step_idx + 1] + + # 计算log概率 + log_probs = torch.nn.functional.log_softmax(decode_logits, dim=-1) + token_log_prob = log_probs[target_token_id].item() + + total_nll += -token_log_prob + total_tokens += 1 + + # 显示前3步的详细信息 + if step_idx < 3: + print(f" Decode step {step_idx+1}: token={llm.tokenizer.decode([target_token_id])} log_prob={token_log_prob:.4f}") + + if total_tokens > 0: + # 计算困惑度 + avg_nll = total_nll / total_tokens + perplexity = math.exp(avg_nll) + + print(f"\nTotal tokens: {total_tokens}") + print(f"Total NLL: {total_nll:.4f}") + print(f"Perplexity: {perplexity:.4f}") + else: + print("No tokens computed for perplexity") + + print("="*60) + + + llm.jiuge_model.drop_kv_cache(kv) + vision_model.destroy_model(vision_handle) + llm.jiuge_model.destroy_model(llm.model_instance) + if args.time: + steps = int(len(generated)) + llm_total_time = float(prefill_time + decode_time) + avg_time_per_step = ( + llm_total_time * 1000 / (steps - 1) if steps > 1 else llm_total_time * 1000 + ) + print(f"Vision time: {vision_infer_time * 1000:.3f}ms") + print(f"Prefill time: {prefill_time * 1000:.3f}ms") + print(f"Decode time: {decode_time * 1000:.3f}ms") + print(f"Time per step: {avg_time_per_step:.3f}ms") + + +if __name__ == "__main__": + main() diff --git a/scripts/minicpmv_vision_hf_smoke.py b/scripts/minicpmv_vision_hf_smoke.py new file mode 100644 index 00000000..b6cfe487 --- /dev/null +++ b/scripts/minicpmv_vision_hf_smoke.py @@ -0,0 +1,350 @@ +import argparse +import json +import os +from pathlib import Path + +import torch +from PIL import Image +from safetensors.torch import safe_open + +from libinfinicore_infer import ( + DataType, + DeviceType, + MiniCPMVLanguageMetaCStruct, + MiniCPMVMetaCStruct, + MiniCPMVModel, + MiniCPMVResamplerMetaCStruct, + MiniCPMVVisionMetaCStruct, + MiniCPMVWeightsCStruct, + MiniCPMVSiglipLayerWeightsCStruct, +) + + +def _dtype_from_config(torch_dtype: str): + if torch_dtype == "bfloat16": + return torch.bfloat16, DataType.INFINI_DTYPE_BF16 + if torch_dtype == "float16": + return torch.float16, DataType.INFINI_DTYPE_F16 + if torch_dtype == "float32": + return torch.float32, DataType.INFINI_DTYPE_F32 + return torch.bfloat16, DataType.INFINI_DTYPE_BF16 + + +def _load_tensor(model_dir: Path, weight_map: dict, key: str) -> torch.Tensor: + if key not in weight_map: + if key.endswith(".weight") and key[: -len(".weight")] in weight_map: + key = key[: -len(".weight")] + elif (key + ".weight") in weight_map: + key = key + ".weight" + filename = weight_map[key] + full = model_dir / filename + with safe_open(str(full), framework="pt", device="cpu") as f: + return f.get_tensor(key) + + +def _make_siglip_layer_struct(model_dir: Path, weight_map: dict, layer_idx: int, torch_dt) -> tuple: + keepalive = {} + + def to_dt(x: torch.Tensor) -> torch.Tensor: + return x.detach().to(dtype=torch_dt).contiguous() + + def t_weight(key: str) -> torch.Tensor: + w = _load_tensor(model_dir, weight_map, key) + return to_dt(w.transpose(0, 1)) + + lw = MiniCPMVSiglipLayerWeightsCStruct() + keepalive["ln1_w"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.layer_norm1.weight")) + keepalive["ln1_b"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.layer_norm1.bias")) + keepalive["ln2_w"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.layer_norm2.weight")) + keepalive["ln2_b"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.layer_norm2.bias")) + lw.layer_norm1_weight = keepalive["ln1_w"].data_ptr() + lw.layer_norm1_bias = keepalive["ln1_b"].data_ptr() + lw.layer_norm2_weight = keepalive["ln2_w"].data_ptr() + lw.layer_norm2_bias = keepalive["ln2_b"].data_ptr() + + keepalive["q_w_t"] = t_weight(f"vpm.encoder.layers.{layer_idx}.self_attn.q_proj.weight") + keepalive["k_w_t"] = t_weight(f"vpm.encoder.layers.{layer_idx}.self_attn.k_proj.weight") + keepalive["v_w_t"] = t_weight(f"vpm.encoder.layers.{layer_idx}.self_attn.v_proj.weight") + keepalive["o_w_t"] = t_weight(f"vpm.encoder.layers.{layer_idx}.self_attn.out_proj.weight") + keepalive["q_b"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.self_attn.q_proj.bias")) + keepalive["k_b"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.self_attn.k_proj.bias")) + keepalive["v_b"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.self_attn.v_proj.bias")) + keepalive["o_b"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.self_attn.out_proj.bias")) + lw.q_weight = keepalive["q_w_t"].data_ptr() + lw.q_bias = keepalive["q_b"].data_ptr() + lw.k_weight = keepalive["k_w_t"].data_ptr() + lw.k_bias = keepalive["k_b"].data_ptr() + lw.v_weight = keepalive["v_w_t"].data_ptr() + lw.v_bias = keepalive["v_b"].data_ptr() + lw.out_weight = keepalive["o_w_t"].data_ptr() + lw.out_bias = keepalive["o_b"].data_ptr() + + keepalive["fc1_w_t"] = t_weight(f"vpm.encoder.layers.{layer_idx}.mlp.fc1.weight") + keepalive["fc2_w_t"] = t_weight(f"vpm.encoder.layers.{layer_idx}.mlp.fc2.weight") + keepalive["fc1_b"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.mlp.fc1.bias")) + keepalive["fc2_b"] = to_dt(_load_tensor(model_dir, weight_map, f"vpm.encoder.layers.{layer_idx}.mlp.fc2.bias")) + lw.fc1_weight = keepalive["fc1_w_t"].data_ptr() + lw.fc1_bias = keepalive["fc1_b"].data_ptr() + lw.fc2_weight = keepalive["fc2_w_t"].data_ptr() + lw.fc2_bias = keepalive["fc2_b"].data_ptr() + + return lw, keepalive + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--image", required=True, help="Path to an input image") + ap.add_argument("--slice-idx", type=int, default=0) + ap.add_argument("--all-slices", action="store_true") + ap.add_argument("--max-slices", type=int, default=None) + ap.add_argument("--max-slice-nums", type=int, default=None) + args = ap.parse_args() + + model_dir = Path(os.environ.get("MINICPMV_MODEL_DIR", "")) + if not model_dir: + raise SystemExit("Set MINICPMV_MODEL_DIR to the HF model directory.") + + config = json.loads((model_dir / "config.json").read_text()) + index = json.loads((model_dir / "model.safetensors.index.json").read_text()) + weight_map = index["weight_map"] + + force_f32 = os.environ.get("MINICPMV_FORCE_F32", "0") == "1" + torch_dt, dt = _dtype_from_config(config.get("torch_dtype", "bfloat16")) + if force_f32: + torch_dt, dt = torch.float32, DataType.INFINI_DTYPE_F32 + + from minicpmv_config.image_processing_minicpmv import MiniCPMVImageProcessor + + preproc_cfg = json.loads(((model_dir / "preprocessor_config.json").read_text())) + ip = MiniCPMVImageProcessor(**preproc_cfg) + + img = Image.open(args.image).convert("RGB") + batch = ip.preprocess(img, do_pad=True, max_slice_nums=args.max_slice_nums, return_tensors="pt") + pixel_values = batch["pixel_values"][0] + tgt_sizes = batch["tgt_sizes"][0] + + if not pixel_values: + raise SystemExit("No slices produced by image processor.") + + if args.all_slices: + slice_indices = list(range(len(pixel_values))) + if args.max_slices is not None: + slice_indices = slice_indices[: int(args.max_slices)] + else: + slice_idx = int(args.slice_idx) + if slice_idx < 0 or slice_idx >= len(pixel_values): + raise SystemExit(f"slice_idx out of range: {slice_idx} (num_slices={len(pixel_values)})") + slice_indices = [slice_idx] + + patch = int(preproc_cfg.get("patch_size", 14)) + + # ---------- Build C++ model (vision+resampler weights only) ---------- + vision_cfg = config["vision_config"] + nlayer = int(vision_cfg["num_hidden_layers"]) + d_v = int(vision_cfg["hidden_size"]) + nh_v = int(vision_cfg["num_attention_heads"]) + di_v = int(vision_cfg["intermediate_size"]) + + language_meta = MiniCPMVLanguageMetaCStruct( + dt_logits=dt, + nlayer=0, + d=int(config["hidden_size"]), + nh=int(config["num_attention_heads"]), + nkvh=int(config["num_key_value_heads"]), + dh=int(config["hidden_size"] // config["num_attention_heads"]), + di=int(config["intermediate_size"]), + dctx=int(config["max_position_embeddings"]), + dvoc=int(config["vocab_size"]), + epsilon=float(config["rms_norm_eps"]), + theta=float(config["rope_theta"]), + end_token=int(config["eos_token_id"]), + ) + vision_meta = MiniCPMVVisionMetaCStruct( + patch_size=patch, + vision_embed_dim=d_v, + vision_num_layers=nlayer, + vision_num_heads=nh_v, + vision_intermediate_size=di_v, + vision_layer_norm_eps=1e-6, + vision_image_size=int(vision_cfg["image_size"]), + vision_num_positions=4900, + ) + resampler_meta = MiniCPMVResamplerMetaCStruct( + num_queries=int(config["query_num"]), + embed_dim=int(config["hidden_size"]), + num_heads=int(config["num_attention_heads"]), + kv_dim=d_v, + layer_norm_eps=1e-6, + max_patches_h=70, + max_patches_w=70, + ) + meta = MiniCPMVMetaCStruct( + vision_meta=vision_meta, resampler_meta=resampler_meta, language_meta=language_meta + ) + + keepalive = {} + + keepalive["patch_w"] = _load_tensor(model_dir, weight_map, "vpm.embeddings.patch_embedding.weight").detach().to(dtype=torch_dt).contiguous() + keepalive["patch_b"] = _load_tensor(model_dir, weight_map, "vpm.embeddings.patch_embedding.bias").detach().to(dtype=torch_dt).contiguous() + keepalive["pos_emb"] = _load_tensor(model_dir, weight_map, "vpm.embeddings.position_embedding.weight").detach().to(dtype=torch_dt).contiguous() + + layers = [] + for i in range(nlayer): + lw, ka = _make_siglip_layer_struct(model_dir, weight_map, i, torch_dt) + layers.append(lw) + for k, v in ka.items(): + keepalive[f"l{i}_{k}"] = v + layers_arr = (MiniCPMVSiglipLayerWeightsCStruct * nlayer)(*layers) + + keepalive["post_ln_w"] = _load_tensor(model_dir, weight_map, "vpm.post_layernorm.weight").detach().to(dtype=torch_dt).contiguous() + keepalive["post_ln_b"] = _load_tensor(model_dir, weight_map, "vpm.post_layernorm.bias").detach().to(dtype=torch_dt).contiguous() + + def t(key: str) -> torch.Tensor: + return _load_tensor(model_dir, weight_map, key).detach().to(dtype=torch_dt).transpose(0, 1).contiguous() + + keepalive["res_kv_proj_w_t"] = t("resampler.kv_proj.weight") + keepalive["res_in_w_t"] = t("resampler.attn.in_proj_weight") + keepalive["res_out_w_t"] = t("resampler.attn.out_proj.weight") + keepalive["res_in_b"] = _load_tensor(model_dir, weight_map, "resampler.attn.in_proj_bias").detach().to(dtype=torch_dt).contiguous() + keepalive["res_out_b"] = _load_tensor(model_dir, weight_map, "resampler.attn.out_proj.bias").detach().to(dtype=torch_dt).contiguous() + keepalive["res_query"] = _load_tensor(model_dir, weight_map, "resampler.query").detach().to(dtype=torch_dt).contiguous() + # proj is used as `x @ proj` in HF (no transpose). + keepalive["res_proj"] = _load_tensor(model_dir, weight_map, "resampler.proj").detach().to(dtype=torch_dt).contiguous() + + for name in ["ln_q", "ln_kv", "ln_post"]: + keepalive[f"{name}_w"] = _load_tensor(model_dir, weight_map, f"resampler.{name}.weight").detach().to(dtype=torch_dt).contiguous() + keepalive[f"{name}_b"] = _load_tensor(model_dir, weight_map, f"resampler.{name}.bias").detach().to(dtype=torch_dt).contiguous() + + weights = MiniCPMVWeightsCStruct() + weights.vpm_patch_embedding_weight = keepalive["patch_w"].data_ptr() + weights.vpm_patch_embedding_bias = keepalive["patch_b"].data_ptr() + weights.vpm_position_embedding = keepalive["pos_emb"].data_ptr() + weights.vpm_layers = layers_arr + weights.vpm_post_layernorm_weight = keepalive["post_ln_w"].data_ptr() + weights.vpm_post_layernorm_bias = keepalive["post_ln_b"].data_ptr() + + weights.resampler_query = keepalive["res_query"].data_ptr() + weights.resampler_kv_proj_weight = keepalive["res_kv_proj_w_t"].data_ptr() + weights.resampler_attn_in_proj_weight = keepalive["res_in_w_t"].data_ptr() + weights.resampler_attn_in_proj_bias = keepalive["res_in_b"].data_ptr() + weights.resampler_attn_out_proj_weight = keepalive["res_out_w_t"].data_ptr() + weights.resampler_attn_out_proj_bias = keepalive["res_out_b"].data_ptr() + weights.resampler_ln_q_weight = keepalive["ln_q_w"].data_ptr() + weights.resampler_ln_q_bias = keepalive["ln_q_b"].data_ptr() + weights.resampler_ln_kv_weight = keepalive["ln_kv_w"].data_ptr() + weights.resampler_ln_kv_bias = keepalive["ln_kv_b"].data_ptr() + weights.resampler_ln_post_weight = keepalive["ln_post_w"].data_ptr() + weights.resampler_ln_post_bias = keepalive["ln_post_b"].data_ptr() + weights.resampler_proj = keepalive["res_proj"].data_ptr() + + # Unused language weights + weights.nlayer = 0 + weights.dt_norm = dt + weights.dt_mat = dt + weights.transpose_linear_weights = 0 + weights.input_embd = 0 + weights.output_norm = 0 + weights.output_embd = 0 + weights.attn_norm = None + weights.attn_qkv = None + weights.attn_qkv_b = None + weights.attn_q_norm = None + weights.attn_k_norm = None + weights.attn_o = None + weights.ffn_norm = None + weights.ffn_gate_up = None + weights.ffn_down = None + + model = MiniCPMVModel() + from ctypes import c_int + + dev_ids = (c_int * 1)(0) + model_handle = model.create_model(meta, weights, DeviceType.DEVICE_TYPE_CPU, 1, dev_ids) + + # ---------- Torch reference (vpm + resampler) ---------- + from minicpmv_config.modeling_navit_siglip import SiglipVisionConfig, SiglipVisionTransformer + from minicpmv_config.resampler import Resampler + + vcfg = SiglipVisionConfig( + hidden_size=d_v, + intermediate_size=di_v, + num_hidden_layers=nlayer, + num_attention_heads=nh_v, + num_channels=3, + image_size=int(vision_cfg["image_size"]), + patch_size=patch, + hidden_act="gelu_pytorch_tanh", + layer_norm_eps=1e-6, + attention_dropout=0.0, + ) + vcfg._attn_implementation = "eager" + vpm = SiglipVisionTransformer(vcfg).to(dtype=torch_dt).eval() + + # Load vpm weights by stripping the "vpm." prefix. + vpm_sd = {} + for k in weight_map.keys(): + if k.startswith("vpm."): + vpm_sd[k[len("vpm.") :]] = _load_tensor(model_dir, weight_map, k).to(dtype=torch_dt) + vpm.load_state_dict(vpm_sd, strict=True) + + resampler = Resampler( + num_queries=int(config["query_num"]), + embed_dim=int(config["hidden_size"]), + num_heads=int(config["hidden_size"] // 128), + kv_dim=d_v, + adaptive=True, + ).to(dtype=torch_dt).eval() + res_sd = {} + for k in weight_map.keys(): + if k.startswith("resampler."): + res_sd[k[len("resampler.") :]] = _load_tensor(model_dir, weight_map, k).to(dtype=torch_dt) + resampler.load_state_dict(res_sd, strict=True) + + overall_max = 0.0 + overall_sum = 0.0 + overall_n = 0 + + for slice_idx in slice_indices: + x = pixel_values[slice_idx].to(dtype=torch_dt).contiguous() # [3, patch, L] + th, tw = int(tgt_sizes[slice_idx][0].item()), int(tgt_sizes[slice_idx][1].item()) + seq_len = th * tw + assert x.shape[0] == 3 and x.shape[1] == patch + assert x.shape[2] == seq_len * patch + packed = x.unsqueeze(0).contiguous() # [1, 3, patch, L] + + out_cpp = torch.empty((resampler_meta.num_queries, resampler_meta.embed_dim), dtype=torch_dt) + model.infer_vision_resampler(model_handle, packed.data_ptr(), seq_len, th, tw, out_cpp.data_ptr()) + + patch_attn_mask = torch.ones((1, 1, seq_len), dtype=torch.bool) + tgt = torch.tensor([[th, tw]], dtype=torch.int32) + with torch.no_grad(): + hs = vpm(packed, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt).last_hidden_state + out_ref = resampler(hs, tgt)[0] + + diff = (out_cpp.float() - out_ref.float()).abs() + max_abs = float(diff.max().item()) + mean_abs = float(diff.mean().item()) + + overall_max = max(overall_max, max_abs) + overall_sum += float(diff.sum().item()) + overall_n += diff.numel() + + print("vision hf smoke:") + print(" slice_idx:", slice_idx) + print(" tgt_h:", th, "tgt_w:", tw, "seq_len:", seq_len) + print(" dtype:", torch_dt) + print(" out:", tuple(out_cpp.shape)) + print(" max_abs:", max_abs) + print(" mean_abs:", mean_abs) + + if len(slice_indices) > 1: + print("vision hf smoke summary:") + print(" slices:", len(slice_indices)) + print(" overall_max_abs:", overall_max) + print(" overall_mean_abs:", overall_sum / max(1, overall_n)) + + model.destroy_model(model_handle) + + +if __name__ == "__main__": + main() diff --git a/scripts/processing_minicpmv.py b/scripts/processing_minicpmv.py new file mode 100755 index 00000000..7db21712 --- /dev/null +++ b/scripts/processing_minicpmv.py @@ -0,0 +1,240 @@ +# coding=utf-8 +# Copyright 2024 The HuggingFace Inc. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" +Processor class for MiniCPMV. +""" + +from typing import List, Optional, Union, Dict, Any +import torch +import re + +from transformers.image_processing_utils import BatchFeature +from transformers.image_utils import ImageInput +from transformers.processing_utils import ProcessorMixin +from transformers.tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy +from transformers.utils import TensorType, requires_backends, is_torch_dtype, is_torch_device + +from image_processing_minicpmv import MiniCPMVBatchFeature + + +class MiniCPMVProcessor(ProcessorMixin): + r""" + Constructs a MiniCPMV processor which wraps a MiniCPMV image processor and a MiniCPMV tokenizer into a single processor. + + [`MiniCPMVProcessor`] offers all the functionalities of [`MiniCPMVImageProcessor`] and [`LlamaTokenizerWrapper`]. See the + [`~MiniCPMVProcessor.__call__`] and [`~MiniCPMVProcessor.decode`] for more information. + + Args: + image_processor ([`MiniCPMVImageProcessor`], *optional*): + The image processor is a required input. + tokenizer ([`LlamaTokenizerWrapper`], *optional*): + The tokenizer is a required input. + """ + attributes = ["image_processor", "tokenizer"] + image_processor_class = "AutoImageProcessor" + tokenizer_class = "AutoTokenizer" + + def __init__(self, image_processor=None, tokenizer=None): + super().__init__(image_processor, tokenizer) + self.version = image_processor.version + + def __call__( + self, + text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]], + images: ImageInput = None, + max_length: Optional[int] = None, + do_pad: Optional[bool] = True, + max_slice_nums: int = None, + use_image_id: bool = None, + return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH, + **kwargs + ) -> MiniCPMVBatchFeature: + + if images is not None: + image_inputs = self.image_processor(images, do_pad=do_pad, max_slice_nums=max_slice_nums, return_tensors=return_tensors) + return self._convert_images_texts_to_inputs(image_inputs, text, max_slice_nums=max_slice_nums, use_image_id=use_image_id, max_length=max_length, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama + def batch_decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please + refer to the docstring of this method for more information. + """ + output_ids = args[0] + result_text = [] + for result in output_ids: + result = result[result != 0] + if result[0] == self.tokenizer.bos_id: + result = result[1:] + if result[-1] == self.tokenizer.eos_id: + result = result[:-1] + result_text.append(self.tokenizer.decode(result, *args[1:], **kwargs).strip()) + return result_text + # return self.tokenizer.batch_decode(*args, **kwargs) + + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.decode with CLIP->Llama + def decode(self, *args, **kwargs): + """ + This method forwards all its arguments to LlamaTokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to + the docstring of this method for more information. + """ + result = args[0] + result = result[result != 0] + if result[0] == self.tokenizer.bos_id: + result = result[1:] + if result[-1] == self.tokenizer.eos_id or (hasattr(self.tokenizer, "eot_id") and result[-1] == self.tokenizer.eot_id): + result = result[:-1] + return self.tokenizer.decode(result, *args[1:], **kwargs).strip() + + def _convert( + self, input_str, max_inp_length: Optional[int] = None + ): + if self.version > 2.5 or not getattr(self.tokenizer, "add_bos_token", False): + input_ids = self.tokenizer.encode(input_str) + else: + input_ids = [self.tokenizer.bos_id] + self.tokenizer.encode(input_str) + if max_inp_length is not None: + input_ids = input_ids[:max_inp_length] + input_ids = torch.tensor(input_ids, dtype=torch.int32) + + start_cond = (input_ids == self.tokenizer.im_start_id) | (input_ids == self.tokenizer.slice_start_id) + end_cond = (input_ids == self.tokenizer.im_end_id) | (input_ids == self.tokenizer.slice_end_id) + + image_start_tokens = torch.where(start_cond)[0] + image_start_tokens += 1 + image_end_tokens = torch.where(end_cond)[0] + + valid_image_nums = max(len(image_start_tokens), len(image_end_tokens)) + + image_bounds = torch.hstack( + [ + image_start_tokens[:valid_image_nums].unsqueeze(-1), + image_end_tokens[:valid_image_nums].unsqueeze(-1), + ] + ) + return input_ids, image_bounds + + def _convert_images_texts_to_inputs( + self, + images, + texts: Union[str, List[str]], + truncation=None, + max_length=None, + max_slice_nums=None, + use_image_id=None, + return_tensors=None, + **kwargs + ): + if images is None or not len(images): + model_inputs = self.tokenizer(texts, return_tensors=return_tensors, truncation=truncation, max_length=max_length, **kwargs) + return MiniCPMVBatchFeature(data={**model_inputs}) + + pattern = "(./)" + images, image_sizes, tgt_sizes = images["pixel_values"], images["image_sizes"], images["tgt_sizes"] + + if isinstance(texts, str): + texts = [texts] + input_ids_list = [] + image_bounds_list = [] + for index, text in enumerate(texts): + image_tags = re.findall(pattern, text) + assert len(image_tags) == len(image_sizes[index]) + text_chunks = text.split(pattern) + final_text = "" + for i in range(len(image_tags)): + final_text = final_text + text_chunks[i] + \ + self.image_processor.get_slice_image_placeholder( + image_sizes[index][i], + i, + max_slice_nums, + use_image_id + ) + final_text += text_chunks[-1] + input_ids, image_bounds = self._convert(final_text, max_length) + input_ids_list.append(input_ids) + image_bounds_list.append(image_bounds) + padded_input_ids, padding_lengths = self.pad( + input_ids_list, + padding_side="left" + ) + for i, length in enumerate(padding_lengths): + image_bounds_list[i] = image_bounds_list[i] + length + attention_mask = padded_input_ids.ne(0) + + return MiniCPMVBatchFeature(data={ + "input_ids": padded_input_ids, + "attention_mask": attention_mask, + "pixel_values": images, + "image_sizes": image_sizes, + "image_bound": image_bounds_list, + "tgt_sizes": tgt_sizes + }) + + @property + # Copied from transformers.models.clip.processing_clip.CLIPProcessor.model_input_names + def model_input_names(self): + tokenizer_input_names = self.tokenizer.model_input_names + image_processor_input_names = self.image_processor.model_input_names + return list(dict.fromkeys(tokenizer_input_names + image_processor_input_names)) + + + def pad(self, inputs, max_length=None, padding_value=0, padding_side="left"): + items = [] + if isinstance(inputs[0], list): + assert isinstance(inputs[0][0], torch.Tensor) + for it in inputs: + for tr in it: + items.append(tr) + else: + assert isinstance(inputs[0], torch.Tensor) + items = inputs + + batch_size = len(items) + shape = items[0].shape + dim = len(shape) + assert dim <= 2 + if max_length is None: + max_length = 0 + max_length = max(max_length, max(item.shape[-1] for item in items)) + min_length = min(item.shape[-1] for item in items) + dtype = items[0].dtype + + if dim == 0: + return torch.stack([item for item in items], dim=0), [0] + elif dim == 1: + if max_length == min_length: + return torch.stack([item for item in items], dim=0), [0] * batch_size + tensor = torch.zeros((batch_size, max_length), dtype=dtype) + padding_value + else: + tensor = ( + torch.zeros((batch_size, max_length, shape[-1]), dtype=dtype) + + padding_value + ) + + padding_length = [] + for i, item in enumerate(items): + if dim == 1: + if padding_side == "left": + tensor[i, -len(item) :] = item.clone() + else: + tensor[i, : len(item)] = item.clone() + elif dim == 2: + if padding_side == "left": + tensor[i, -len(item) :, :] = item.clone() + else: + tensor[i, : len(item), :] = item.clone() + padding_length.append(tensor.shape[-1] - len(item)) + + return tensor, padding_length diff --git a/scripts/test_ceval.py b/scripts/test_ceval.py index 749f15c5..83365f4a 100644 --- a/scripts/test_ceval.py +++ b/scripts/test_ceval.py @@ -43,11 +43,7 @@ def generate(self, conversation, max_steps, topp_=1.0, topk_=1, temperature_=1.0 output_tokens = self.batch_infer_one_round([infer_task]) end_time = time.time() steps += 1 - output_str = ( - self.tokenizer._tokenizer.id_to_token(output_tokens[0]) - .replace("▁", " ") - .replace("<0x0A>", "\n") - ) + output_str = self.tokenizer.decode(output_tokens[0]) output_content += output_str print(output_str, end="", flush=True) if output_tokens[0] in self.eos_token_id: diff --git a/src/cache_manager/kv_compression.cpp b/src/cache_manager/kv_compression.cpp new file mode 100644 index 00000000..8f9f7a76 --- /dev/null +++ b/src/cache_manager/kv_compression.cpp @@ -0,0 +1,237 @@ +#include "kv_compression.hpp" +#include "../utils.hpp" + +#include +#include + +namespace { +// Magic: "KV C M" little endian = 0x4b56434d +constexpr uint32_t kMagic = 0x4b56434d; + +struct Header { + uint32_t magic; + uint32_t version; + uint16_t dtype_code; + uint16_t reserved; + uint32_t num_layers; + uint32_t num_heads; + uint32_t head_dim; + uint32_t hidden_size; + uint32_t compression_factor; + uint32_t min_seq_len; + uint32_t weight_count_per_layer; + uint32_t metadata_size_bytes; +}; +static_assert(sizeof(Header) == 44, "Header size mismatch"); + +enum class DTypeCode : uint16_t { FP16 = 0, BF16 = 1, FP32 = 2 }; + +infiniDtype_t toInfiniDtype(DTypeCode code) { + switch (code) { + case DTypeCode::FP16: + return INFINI_DTYPE_F16; + case DTypeCode::BF16: + return INFINI_DTYPE_BF16; + case DTypeCode::FP32: + return INFINI_DTYPE_F32; + default: + return INFINI_DTYPE_INVALID; + } +} + +struct WeightMeta { + uint32_t rows; + uint32_t cols; + uint32_t has_bias; +}; +static_assert(sizeof(WeightMeta) == 12, "WeightMeta size mismatch"); +} // namespace + +Compressor::Compressor(const CompressionConfig &cfg) : config_(cfg) {} + +bool Compressor::loadWeights() { + // Guard: require config enabled and weight path set. + if (!config_.enable) { + return false; + } + if (config_.weight_path.empty()) { + return false; + } + // PyTorch .pth is not supported directly. + if (config_.weight_path.size() >= 4 && + config_.weight_path.substr(config_.weight_path.size() - 4) == ".pth") { + std::stringstream ss; + ss << "Unsupported weight format (.pth) in " + << config_.weight_path + << "; convert to binary format described in docs/KVCacheCompressionWeightFormat.md"; + std::cerr << ss.str() << std::endl; + return false; + } + + std::ifstream fin(config_.weight_path, std::ios::binary); + if (!fin) { + std::stringstream ss; + ss << "Failed to open weight file: " << config_.weight_path; + std::cerr << ss.str() << std::endl; + return false; + } + + auto read_u32 = [&](uint32_t &out) -> bool { + char buf[4]; + fin.read(buf, 4); + if (!fin) return false; + out = static_cast(static_cast(buf[0])) | + (static_cast(static_cast(buf[1])) << 8) | + (static_cast(static_cast(buf[2])) << 16) | + (static_cast(static_cast(buf[3])) << 24); + return true; + }; + auto read_u16 = [&](uint16_t &out) -> bool { + char buf[2]; + fin.read(buf, 2); + if (!fin) return false; + out = static_cast(static_cast(buf[0])) | + (static_cast(static_cast(buf[1])) << 8); + return true; + }; + + Header hdr{}; + if (!read_u32(hdr.magic) || !read_u32(hdr.version) || + !read_u16(hdr.dtype_code) || !read_u16(hdr.reserved) || + !read_u32(hdr.num_layers) || !read_u32(hdr.num_heads) || + !read_u32(hdr.head_dim) || !read_u32(hdr.hidden_size) || + !read_u32(hdr.compression_factor) || !read_u32(hdr.min_seq_len) || + !read_u32(hdr.weight_count_per_layer) || !read_u32(hdr.metadata_size_bytes)) { + std::cerr << "Failed to read compression weight header" << std::endl; + return false; + } + std::cerr << "Header: magic=" << std::hex << hdr.magic << std::dec + << " version=" << hdr.version + << " dtype_code=" << hdr.dtype_code + << " num_layers=" << hdr.num_layers + << " weight_count_per_layer=" << hdr.weight_count_per_layer + << " meta_size=" << hdr.metadata_size_bytes + << std::endl; + if (hdr.magic != kMagic || hdr.version != 1) { + std::cerr << "Invalid compression weight header" << std::endl; + return false; + } + // Basic sanity checks on header fields. + if (hdr.num_layers == 0 || hdr.num_layers > 10000 || + hdr.weight_count_per_layer == 0 || hdr.weight_count_per_layer > 4096) { + std::cerr << "Invalid header values (num_layers/weight_count_per_layer)" << std::endl; + return false; + } + auto dtype = toInfiniDtype(static_cast(hdr.dtype_code)); + if (dtype == INFINI_DTYPE_INVALID) { + std::cerr << "Unsupported dtype in compression weight file" << std::endl; + return false; + } + // Sync config with header if not set + if (config_.compression_factor == 0 || config_.compression_factor == 1) { + config_.compression_factor = hdr.compression_factor; + } + if (config_.min_seq_len == 0) { + config_.min_seq_len = hdr.min_seq_len; + } + + // Skip metadata if present. + if (hdr.metadata_size_bytes > 0) { + fin.seekg(hdr.metadata_size_bytes, std::ios::cur); + } + + weights_.clear(); + prefix_offsets_.clear(); + layered_weights_.clear(); + weights_.reserve(static_cast(hdr.num_layers) * hdr.weight_count_per_layer); + layered_weights_.resize(hdr.num_layers); + + // Record layer offsets to map (layer, index) -> weights_ position. + for (uint32_t layer = 0; layer < hdr.num_layers; ++layer) { + prefix_offsets_.push_back(static_cast(weights_.size())); + layered_weights_[layer].reserve(hdr.weight_count_per_layer); + for (uint32_t w = 0; w < hdr.weight_count_per_layer; ++w) { + WeightMeta meta{}; + char meta_buf[sizeof(WeightMeta)]; + fin.read(meta_buf, sizeof(WeightMeta)); + if (!fin) { + std::cerr << "Unexpected EOF while reading weight meta" << std::endl; + return false; + } + meta.rows = static_cast(static_cast(meta_buf[0])) | + (static_cast(static_cast(meta_buf[1])) << 8) | + (static_cast(static_cast(meta_buf[2])) << 16) | + (static_cast(static_cast(meta_buf[3])) << 24); + meta.cols = static_cast(static_cast(meta_buf[4])) | + (static_cast(static_cast(meta_buf[5])) << 8) | + (static_cast(static_cast(meta_buf[6])) << 16) | + (static_cast(static_cast(meta_buf[7])) << 24); + meta.has_bias = static_cast(static_cast(meta_buf[8])) | + (static_cast(static_cast(meta_buf[9])) << 8) | + (static_cast(static_cast(meta_buf[10])) << 16) | + (static_cast(static_cast(meta_buf[11])) << 24); + + const size_t weight_elems = static_cast(meta.rows) * meta.cols; + // Guard against unreasonable sizes to avoid allocation overflow. + const size_t max_elems = static_cast(1e8); // ~200MB for fp16 + if (meta.rows == 0 || meta.cols == 0 || weight_elems > max_elems) { + std::cerr << "Unreasonable weight shape: rows=" << meta.rows + << " cols=" << meta.cols << std::endl; + return false; + } + const size_t weight_bytes = weight_elems * dsize(dtype); + std::vector buf(weight_bytes); + fin.read(buf.data(), static_cast(weight_bytes)); + if (!fin) { + std::cerr << "Unexpected EOF while reading weight data" << std::endl; + return false; + } + + // Create Tensor on device. + auto weight_tensor = Tensor::weight(buf.data(), dtype, {meta.rows, meta.cols}); + weights_.push_back(weight_tensor); + + std::shared_ptr bias_tensor = nullptr; + if (meta.has_bias) { + const size_t bias_bytes = static_cast(meta.rows) * dsize(dtype); + std::vector bias_buf(bias_bytes); + fin.read(bias_buf.data(), static_cast(bias_bytes)); + if (!fin) { + std::cerr << "Unexpected EOF while reading bias" << std::endl; + return false; + } + bias_tensor = Tensor::weight(bias_buf.data(), dtype, {meta.rows}); + weights_.push_back(bias_tensor); + } + layered_weights_[layer].push_back(LinearWeight{weight_tensor, bias_tensor}); + } + } + + return true; +} + +std::shared_ptr Compressor::getWeight(uint32_t layer, uint32_t idx) const { + if (layer >= prefix_offsets_.size()) { + return nullptr; + } + uint32_t base = prefix_offsets_[layer]; + uint32_t pos = base + idx; + if (pos >= weights_.size()) { + return nullptr; + } + return weights_[pos]; +} + +std::pair, std::shared_ptr> +Compressor::getLinearWithBias(uint32_t layer, uint32_t prefix_idx, uint32_t slot) const { + // prefix_idx in [0, kPrefixCount), slot in [0, kWeightsPerPrefix) + const uint32_t idx = prefix_idx * kWeightsPerPrefix + slot; + if (layer >= layered_weights_.size()) { + return {nullptr, nullptr}; + } + const auto &vec = layered_weights_[layer]; + if (idx >= vec.size()) { + return {nullptr, nullptr}; + } + return {vec[idx].weight, vec[idx].bias}; +} diff --git a/src/cache_manager/kv_compression.hpp b/src/cache_manager/kv_compression.hpp new file mode 100644 index 00000000..f56ec766 --- /dev/null +++ b/src/cache_manager/kv_compression.hpp @@ -0,0 +1,78 @@ +#pragma once + +#include "../cache.hpp" +#include +#include +#include +#include + +// Compression configuration set at model creation. +struct CompressionConfig { + bool enable = false; + uint32_t compression_factor = 1; // e.g., 4 or 5 + uint32_t min_seq_len = 0; // threshold to trigger compression + std::string weight_path; // path to binary weights + uint32_t image_kv_len = 0; // optional prefix length (tokens) treated as image KV + // Future: per-layer override, algorithm type, dtype override. +}; + +// Metadata and storage for compressed KV per device. +struct CompressedKV { + struct LayerKV { + std::shared_ptr k_comp; // compressed key + std::shared_ptr v_comp; // compressed value + uint32_t orig_seq_len = 0; + uint32_t comp_seq_len = 0; + std::shared_ptr indices; // optional: mapping indices (int32) + std::shared_ptr scales; // optional: scaling factors + }; + + // Device-local layout: [layer] order matches original KV. + std::vector layers; +}; + +// Compressor interface (declaration only; implementation to be added in a new .cpp). +class Compressor { +public: + explicit Compressor(const CompressionConfig &cfg); + + // Load weights from binary file; returns false on failure. + bool loadWeights(); + + // Compress device-local KV; returns compressed structure or nullptr on failure. + std::unique_ptr compress(const KVCache &kv, uint32_t seq_len); + + // Compress into temporary buffers and write back into `kv`'s preallocated storage (prefix [0,new_len)). + // Returns the new logical KV length (<= seq_len); returns seq_len on no-op or failure. + uint32_t compressInplace(KVCache &kv, uint32_t seq_len); + + // Decompress into temporary buffers for attention use; returns false on failure. + bool decompress(const CompressedKV &ckv, + std::vector> &k_out, + std::vector> &v_out); + +private: + struct LinearWeight { + std::shared_ptr weight; + std::shared_ptr bias; + }; + + CompressionConfig config_; + // Store per-layer weights; concrete types depend on algorithm design. + std::vector> weights_; + // Offsets into weights_ for each layer (start index of that layer). + std::vector prefix_offsets_; + // Structured access: [layer][block] -> (weight, bias) + std::vector> layered_weights_; + + // Helper to fetch weight tensor by (layer, index offset within layer). + std::shared_ptr getWeight(uint32_t layer, uint32_t idx) const; + + // Mapping utilities: each layer is expected to have weights ordered by prefix. + // Order: compress_tk[3], compress_tv[3], compress_ik[3], compress_iv[3], attention[*] (if present). + static constexpr uint32_t kWeightsPerPrefix = 3; + static constexpr uint32_t kPrefixCount = 4; // compress_tk, compress_tv, compress_ik, compress_iv + + std::pair, std::shared_ptr> + getLinearWithBias(uint32_t layer, uint32_t prefix_idx, uint32_t slot) const; +}; diff --git a/src/cache_manager/kv_compression_capi.cpp b/src/cache_manager/kv_compression_capi.cpp new file mode 100644 index 00000000..386a0256 --- /dev/null +++ b/src/cache_manager/kv_compression_capi.cpp @@ -0,0 +1,46 @@ +#include "kv_compression.hpp" + +#include "../cache.hpp" +#include "../utils.hpp" + +#include "infinicore_infer/kv_compression.h" + +#include + +__C __export uint32_t +compressKVCacheInplace(struct KVCache *kv_cache, uint32_t seq_len, const KVCompressionConfig *cfg) { + if (kv_cache == nullptr || cfg == nullptr || cfg->enable == 0) { + return seq_len; + } + if (seq_len == 0) { + return 0; + } + if (kv_cache->k.empty() || kv_cache->k[0].empty() || kv_cache->v.empty() || kv_cache->v[0].empty()) { + return seq_len; + } + if (!kv_cache->k[0][0]) { + return seq_len; + } + + CompressionConfig cxx_cfg; + cxx_cfg.enable = true; + cxx_cfg.compression_factor = cfg->compression_factor; + cxx_cfg.min_seq_len = cfg->min_seq_len; + cxx_cfg.image_kv_len = cfg->image_kv_len; + if (cfg->weight_path != nullptr) { + cxx_cfg.weight_path = cfg->weight_path; + } + + const uint32_t max_seq = static_cast(kv_cache->k[0][0]->shape()[0]); + seq_len = std::min(seq_len, max_seq); + + // Ensure weights are created on the same device as KV. + RUN_INFINI(infinirtSetDevice(kv_cache->k[0][0]->deviceType(), kv_cache->k[0][0]->deviceId())); + + Compressor compressor(cxx_cfg); + if (!compressor.loadWeights()) { + return seq_len; + } + return compressor.compressInplace(*kv_cache, seq_len); +} + diff --git a/src/cache_manager/kv_compression_impl.cpp b/src/cache_manager/kv_compression_impl.cpp new file mode 100644 index 00000000..24ba2d78 --- /dev/null +++ b/src/cache_manager/kv_compression_impl.cpp @@ -0,0 +1,483 @@ +#include "kv_compression.hpp" +#include "../utils.hpp" +#include "../tensor.hpp" +#include "../models/inference_context.hpp" +#include "infinicore_infer.h" +#include + +#include +#include + +namespace { +// Transpose a 2D weight (out, in) -> (in, out) into a contiguous buffer. +std::shared_ptr make_transposed(std::shared_ptr w, InferenceContext *ctx) { + if (!w || w->ndim() != 2) return w; + auto shape = w->shape(); // [out, in] + auto view_t = w->permute({1, 0}); // view with swapped strides + auto out = Tensor::buffer(w->dtype(), {shape[1], shape[0]}, ctx->memory_pool); + out->copyFrom(view_t, ctx->op_handle, ctx->stream); + return out; +} + +std::shared_ptr cast_weight_cpu(std::shared_ptr w, infiniDtype_t target_dtype) { + if (!w) return nullptr; + if (w->dtype() == target_dtype) return w; + if (w->deviceType() != INFINI_DEVICE_CPU) { + return nullptr; + } + const size_t n = w->numel(); + std::vector tmp(n); + + if (w->dtype() == INFINI_DTYPE_F16) { + auto *p = reinterpret_cast(w->data()); + for (size_t i = 0; i < n; ++i) tmp[i] = f16_to_f32(p[i]); + } else if (w->dtype() == INFINI_DTYPE_BF16) { + auto *p = reinterpret_cast(w->data()); + for (size_t i = 0; i < n; ++i) tmp[i] = bf16_to_f32(p[i]); + } else if (w->dtype() == INFINI_DTYPE_F32) { + auto *p = reinterpret_cast(w->data()); + std::copy(p, p + n, tmp.begin()); + } else { + return nullptr; + } + + if (target_dtype == INFINI_DTYPE_F16) { + std::vector out(n); + for (size_t i = 0; i < n; ++i) out[i] = f32_to_f16(tmp[i]); + return Tensor::weight(out.data(), target_dtype, w->shape()); + } + if (target_dtype == INFINI_DTYPE_BF16) { + std::vector out(n); + for (size_t i = 0; i < n; ++i) out[i] = f32_to_bf16(tmp[i]); + return Tensor::weight(out.data(), target_dtype, w->shape()); + } + if (target_dtype == INFINI_DTYPE_F32) { + return Tensor::weight(tmp.data(), target_dtype, w->shape()); + } + return nullptr; +} +} // namespace + +std::unique_ptr Compressor::compress(const KVCache &kv, uint32_t seq_len) { + if (!config_.enable) { + return nullptr; + } + if (weights_.empty()) { + std::cerr << "Compressor::compress: weights are empty" << std::endl; + return nullptr; + } + if (seq_len == 0) { + return nullptr; + } + + auto compressed = std::make_unique(); + if (kv.k.empty()) { + return nullptr; + } + const size_t ndev = kv.k.size(); + const size_t nlayers = kv.k[0].size(); + compressed->layers.resize(nlayers); + + // Only handle device 0 for now. + if (ndev == 0) { + return nullptr; + } + + // Validate / auto-initialize inference context for the current device. + auto ensure_ctx = [&]() -> InferenceContext * { + auto *ctx_ptr = maybe_get_context(); + if (ctx_ptr && ctx_ptr->op_handle != nullptr && ctx_ptr->memory_pool != nullptr) { + return ctx_ptr; + } + // Auto create a lightweight context bound to the KV device to allow tests to run. + static CacheManager auto_cache_mgr(32); + static std::shared_ptr auto_pool; + static infiniopHandle_t auto_handle = nullptr; + static infinirtStream_t auto_stream = nullptr; + static InferenceContext *auto_ctx = nullptr; + + if (auto_ctx == nullptr) { + // Bind to device 0 (first shard) since compressor currently assumes single device. + auto device_type = kv.k[0][0]->deviceType(); + auto device_id = kv.k[0][0]->deviceId(); + RUN_INFINI(infinirtSetDevice(device_type, device_id)); + auto_pool = std::make_shared(128 * 1024 * 1024); + RUN_INFINI(infiniopCreateHandle(&auto_handle)); + RUN_INFINI(infinirtStreamCreate(&auto_stream)); + static InferenceContext ctx(auto_handle, auto_pool, &auto_cache_mgr, auto_stream); + auto_ctx = &ctx; + } + setInferenceContext(auto_ctx); + return auto_ctx; + }; + + auto *ctx_ptr = ensure_ctx(); + if (!ctx_ptr || ctx_ptr->op_handle == nullptr || ctx_ptr->memory_pool == nullptr) { + std::cerr << "compress: inference context not initialized (op_handle/memory_pool), fallback to no-compress copy" << std::endl; + // Fallback: return original KV (device 0) without compression. + const size_t nlayers = kv.k[0].size(); + auto fallback = std::make_unique(); + fallback->layers.resize(nlayers); + for (size_t layer = 0; layer < nlayers; ++layer) { + auto k_tensor = kv.k[0][layer]; + auto v_tensor = kv.v[0][layer]; + if (!k_tensor || !v_tensor) { + return nullptr; + } + const uint32_t max_seq = static_cast(k_tensor->shape()[0]); + const uint32_t seq = std::min(seq_len, max_seq); + fallback->layers[layer].k_comp = k_tensor->slice(0, 0, seq); + fallback->layers[layer].v_comp = v_tensor->slice(0, 0, seq); + fallback->layers[layer].orig_seq_len = seq; + fallback->layers[layer].comp_seq_len = seq; + } + return fallback; + } + const uint32_t factor = config_.compression_factor > 0 ? config_.compression_factor : 1; + if (factor <= 1) { + return nullptr; + } + + // Ensure weight/bias dtypes match KV dtype on CPU to avoid Rearrange/Gemm dtype mismatches. + const auto kv_dtype = kv.k[0][0]->dtype(); + if (!layered_weights_.empty() && !layered_weights_[0].empty()) { + const auto w_dtype = layered_weights_[0][0].weight ? layered_weights_[0][0].weight->dtype() : kv_dtype; + if (w_dtype != kv_dtype) { + if (kv.k[0][0]->deviceType() != INFINI_DEVICE_CPU) { + std::cerr << "compress: weight dtype != kv dtype on non-CPU device; disable compression" << std::endl; + return nullptr; + } + for (auto &layer : layered_weights_) { + for (auto &lw : layer) { + auto casted_w = cast_weight_cpu(lw.weight, kv_dtype); + if (!casted_w) { + std::cerr << "compress: failed to cast weights to kv dtype" << std::endl; + return nullptr; + } + lw.weight = casted_w; + if (lw.bias) { + auto casted_b = cast_weight_cpu(lw.bias, kv_dtype); + if (!casted_b) { + std::cerr << "compress: failed to cast bias to kv dtype" << std::endl; + return nullptr; + } + lw.bias = casted_b; + } + } + } + } + } + + auto has_prefix_mlp = [&](uint32_t prefix_idx) -> bool { + if (layered_weights_.empty()) return false; + for (uint32_t slot = 0; slot < 3; ++slot) { + auto wb = getLinearWithBias(0, prefix_idx, slot); + if (!wb.first) return false; + } + return true; + }; + const bool has_text_mlp = has_prefix_mlp(0) && has_prefix_mlp(1); + const bool has_image_mlp = has_prefix_mlp(2) && has_prefix_mlp(3); + if (!has_text_mlp) { + std::cerr << "compress: missing text MLP weights (compress_tk/tv)" << std::endl; + return nullptr; + } + + for (size_t layer = 0; layer < nlayers; ++layer) { + auto k_tensor = kv.k[0][layer]; + auto v_tensor = kv.v[0][layer]; + if (!k_tensor || !v_tensor) { + return nullptr; + } + const auto &shape = k_tensor->shape(); + if (shape.size() != 3) { + return nullptr; + } + const uint32_t max_seq = static_cast(shape[0]); + const uint32_t seq = std::min(seq_len, max_seq); + const uint32_t nkvh = static_cast(shape[1]); + const uint32_t dk = static_cast(shape[2]); + + auto k_view = k_tensor->slice(0, 0, seq); + auto v_view = v_tensor->slice(0, 0, seq); + + auto fetch = [&](uint32_t prefix, uint32_t slot) -> std::pair, std::shared_ptr> { + return getLinearWithBias(static_cast(layer), prefix, slot); + }; + + auto run_pipeline = [&](std::shared_ptr input2d, uint32_t prefix) -> std::shared_ptr { + auto l0 = fetch(prefix, 0); + auto l1 = fetch(prefix, 1); + auto l2 = fetch(prefix, 2); + if (!l0.first || !l1.first || !l2.first) { + return nullptr; + } + if (l0.first->shape()[1] != factor * dk || + l1.first->shape()[1] != l0.first->shape()[0] || + l2.first->shape()[1] != l1.first->shape()[0]) { + std::cerr << "compress: weight/input shape mismatch at prefix " << prefix + << " layer " << layer << std::endl; + return nullptr; + } + auto w0 = make_transposed(l0.first, ctx_ptr); + auto w1 = make_transposed(l1.first, ctx_ptr); + auto w2 = make_transposed(l2.first, ctx_ptr); + + // auto w0 = l0.first; + // auto w1 = l1.first; + // auto w2 = l2.first; + + const size_t rows_linear = input2d->shape()[0]; + auto out0 = Tensor::buffer(input2d->dtype(), {rows_linear, l0.first->shape()[0]}, ctx_ptr->memory_pool); + auto out1 = Tensor::buffer(input2d->dtype(), {rows_linear, l1.first->shape()[0]}, ctx_ptr->memory_pool); + auto out2 = Tensor::buffer(input2d->dtype(), {rows_linear, l2.first->shape()[0]}, ctx_ptr->memory_pool); + + linear(out0, input2d, w0, 1.0f, 0.0f, nullptr, l0.second); + relu(out0, out0); + + linear(out1, out0, w1, 1.0f, 0.0f, nullptr, l1.second); + relu(out1, out1); + + linear(out2, out1, w2, 1.0f, 0.0f, nullptr, l2.second); + // NOTE: The compression path uses MemoryPool for temporary tensors. + // Infini op kernels are enqueued asynchronously on `ctx_ptr->stream`, and MemoryPool blocks + // are immediately reusable on tensor destruction. Synchronize here to ensure all kernels + // finish before intermediate buffers (w0/w1/w2/out0/out1) get released back to the pool. + RUN_INFINI(infinirtStreamSynchronize(ctx_ptr->stream)); + return out2; + }; + + auto compress_segment = [&](std::shared_ptr k_seg, + std::shared_ptr v_seg, + uint32_t prefix_base) -> std::pair, std::shared_ptr> { + if (!k_seg || !v_seg) return std::make_pair(nullptr, nullptr); + const auto seg_shape = k_seg->shape(); + uint32_t seg_len = static_cast(seg_shape[0]); + uint32_t compressed_seq_len = (seg_len / factor); + if (compressed_seq_len < config_.min_seq_len) { + return {k_seg, v_seg}; + } + uint32_t compress_len = compressed_seq_len * factor; + uint32_t remainder_len = seg_len - compress_len; + + auto k_head = k_seg->slice(0, 0, compress_len); + auto v_head = v_seg->slice(0, 0, compress_len); + + auto k_head_buf = Tensor::buffer(k_seg->dtype(), {compress_len, nkvh, dk}, ctx_ptr->memory_pool); + k_head_buf->copyFrom(k_head, ctx_ptr->op_handle, ctx_ptr->stream); + auto v_head_buf = Tensor::buffer(v_seg->dtype(), {compress_len, nkvh, dk}, ctx_ptr->memory_pool); + v_head_buf->copyFrom(v_head, ctx_ptr->op_handle, ctx_ptr->stream); + + // auto k_grouped = k_head_buf->view({compress_len / factor, nkvh, factor, dk}); + // auto v_grouped = v_head_buf->view({compress_len / factor, nkvh, factor, dk}); + + auto k_perm = k_head_buf->permute({1, 0, 2}); // 视图,非连续 + auto k_contig = Tensor::buffer(k_tensor->dtype(), {nkvh, compress_len, dk}, ctx_ptr->memory_pool); + k_contig->copyFrom(k_perm, ctx_ptr->op_handle, ctx_ptr->stream); + + auto v_perm = v_head_buf->permute({1, 0, 2}); + auto v_contig = Tensor::buffer(v_tensor->dtype(), {nkvh, compress_len, dk}, ctx_ptr->memory_pool); + v_contig->copyFrom(v_perm, ctx_ptr->op_handle, ctx_ptr->stream); + + auto k_grouped = k_contig->view({nkvh, compress_len / factor, factor * dk}); + auto v_grouped = v_contig->view({nkvh, compress_len / factor, factor * dk}); + + const size_t rows_linear = static_cast(compress_len / factor) * nkvh; + auto k_in2d = k_grouped->view({rows_linear, factor * dk}); + auto v_in2d = v_grouped->view({rows_linear, factor * dk}); + + auto k_comp2d = run_pipeline(k_in2d, prefix_base); + auto v_comp2d = run_pipeline(v_in2d, prefix_base + 1); + if (!k_comp2d || !v_comp2d) { + return {nullptr, nullptr}; + } + + // k_comp2d/v_comp2d rows are laid out as [nkvh, compressed_seq_len] (head-major), + // but KV cache storage expects [compressed_seq_len, nkvh, dk] (seq-major). + // Reshape to [nkvh, compressed_seq_len, dk] then permute to [compressed_seq_len, nkvh, dk]. + auto k_comp_head = k_comp2d->view({nkvh, compress_len / factor, dk})->permute({1, 0, 2}); + auto v_comp_head = v_comp2d->view({nkvh, compress_len / factor, dk})->permute({1, 0, 2}); + + if (remainder_len == 0) { + return {k_comp_head, v_comp_head}; + } + + std::shared_ptr k_comp; + std::shared_ptr v_comp; + k_comp = Tensor::buffer(k_tensor->dtype(), + {compressed_seq_len + remainder_len, nkvh, dk}, + ctx_ptr->memory_pool); + + v_comp = Tensor::buffer(v_tensor->dtype(), + {compressed_seq_len + remainder_len, nkvh, dk}, + ctx_ptr->memory_pool); + + auto k_tail = k_seg->slice(0, compress_len, remainder_len); + auto v_tail = v_seg->slice(0, compress_len, remainder_len); + // 目标前半段 [0, compressed_seq_len) 放压缩后的 head + auto k_dst_head = k_comp->slice(0, 0, compressed_seq_len); + auto v_dst_head = v_comp->slice(0, 0, compressed_seq_len); + rearrange(k_dst_head, k_comp_head); // [compressed_seq_len, nkvh, dk] + rearrange(v_dst_head, v_comp_head); + + // 目标后半段 [compressed_seq_len, compressed_seq_len + remainder_len) 放 tail + auto k_dst_tail = k_comp->slice(0, compressed_seq_len, remainder_len); + auto v_dst_tail = v_comp->slice(0, compressed_seq_len, remainder_len); + rearrange(k_dst_tail, k_tail); // [remainder_len, nkvh, dk] + rearrange(v_dst_tail, v_tail); + // auto k_out = Tensor::buffer(k_seg->dtype(), {compress_len / factor + remainder_len, nkvh, dk}, ctx_ptr->memory_pool); + // auto v_out = Tensor::buffer(v_seg->dtype(), {compress_len / factor + remainder_len, nkvh, dk}, ctx_ptr->memory_pool); + + // RUN_INFINI(infinirtMemcpy(k_out->data(), k_comp_head->data(), + // k_comp_head->numel() * dsize(k_comp_head->dtype()), + // INFINIRT_MEMCPY_D2D)); + // RUN_INFINI(infinirtMemcpy(v_out->data(), v_comp_head->data(), + // v_comp_head->numel() * dsize(v_comp_head->dtype()), + // INFINIRT_MEMCPY_D2D)); + // auto head_elems = k_comp_head->numel(); + // RUN_INFINI(infinirtMemcpy(k_out->data(head_elems * dsize(k_out->dtype())), + // k_tail->data(), + // k_tail->numel() * dsize(k_tail->dtype()), + // INFINIRT_MEMCPY_D2D)); + // RUN_INFINI(infinirtMemcpy(v_out->data(head_elems * dsize(v_out->dtype())), + // v_tail->data(), + // v_tail->numel() * dsize(v_tail->dtype()), + // INFINIRT_MEMCPY_D2D)); + return {k_comp, v_comp}; + }; + + uint32_t img_len = std::min(config_.image_kv_len, seq); + uint32_t txt_len = seq - img_len; + + + //这里可能有坑 + auto k_img = img_len > 0 ? k_view->slice(0, 0, img_len) : nullptr; + auto v_img = img_len > 0 ? v_view->slice(0, 0, img_len) : nullptr; + auto k_txt = txt_len > 0 ? k_view->slice(0, img_len, txt_len) : nullptr; + auto v_txt = txt_len > 0 ? v_view->slice(0, img_len, txt_len) : nullptr; + + std::shared_ptr k_img_comp, v_img_comp, k_txt_comp, v_txt_comp; + if (img_len > 0) { + if (has_image_mlp) { + auto res = compress_segment(k_img, v_img, 2); // compress_ik/iv + k_img_comp = res.first; + v_img_comp = res.second; + } else { + // Hybrid (text-only) weights: retain image KV prefix uncompressed. + k_img_comp = k_img; + v_img_comp = v_img; + } + } + if (txt_len > 0) { + auto res = compress_segment(k_txt, v_txt, 0); // compress_tk/tv + k_txt_comp = res.first; + v_txt_comp = res.second; + } + + std::shared_ptr k_comp, v_comp; + if (k_img_comp && k_txt_comp) { + auto total_len = k_img_comp->shape()[0] + k_txt_comp->shape()[0]; + k_comp = Tensor::buffer(k_tensor->dtype(), {total_len, nkvh, dk}, ctx_ptr->memory_pool); + v_comp = Tensor::buffer(v_tensor->dtype(), {total_len, nkvh, dk}, ctx_ptr->memory_pool); + // concat along seq dim using slice+rearrange + auto k_dst_img = k_comp->slice(0, 0, k_img_comp->shape()[0]); + auto k_dst_txt = k_comp->slice(0, k_img_comp->shape()[0], k_txt_comp->shape()[0]); + auto v_dst_img = v_comp->slice(0, 0, v_img_comp->shape()[0]); + auto v_dst_txt = v_comp->slice(0, v_img_comp->shape()[0], v_txt_comp->shape()[0]); + rearrange(k_dst_img, k_img_comp); + rearrange(k_dst_txt, k_txt_comp); + rearrange(v_dst_img, v_img_comp); + rearrange(v_dst_txt, v_txt_comp); + } else { + k_comp = k_img_comp ? k_img_comp : k_txt_comp; + v_comp = v_img_comp ? v_img_comp : v_txt_comp; + } + + compressed->layers[layer].k_comp = k_comp; + compressed->layers[layer].v_comp = v_comp; + compressed->layers[layer].orig_seq_len = seq; + compressed->layers[layer].comp_seq_len = k_comp ? static_cast(k_comp->shape()[0]) : 0; + } + + return compressed; +} + +uint32_t Compressor::compressInplace(KVCache &kv, uint32_t seq_len) { + if (!config_.enable) { + return seq_len; + } + if (seq_len == 0) { + return 0; + } + if (kv.k.empty() || kv.v.empty()) { + return seq_len; + } + if (kv.k.size() != 1 || kv.v.size() != 1) { + std::cerr << "compressInplace: only single-device KVCache is supported for now" << std::endl; + return seq_len; + } + + auto ckv = compress(kv, seq_len); + if (!ckv) { + return seq_len; + } + if (ckv->layers.empty()) { + return seq_len; + } + + auto *ctx_ptr = maybe_get_context(); + if (!ctx_ptr || ctx_ptr->op_handle == nullptr) { + std::cerr << "compressInplace: inference context not initialized" << std::endl; + return seq_len; + } + + const size_t nlayers = kv.k[0].size(); + if (ckv->layers.size() != nlayers) { + std::cerr << "compressInplace: layer count mismatch" << std::endl; + return seq_len; + } + + uint32_t new_len = ckv->layers[0].comp_seq_len; + for (size_t layer = 0; layer < nlayers; ++layer) { + if (!ckv->layers[layer].k_comp || !ckv->layers[layer].v_comp) { + std::cerr << "compressInplace: missing compressed tensor at layer " << layer << std::endl; + return seq_len; + } + if (ckv->layers[layer].comp_seq_len != new_len) { + std::cerr << "compressInplace: inconsistent compressed length across layers" << std::endl; + return seq_len; + } + } + + for (size_t layer = 0; layer < nlayers; ++layer) { + auto k_dst = kv.k[0][layer]->slice(0, 0, new_len); + auto v_dst = kv.v[0][layer]->slice(0, 0, new_len); + k_dst->copyFrom(ckv->layers[layer].k_comp, ctx_ptr->op_handle, ctx_ptr->stream); + v_dst->copyFrom(ckv->layers[layer].v_comp, ctx_ptr->op_handle, ctx_ptr->stream); + } + // Ensure the in-place KV writes are visible to subsequent decoding on other streams/threads. + RUN_INFINI(infinirtStreamSynchronize(ctx_ptr->stream)); + + return new_len; +} + + + + + +bool Compressor::decompress(const CompressedKV &ckv, + std::vector> &k_out, + std::vector> &v_out) { + // Placeholder: no real decompression; just fail to signal unimplemented. + (void)ckv; + (void)k_out; + (void)v_out; + return false; +} + +// Optional helper: create a placeholder compressor that is disabled. +std::unique_ptr createDisabledCompressor() { + CompressionConfig cfg; + cfg.enable = false; + return std::make_unique(cfg); +} diff --git a/src/cache_manager/opcache_manager.hpp b/src/cache_manager/opcache_manager.hpp index 4c49e961..512e32a1 100644 --- a/src/cache_manager/opcache_manager.hpp +++ b/src/cache_manager/opcache_manager.hpp @@ -158,10 +158,18 @@ class CacheManager { DECLARE_OP_CACHE(RoPE) DECLARE_OP_CACHE(Rearrange) DECLARE_OP_CACHE(CausalSoftmax) + DECLARE_OP_CACHE(Softmax) DECLARE_OP_CACHE(Topkrouter) DECLARE_OP_CACHE(SwiGLU) DECLARE_OP_CACHE(RandomSample) DECLARE_OP_CACHE(DequantizeAWQ) + DECLARE_OP_CACHE(Conv) + DECLARE_OP_CACHE(LayerNorm) + DECLARE_OP_CACHE(Relu) + DECLARE_OP_CACHE(GeluTanh) + DECLARE_OP_CACHE(QuickGelu) + DECLARE_OP_CACHE(Sigmoid) + DECLARE_OP_CACHE(Gelu) CacheManager(size_t capacity = 100) : Add_cache(capacity, DESTROY_FUNC(Add)), @@ -170,10 +178,18 @@ class CacheManager { RoPE_cache(capacity, DESTROY_FUNC(RoPE)), Rearrange_cache(capacity, DESTROY_FUNC(Rearrange)), CausalSoftmax_cache(capacity, DESTROY_FUNC(CausalSoftmax)), + Softmax_cache(capacity, DESTROY_FUNC(Softmax)), Topkrouter_cache(capacity, DESTROY_FUNC(Topkrouter)), SwiGLU_cache(capacity, DESTROY_FUNC(SwiGLU)), RandomSample_cache(capacity, DESTROY_FUNC(RandomSample)), - DequantizeAWQ_cache(capacity, DESTROY_FUNC(DequantizeAWQ)) {} + DequantizeAWQ_cache(capacity, DESTROY_FUNC(DequantizeAWQ)), + Conv_cache(capacity, DESTROY_FUNC(Conv)), + LayerNorm_cache(capacity, DESTROY_FUNC(LayerNorm)), + Relu_cache(capacity, DESTROY_FUNC(Relu)), + GeluTanh_cache(capacity, DESTROY_FUNC(GeluTanh)), + QuickGelu_cache(capacity, DESTROY_FUNC(QuickGelu)), + Sigmoid_cache(capacity, DESTROY_FUNC(Sigmoid)), + Gelu_cache(capacity, DESTROY_FUNC(Gelu)) {} template static size_t createDescriptorKey(Tensors... tensors) { diff --git a/src/models/inference_context.cpp b/src/models/inference_context.cpp index db5fda11..b7060c23 100644 --- a/src/models/inference_context.cpp +++ b/src/models/inference_context.cpp @@ -1,6 +1,9 @@ #include "inference_context.hpp" #include "../tensor.hpp" #include "../utils.hpp" +#include + +thread_local InferenceContext *tls_inference_context = nullptr; InferenceContext::InferenceContext(infiniopHandle_t op_handle_, std::shared_ptr memory_pool_, CacheManager *cache_manager, infinirtStream_t stream) : op_handle(op_handle_), memory_pool(memory_pool_), cache_manager(cache_manager), stream(stream) {} @@ -56,10 +59,62 @@ void InferenceContext::rmsnorm(std::shared_ptr y, y->data(), x->data(), w->data(), stream)); } +void InferenceContext::layernorm(std::shared_ptr output, + std::shared_ptr input_standardization, + std::shared_ptr input_std_deviation, + std::shared_ptr input, + std::shared_ptr weight, + std::shared_ptr bias, + float eps) { + // 构造 descriptor key(把所有相关 tensor 都参与 key) + size_t key = CacheManager::createDescriptorKey( + output, input_standardization, input_std_deviation, input, weight, bias); + + infiniopLayerNormDescriptor_t desc; + if (!cache_manager->getLayerNormDescriptor(key, desc)) { + // create descriptor: 注意 eps 必须是最后一个参数(与绑定的 C API 一致) + RUN_INFINI(infiniopCreateLayerNormDescriptor( + op_handle, + &desc, + output->desc(), // output_desc + input_standardization->desc(), // input_standardization_desc + input_std_deviation->desc(), // input_std_deviation_desc + input->desc(), // input_desc + weight ? weight->desc() : nullptr, // weight_desc (gamma) + bias ? bias->desc() : nullptr, // bias_desc (beta) or nullptr + eps // epsilon (最后) + )); + cache_manager->putLayerNormDescriptor(key, desc); + } + + // 获取 workspace 大小并确保 workspace 足够 + size_t workspace_size = 0; + RUN_INFINI(infiniopGetLayerNormWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + // 调用 kernel(最后一个参数是 stream) + RUN_INFINI(infiniopLayerNorm( + desc, + workspace, + workspace_size, + output->data(), + input_standardization->data(), + input_std_deviation->data(), + input->data(), + weight ? weight->data() : nullptr, + bias ? bias->data() : nullptr, + stream)); +} + void InferenceContext::gemm(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b, float alpha, float beta) { + // printf("--------------------------------------------\n"); + // printf("%s\n", a->info().c_str()); // for debug + // printf("%s\n", b->info().c_str()); // for debug + // printf("%s\n", c->info().c_str()); // for debug size_t key = CacheManager::createDescriptorKey(c, a, b); infiniopGemmDescriptor_t desc; @@ -143,6 +198,27 @@ void InferenceContext::causalSoftmax(std::shared_ptr y, y->data(), x->data(), stream)); } + +void InferenceContext::Softmax(std::shared_ptr y, std::shared_ptr x, int axis) { + size_t key = CacheManager::createDescriptorKey(y, x); + hash_combine(key, std::hash()(axis)); + + infiniopSoftmaxDescriptor_t desc; + if (!cache_manager->getSoftmaxDescriptor(key, desc)) { + RUN_INFINI(infiniopCreateSoftmaxDescriptor( + op_handle, &desc, y->desc(), x->desc(), axis)); + cache_manager->putSoftmaxDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetSoftmaxWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + RUN_INFINI(infiniopSoftmax(desc, workspace, workspace_size, + y->data(), x->data(), stream)); +} + void InferenceContext::topkrouter(std::shared_ptr values, // F32 std::shared_ptr indices, // I32 std::shared_ptr x, @@ -189,6 +265,146 @@ void InferenceContext::swiglu(std::shared_ptr out, out->data(), up->data(), gate->data(), stream)); } + + +void InferenceContext::relu(std::shared_ptr y, + std::shared_ptr x) { + size_t key = CacheManager::createDescriptorKey(y, x); + + infiniopReluDescriptor_t desc; + if (!cache_manager->getReluDescriptor(key, desc)) { + RUN_INFINI(infiniopCreateReluDescriptor(op_handle, &desc, y->desc(), x->desc())); + cache_manager->putReluDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetReluWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + RUN_INFINI(infiniopRelu(desc, + workspace, + workspace_size, + y->data(), x->data(), stream)); +} + +void InferenceContext::geluTanh(std::shared_ptr y, + std::shared_ptr x) { + size_t key = CacheManager::createDescriptorKey(y, x); + + infiniopGeluTanhDescriptor_t desc; + if (!cache_manager->getGeluTanhDescriptor(key, desc)) { + RUN_INFINI(infiniopCreateGeluTanhDescriptor(op_handle, &desc, y->desc(), x->desc())); + cache_manager->putGeluTanhDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetGeluTanhWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + RUN_INFINI(infiniopGeluTanh(desc, + workspace, + workspace_size, + y->data(), x->data(), stream)); +} + +void InferenceContext::layerNorm(std::shared_ptr y, + std::shared_ptr x, + std::shared_ptr w, + std::shared_ptr beta, + float epsilon) { + ASSERT_VALID_PTR(y); + ASSERT_VALID_PTR(x); + ASSERT_VALID_PTR(w); + ASSERT_VALID_PTR(beta); + + // Some implementations do not support in-place LayerNorm (output aliases input). + // Keep call sites simple by handling it here. + std::shared_ptr y_out = y; + std::shared_ptr y_tmp; + if (y.get() == x.get() || y->data() == x->data()) { + y_tmp = Tensor::buffer(y->dtype(), y->shape(), memory_pool); + y_out = y_tmp; + } + + // LayerNorm produces two extra outputs (standardization + std deviation). We don't + // expose them to callers, but descriptors require them, so we allocate temporaries. + // + // Keep intermediates in the same dtype as input to support device execution (e.g. Hygon) + // and avoid dtype-specific assumptions in backend implementations. + const infiniDtype_t inter_dt = x->dtype(); + + // CPU LayerNorm kernel assumes 3D input [B, L, D]. Adapt common 2D tensors [L, D] + // into [1, L, D] via views to avoid out-of-bounds access. + std::shared_ptr x_desc = x; + std::shared_ptr y_desc = y_out; + std::shared_ptr input_standardization; + std::shared_ptr input_std_deviation; + + if (x->deviceType() == INFINI_DEVICE_CPU && x->ndim() == 2) { + const auto &sh = x->shape(); + const auto &st = x->strides(); + const size_t L = sh[0]; + const size_t D = sh[1]; + const ptrdiff_t s0 = st[0]; + const ptrdiff_t s1 = st[1]; + x_desc = x->view_as({1, L, D}, {static_cast(L) * s0, s0, s1}); + y_desc = y_out->view_as({1, L, D}, {static_cast(L) * s0, s0, s1}); + input_standardization = Tensor::buffer(inter_dt, {1, L, D}, memory_pool); + input_std_deviation = Tensor::buffer(inter_dt, {1, L}, memory_pool); + } else { + input_standardization = Tensor::buffer(inter_dt, x->shape(), memory_pool); + std::vector std_shape = x->shape(); + if (!std_shape.empty()) { + std_shape.pop_back(); // stddev drops the normalized (last) dimension + } + if (std_shape.empty()) { + std_shape.push_back(1); + } + input_std_deviation = Tensor::buffer(inter_dt, std_shape, memory_pool); + } + + size_t key = CacheManager::createDescriptorKey(y_desc, x_desc, w, beta); + uint32_t eps_bits = 0; + std::memcpy(&eps_bits, &epsilon, sizeof(eps_bits)); + hash_combine(key, std::hash()(eps_bits)); + + infiniopLayerNormDescriptor_t desc; + if (!cache_manager->getLayerNormDescriptor(key, desc)) { + RUN_INFINI(infiniopCreateLayerNormDescriptor( + op_handle, &desc, + y_desc->desc(), + input_standardization->desc(), + input_std_deviation->desc(), + x_desc->desc(), + w->desc(), + beta->desc(), + epsilon)); + cache_manager->putLayerNormDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetLayerNormWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + RUN_INFINI(infiniopLayerNorm( + desc, workspace, workspace_size, + y_desc->data(), + input_standardization->data(), + input_std_deviation->data(), + x_desc->data(), + w->data(), + beta->data(), + stream)); + + if (y_tmp) { + rearrange(y, y_tmp); + } +} + + void InferenceContext::randomSample(std::shared_ptr out, std::shared_ptr prob, float random_val, float top_p, uint32_t top_k, float temperature) { @@ -219,8 +435,12 @@ void InferenceContext::linear(std::shared_ptr c, float alpha, float beta, std::shared_ptr residual, std::shared_ptr bias) { - bool residual_flag = residual != nullptr; + bool residual_flag = residual != nullptr; +// bias && !residual +// residual +// residual->data() == c->data() +// beta == 0.0 if (bias && !residual) { int ndim_diff = c->ndim() - 1; ASSERT_EQ(bias->ndim(), 1); @@ -234,18 +454,22 @@ void InferenceContext::linear(std::shared_ptr c, if (residual) { if (residual->data() == c->data()) { if (beta == 0.0) { + // std::cout << "1"; gemm(c, a, b, alpha, 1.0); } else { auto c_copy = Tensor::buffer(c->dtype(), c->shape(), memory_pool); c_copy->copyFrom(c, op_handle, stream); + // std::cout << "2"; gemm(c, a, b, alpha, beta); add(c, c, c_copy); } } else { + // std::cout << "3"; gemm(c, a, b, alpha, beta); add(c, c, residual); } } else { + // std::cout << "4"; gemm(c, a, b, alpha, beta); } @@ -281,3 +505,181 @@ void InferenceContext::dequant(std::shared_ptr weight, desc, workspace, workspace_size, weight->data(), in_w->data(), in_s->data(), in_z->data(), stream)); } + +void InferenceContext::conv2d(std::shared_ptr y, + std::shared_ptr x, + std::shared_ptr w, + std::shared_ptr b, + std::vector pads, + std::vector strides, + std::vector dilations) { + // 步骤1: 创建缓存键 - 包含所有影响算子行为的参数 + size_t key = CacheManager::createDescriptorKey(y, x, w, b); + + // 将卷积参数也纳入缓存键计算 + void *b_data = b ? b->data() : nullptr; + for (size_t pad : pads) { + hash_combine(key, std::hash()(pad)); + } + for (size_t stride : strides) { + hash_combine(key, std::hash()(stride)); + } + for (size_t dilation : dilations) { + hash_combine(key, std::hash()(dilation)); + } + + // 步骤2: 查找描述符缓存 + infiniopConvDescriptor_t desc; + auto b_desc = b ? b->desc() : nullptr; + if (!cache_manager->getConvDescriptor(key, desc)) { + + // std::cout << "X DESC = " << x->info() << std::endl; + // std::cout << "W DESC = " << w->info() << std::endl; + // std::cout << "Y DESC = " << y->info() << std::endl; + // std::cout << "pads: " << pads[0] << ", " << pads[1] << "\n"; + // std::cout << "strides: " << strides[0] << ", " << strides[1] << "\n"; + // std::cout << "dilations: " << dilations[0] << ", " << dilations[1] << "\n"; + // 步骤3: 创建新描述符并缓存 + RUN_INFINI(infiniopCreateConvDescriptor( + op_handle, &desc, y->desc(), x->desc(), w->desc(), b_desc, + pads.data(), strides.data(), dilations.data(), pads.size())); + cache_manager->putConvDescriptor(key, desc); + } + // 步骤4: 获取工作空间大小 + size_t workspace_size = 0; + RUN_INFINI(infiniopGetConvWorkspaceSize(desc, &workspace_size)); + + // 步骤5: 确保工作空间足够 + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + // 步骤6: 执行卷积算子 + RUN_INFINI(infiniopConv( + desc, workspace, workspace_size, + y->data(), x->data(), w->data(), b_data, stream)); +} + +void InferenceContext::quickGelu(std::shared_ptr y, + std::shared_ptr x) { + // 步骤1: 创建缓存键 - QuickGelu只依赖输入输出张量 + size_t key = CacheManager::createDescriptorKey(y, x); + + // 步骤2: 尝试从缓存中获取描述符 + infiniopQuickGeluDescriptor_t desc; + if (!cache_manager->getQuickGeluDescriptor(key, desc)) { + // 步骤3: 创建新的描述符 + RUN_INFINI(infiniopCreateQuickGeluDescriptor( + op_handle, &desc, y->desc(), x->desc())); + + // 缓存描述符以便复用 + cache_manager->putQuickGeluDescriptor(key, desc); + } + + // 步骤4: 获取工作空间大小 + size_t workspace_size = 0; + RUN_INFINI(infiniopGetQuickGeluWorkspaceSize(desc, &workspace_size)); + + // 步骤5: 确保工作空间充足 + ensure_workspace(workspace_size); + void* workspace = workspace_storage->memory(); + + // 步骤6: 执行 QuickGelu 算子 + RUN_INFINI(infiniopQuickGelu( + desc, workspace, workspace_size, + y->data(), x->data(), stream)); +} + + +void InferenceContext::softmax(std::shared_ptr y, + std::shared_ptr x, + int axis) { + // 步骤1: 创建缓存键 - 包含影响算子行为的参数(y, x, axis) + size_t key = CacheManager::createDescriptorKey(y, x); + hash_combine(key, std::hash()(axis)); // 将 axis 也纳入 key + + // 步骤2: 查找 Softmax 描述符缓存 + infiniopSoftmaxDescriptor_t desc; + if (!cache_manager->getSoftmaxDescriptor(key, desc)) { + // 步骤3: 创建新描述符 + RUN_INFINI(infiniopCreateSoftmaxDescriptor( + op_handle, &desc, y->desc(), x->desc(), axis)); + // 可以选择缓存 + // cache_manager->putSoftmaxDescriptor(key, desc); + } + + // 步骤4: 获取工作空间大小 + size_t workspace_size = 0; + RUN_INFINI(infiniopGetSoftmaxWorkspaceSize(desc, &workspace_size)); + + // 步骤5: 确保工作空间充足 + ensure_workspace(workspace_size); + void* workspace = workspace_storage->memory(); + + // 步骤6: 执行 Softmax 算子 + RUN_INFINI(infiniopSoftmax( + desc, workspace, workspace_size, + y->data(), x->data(), stream)); +} + +void InferenceContext::sigmoid(std::shared_ptr y, + std::shared_ptr x) { + // 步骤1: 创建缓存键(y 和 x 决定算子行为) + size_t key = CacheManager::createDescriptorKey(y, x); + + // 步骤2: 尝试从缓存获取描述符 + infiniopSigmoidDescriptor_t desc; + if (!cache_manager->getSigmoidDescriptor(key, desc)) { + // 步骤3: 创建新的描述符 + RUN_INFINI(infiniopCreateSigmoidDescriptor( + op_handle, &desc, y->desc(), x->desc())); + + // 缓存以供复用 + cache_manager->putSigmoidDescriptor(key, desc); + } + + // 步骤4: 获取工作空间大小 + size_t workspace_size = 0; + RUN_INFINI(infiniopGetSigmoidWorkspaceSize(desc, &workspace_size)); + + // 步骤5: 确保工作空间充足 + ensure_workspace(workspace_size); + void* workspace = workspace_storage->memory(); + + // 步骤6: 执行 Sigmoid 算子 + RUN_INFINI(infiniopSigmoid( + desc, workspace, workspace_size, + y->data(), x->data(), stream)); +} + +void InferenceContext::gelu(std::shared_ptr output, + std::shared_ptr input) { + // 构造 descriptor key(只需要用 output 和 input 参与 key) + size_t key = CacheManager::createDescriptorKey(output, input); + + infiniopGeluDescriptor_t desc; + if (!cache_manager->getGeluDescriptor(key, desc)) { + // 创建 GELU descriptor + RUN_INFINI(infiniopCreateGeluDescriptor( + op_handle, + &desc, + output->desc(), // output_desc + input->desc() // input_desc + )); + cache_manager->putGeluDescriptor(key, desc); + } + + // 获取 workspace 大小并确保 workspace 足够 + size_t workspace_size = 0; + RUN_INFINI(infiniopGetGeluWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + // 调用 GELU kernel + RUN_INFINI(infiniopGelu( + desc, + workspace, + workspace_size, + output->data(), + input->data(), + stream)); +} diff --git a/src/models/inference_context.hpp b/src/models/inference_context.hpp index 0cf93f6f..134b03ca 100644 --- a/src/models/inference_context.hpp +++ b/src/models/inference_context.hpp @@ -3,6 +3,7 @@ #include "../cache_manager/opcache_manager.hpp" #include +#include struct InferenceContext { infiniopHandle_t op_handle; @@ -23,6 +24,13 @@ struct InferenceContext { std::shared_ptr x, std::shared_ptr w, float epsilon); + void layernorm(std::shared_ptr output, + std::shared_ptr input_standardization, + std::shared_ptr input_std_deviation, + std::shared_ptr input, + std::shared_ptr weight, + std::shared_ptr bias, + float eps); void gemm(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b, @@ -37,6 +45,8 @@ struct InferenceContext { infiniopRoPEAlgo_t algo); void causalSoftmax(std::shared_ptr y, std::shared_ptr x); + + void Softmax(std::shared_ptr y, std::shared_ptr x, int axis); void topkrouter(std::shared_ptr values, // F32 std::shared_ptr indices, // I32 @@ -48,6 +58,18 @@ struct InferenceContext { void swiglu(std::shared_ptr out, std::shared_ptr up, std::shared_ptr gate); + void relu(std::shared_ptr y, + std::shared_ptr x); + + void layerNorm(std::shared_ptr y, + std::shared_ptr x, + std::shared_ptr w, + std::shared_ptr beta, + float epsilon); + + void geluTanh(std::shared_ptr y, + std::shared_ptr x); + void randomSample(std::shared_ptr out, std::shared_ptr prob, float random_val, float top_p, uint32_t top_k, float temperature); @@ -62,11 +84,22 @@ struct InferenceContext { std::shared_ptr in_w, std::shared_ptr in_s, std::shared_ptr in_z); + void conv2d(std::shared_ptr y, std::shared_ptr x, + std::shared_ptr w, std::shared_ptr b, + std::vector pads, std::vector strides, + std::vector dilations); + + void sigmoid(std::shared_ptr y, std::shared_ptr x); + + void softmax(std::shared_ptr y, std::shared_ptr x, int axis); + + void quickGelu(std::shared_ptr y, std::shared_ptr x); + + void gelu(std::shared_ptr output, std::shared_ptr input); }; -namespace { -thread_local InferenceContext *tls_inference_context = nullptr; -} +extern thread_local InferenceContext *tls_inference_context; +inline InferenceContext *maybe_get_context() { return tls_inference_context; } inline InferenceContext &getInferenceContext() { assert(tls_inference_context != nullptr && "InferenceContext not set for this thread"); @@ -86,6 +119,16 @@ inline void rmsnorm(std::shared_ptr y, std::shared_ptr x, getInferenceContext().rmsnorm(y, x, w, epsilon); } +inline void layernorm(std::shared_ptr output, + std::shared_ptr input_standardization, + std::shared_ptr input_std_deviation, + std::shared_ptr input, + std::shared_ptr weight, + std::shared_ptr bias, + float eps) { + getInferenceContext().layernorm(output, input_standardization, input_std_deviation, input, weight, bias, eps); +} + inline void gemm(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b, float alpha, float beta) { getInferenceContext().gemm(c, a, b, alpha, beta); @@ -111,6 +154,10 @@ inline void causalSoftmax(std::shared_ptr y, std::shared_ptr x) getInferenceContext().causalSoftmax(y, x); } +inline void Softmax(std::shared_ptr y, std::shared_ptr x, int axis) { + getInferenceContext().Softmax(y, x, axis); +} + inline void topkrouter(std::shared_ptr values, // F32 std::shared_ptr indices, // I32 std::shared_ptr x, @@ -131,6 +178,20 @@ inline void swiglu(std::shared_ptr out, std::shared_ptr up, getInferenceContext().swiglu(out, up, gate); } +inline void relu(std::shared_ptr y, std::shared_ptr x) { + getInferenceContext().relu(y, x); +} + +inline void layerNorm(std::shared_ptr y, std::shared_ptr x, + std::shared_ptr w, std::shared_ptr beta, + float epsilon) { + getInferenceContext().layerNorm(y, x, w, beta, epsilon); +} + +inline void geluTanh(std::shared_ptr y, std::shared_ptr x) { + getInferenceContext().geluTanh(y, x); +} + inline void randomSample(std::shared_ptr out, std::shared_ptr prob, float random_val, float top_p, uint32_t top_k, float temperature) { getInferenceContext().randomSample(out, prob, random_val, top_p, top_k, temperature); @@ -149,3 +210,30 @@ inline void dequant_linear(std::shared_ptr out, std::shared_ptr getInferenceContext().dequant(w, w_w, w_s, w_z); getInferenceContext().linear(out, x, w, alpha, beta, residual, bias); } + +// 新增Conv2d的便捷函数 +inline void conv2d(std::shared_ptr y, + std::shared_ptr x, + std::shared_ptr w, + std::shared_ptr b, + std::vector& pads, + std::vector& strides, + std::vector& dilations) { + getInferenceContext().conv2d(y, x, w, b, pads, strides, dilations); +} + +inline void quickGelu(std::shared_ptr out, std::shared_ptr in){ + getInferenceContext().quickGelu(out, in); +} + +inline void sigmoid(std::shared_ptr out, std::shared_ptr in){ + getInferenceContext().sigmoid(out, in); +} + +inline void softmax(std::shared_ptr out, std::shared_ptr in, int axis){ + getInferenceContext().softmax(out, in, axis); +} + +inline void gelu(std::shared_ptr out, std::shared_ptr in){ + getInferenceContext().gelu(out, in); +} diff --git a/src/models/jiuge/jiuge.cpp b/src/models/jiuge/jiuge.cpp index 059842cc..25f7df47 100644 --- a/src/models/jiuge/jiuge.cpp +++ b/src/models/jiuge/jiuge.cpp @@ -6,6 +6,7 @@ #include "../inference_context.hpp" #include "infinicore_infer.h" +#include #include #include #include @@ -21,7 +22,7 @@ void createDeviceResource(JiugeDeviceResource *rsrc, const JiugeMeta *meta, infinirtStream_t stream; infinirtStreamCreate(&stream); - std::vector> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_out, + std::vector> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_q_norm, w_attn_k_norm, w_attn_out, w_ffn_norm, w_ffn_gate_up, w_ffn_down; for (size_t layer = 0; layer < meta->nlayer; layer++) { w_attn_norm.push_back( @@ -32,6 +33,13 @@ void createDeviceResource(JiugeDeviceResource *rsrc, const JiugeMeta *meta, b_attn_qkv.push_back( getAttnQKVBias(meta, weights, layer, idev, ndev)); } + + if (weights->attn_q_norm != nullptr) { + w_attn_q_norm.push_back( + getAttnQNorm(meta, weights, layer)); + w_attn_k_norm.push_back( + getAttnKNorm(meta, weights, layer)); + } w_attn_out.push_back( getAttnO(meta, weights, layer, idev, ndev)); w_ffn_norm.push_back( @@ -56,6 +64,8 @@ void createDeviceResource(JiugeDeviceResource *rsrc, const JiugeMeta *meta, w_attn_norm, w_attn_qkv, b_attn_qkv, + w_attn_q_norm, + w_attn_k_norm, w_attn_out, w_ffn_norm, w_ffn_gate_up, @@ -111,13 +121,18 @@ void releaseDeviceResource(JiugeDeviceResource &res) { res.comm = nullptr; } -void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, - uint32_t idev, uint32_t ndev, - const uint32_t *tokens, uint32_t ntok, - const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, - struct KVCache **kv_caches, - const float *temperature, const uint32_t *topk, const float *topp, - uint32_t *output, void *last_logits) { +static void inferDeviceBatchEx(const JiugeMeta &meta, JiugeDeviceResource &rsrc, + uint32_t idev, uint32_t ndev, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, + const uint32_t *req_pos, + const uint32_t *kv_pos, + struct KVCache **kv_caches, + uint32_t n_override, + const uint32_t *override_pos, + const void *override_embeds, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output, void *last_logits) { auto nlayer = meta.nlayer; auto nkvh = meta.nkvh / ndev; auto nh = meta.nh / ndev; @@ -130,6 +145,7 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, auto dvoc = meta.dvoc; auto stream = rsrc.stream; bool has_qkv_bias = rsrc.b_attn_qkv.size() > 0; + bool has_qk_norm = rsrc.w_attn_q_norm.size() > 0 && rsrc.w_attn_k_norm.size() > 0; // Allocate buffers auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); @@ -142,6 +158,8 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, auto result_cpu = std::vector(nreq); auto qkv_rope = qkv_buf->view({ntok, nh + nkvh * 2, dh}); + auto q_buf = qkv_rope->slice(1, 0, nh); + auto k_buf = qkv_rope->slice(1, nh, nkvh); // Prepare inputs auto batch_pos_ids = std::vector(ntok); @@ -161,10 +179,28 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, RUN_INFINI(infinirtMemcpyAsync(pos_ids_buf->data(), batch_pos_ids.data(), sizeof(uint32_t) * ntok, INFINIRT_MEMCPY_H2D, stream)); } + + const char *override_ptr = reinterpret_cast(override_embeds); + const size_t unit = dsize(dt_logits); + uint32_t override_idx = 0; for (uint32_t i = 0; i < ntok; i++) { - RUN_INFINI(infinirtMemcpyAsync(logits_in->data(i * d), - rsrc.w_in_embd->data(tokens[i] * d), - dsize(dt_logits) * d, INFINIRT_MEMCPY_D2D, stream)); + if (override_ptr != nullptr && override_idx < n_override && override_pos[override_idx] == i) { + void *dst = logits_in->data(i * d); + const void *src = override_ptr + static_cast(override_idx) * d * unit; + if (rsrc.device == INFINI_DEVICE_CPU) { + std::memcpy(dst, src, unit * d); + } else { + RUN_INFINI(infinirtMemcpyAsync(dst, src, unit * d, INFINIRT_MEMCPY_H2D, stream)); + } + override_idx++; + continue; + } + RUN_INFINI(infinirtMemcpyAsync( + logits_in->data(i * d), + rsrc.w_in_embd->data(tokens[i] * d), + unit * d, + INFINIRT_MEMCPY_D2D, + stream)); } // Attention @@ -173,7 +209,7 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, size_t max_seq_len = 0; for (uint32_t req = 0; req < nreq; req++) { - auto past_len = req_pos[req]; + auto past_len = kv_pos[req]; auto seq_len = req_lens[req]; auto total_len = past_len + seq_len; @@ -198,13 +234,19 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, rmsnorm(logits_out, logits_in, rsrc.w_attn_norm[layer], meta.epsilon); // qkv_proj linear(qkv_buf, logits_out, rsrc.w_attn_qkv[layer], 1.0, 0.0, nullptr, has_qkv_bias ? rsrc.b_attn_qkv[layer] : nullptr); + + if (has_qk_norm) { + rmsnorm(q_buf, q_buf, rsrc.w_attn_q_norm[layer], meta.epsilon); + rmsnorm(k_buf, k_buf, rsrc.w_attn_k_norm[layer], meta.epsilon); + } + // rope - rope(qkv_rope->slice(1, 0, nh), qkv_rope->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table); - rope(qkv_rope->slice(1, nh, nkvh), qkv_rope->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table); + rope(q_buf, q_buf, pos_ids_buf, rsrc.sin_table, rsrc.cos_table); + rope(k_buf, k_buf, pos_ids_buf, rsrc.sin_table, rsrc.cos_table); size_t token_offset = 0; for (uint32_t req = 0; req < nreq; req++) { - auto past_len = req_pos[req]; + auto past_len = kv_pos[req]; auto seq_len = req_lens[req]; auto total_len = past_len + seq_len; auto o = o_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); @@ -297,19 +339,127 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, } } + +// 同步机制: +// 条件变量: +// cv_start:主线程通知工作线程开始推理 +// cv_done:工作线程通知主线程推理完成 +// cv_load:工作线程通知主线程设备已加载完成 +// 互斥锁:保护共享状态(proceed、loaded等标志位) + +// 执行流程: +// 1. 创建模型 → 启动N个工作线程 → 线程等待cv_start信号 +// 2. 调用inferBatchJiuge → 设置参数 → 发cv_start信号 +// 3. 工作线程被唤醒 → 调用inferDeviceBatch → 发cv_done信号 +// 4. 主线程等待所有cv_done → 推理完成 + +// inferBatchJiuge本身不执行具体的矩阵运算,它是一个协调器,负责: +// 准备推理参数 +// 唤醒所有设备的工作线程 +// 等待所有线程完成推理 + + __C void inferBatchJiuge(struct JiugeModel *model, - const uint32_t *tokens, uint32_t ntok, - const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, - struct KVCache **kv_caches, - const float *temperature, const uint32_t *topk, const float *topp, - uint32_t *output) { + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output) { + // 1. 设置推理参数(共享的请求结构体) + model->req.tokens = tokens; + model->req.ntok = ntok; + model->req.req_lens = req_lens; + model->req.nreq = nreq; + model->req.req_pos = req_pos; + model->req.kv_pos = req_pos; + model->req.kv_caches = kv_caches; + model->req.n_override = 0; + model->req.override_pos = nullptr; + model->req.override_embeds = nullptr; + model->req.output = output; + model->req.logits = nullptr; + model->req.temperature = temperature; + model->req.topk = topk; + model->req.topp = topp; + + // 2. 通知所有设备线程开始工作 + for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].proceed = true; // 设置信号 + lock.unlock(); + model->states[idev].cv_start.notify_one(); // 唤醒线程 + } + + // 3. 等待所有设备线程完成工作 + for (size_t i = model->dev_ids.size(); i > 0; i--) { + auto idev = i - 1; + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); }); + lock.unlock(); + } +} + +__C void +inferBatchJiugeWithLogits(struct JiugeModel *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output, void *logits) { + // 1. 设置推理参数(共享的请求结构体) + model->req.tokens = tokens; + model->req.ntok = ntok; + model->req.req_lens = req_lens; + model->req.nreq = nreq; + model->req.req_pos = req_pos; + model->req.kv_pos = req_pos; + model->req.kv_caches = kv_caches; + model->req.n_override = 0; + model->req.override_pos = nullptr; + model->req.override_embeds = nullptr; + model->req.output = output; + model->req.logits = logits; // 关键:设置 logits 输出 + model->req.temperature = temperature; + model->req.topk = topk; + model->req.topp = topp; + + // 2. 通知所有设备线程开始工作 + for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].proceed = true; // 设置信号 + lock.unlock(); + model->states[idev].cv_start.notify_one(); // 唤醒线程 + } + + // 3. 等待所有设备线程完成工作 + for (size_t i = model->dev_ids.size(); i > 0; i--) { + auto idev = i - 1; + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); }); + lock.unlock(); + } +} + +__C void +inferBatchJiugeEx(struct JiugeModel *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, + const uint32_t *req_pos, + const uint32_t *kv_pos, + struct KVCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output) { model->req.tokens = tokens; model->req.ntok = ntok; model->req.req_lens = req_lens; model->req.nreq = nreq; model->req.req_pos = req_pos; + model->req.kv_pos = kv_pos; model->req.kv_caches = kv_caches; + model->req.n_override = 0; + model->req.override_pos = nullptr; + model->req.override_embeds = nullptr; model->req.output = output; model->req.logits = nullptr; model->req.temperature = temperature; @@ -330,18 +480,347 @@ inferBatchJiuge(struct JiugeModel *model, } } +__C void +inferBatchJiugeExWithLogits(struct JiugeModel *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, + const uint32_t *req_pos, + const uint32_t *kv_pos, + struct KVCache **kv_caches, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output, void *logits) { + model->req.tokens = tokens; + model->req.ntok = ntok; + model->req.req_lens = req_lens; + model->req.nreq = nreq; + model->req.req_pos = req_pos; + model->req.kv_pos = kv_pos; + model->req.kv_caches = kv_caches; + model->req.n_override = 0; + model->req.override_pos = nullptr; + model->req.override_embeds = nullptr; + model->req.output = output; + model->req.logits = logits; + model->req.temperature = temperature; + model->req.topk = topk; + model->req.topp = topp; + + for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].proceed = true; + lock.unlock(); + model->states[idev].cv_start.notify_one(); + } + for (size_t i = model->dev_ids.size(); i > 0; i--) { + auto idev = i - 1; + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); }); + lock.unlock(); + } +} + __C void forwardBatchJiuge(struct JiugeModel *model, - const uint32_t *tokens, uint32_t ntok, - const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, - struct KVCache **kv_caches, - void *logits) { + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + void *logits) { + model->req.tokens = tokens; + model->req.ntok = ntok; + model->req.req_lens = req_lens; + model->req.nreq = nreq; + model->req.req_pos = req_pos; + model->req.kv_pos = req_pos; + model->req.kv_caches = kv_caches; + model->req.n_override = 0; + model->req.override_pos = nullptr; + model->req.override_embeds = nullptr; + model->req.output = nullptr; + model->req.logits = logits; + model->req.temperature = nullptr; + model->req.topk = nullptr; + model->req.topp = nullptr; + + // 这是主线程(比如调用inferBatchJiuge的地方)执行的代码 + for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].proceed = true; // 🚦 设置绿灯信号 // 👉 拍拍工人0的肩膀:"该干活了" + lock.unlock(); + model->states[idev].cv_start.notify_one(); // 📢 喊醒对应的线程 // 📢 "醒醒!" + } + for (size_t i = model->dev_ids.size(); i > 0; i--) { + auto idev = i - 1; + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); }); + // ⏳ 老板等待工人完成 + lock.unlock(); + } +} + +__C void +forwardBatchJiugeEx(struct JiugeModel *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, + const uint32_t *req_pos, + const uint32_t *kv_pos, + struct KVCache **kv_caches, + void *logits) { + model->req.tokens = tokens; + model->req.ntok = ntok; + model->req.req_lens = req_lens; + model->req.nreq = nreq; + model->req.req_pos = req_pos; + model->req.kv_pos = kv_pos; + model->req.kv_caches = kv_caches; + model->req.n_override = 0; + model->req.override_pos = nullptr; + model->req.override_embeds = nullptr; + model->req.output = nullptr; + model->req.logits = logits; + model->req.temperature = nullptr; + model->req.topk = nullptr; + model->req.topp = nullptr; + + for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].proceed = true; + lock.unlock(); + model->states[idev].cv_start.notify_one(); + } + for (size_t i = model->dev_ids.size(); i > 0; i--) { + auto idev = i - 1; + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); }); + lock.unlock(); + } +} + +__C void +inferBatchJiugeWithOverrides(struct JiugeModel *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + uint32_t n_override, + const uint32_t *override_pos, + const void *override_embeds, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output) { + model->req.tokens = tokens; + model->req.ntok = ntok; + model->req.req_lens = req_lens; + model->req.nreq = nreq; + model->req.req_pos = req_pos; + model->req.kv_pos = req_pos; + model->req.kv_caches = kv_caches; + model->req.n_override = n_override; + model->req.override_pos = override_pos; + model->req.override_embeds = override_embeds; + model->req.output = output; + model->req.logits = nullptr; + model->req.temperature = temperature; + model->req.topk = topk; + model->req.topp = topp; + + for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].proceed = true; + lock.unlock(); + model->states[idev].cv_start.notify_one(); + } + + for (size_t i = model->dev_ids.size(); i > 0; i--) { + auto idev = i - 1; + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); }); + lock.unlock(); + } +} + +__C void +inferBatchJiugeWithOverridesWithLogits(struct JiugeModel *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + uint32_t n_override, + const uint32_t *override_pos, + const void *override_embeds, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output, void *logits) { + model->req.tokens = tokens; + model->req.ntok = ntok; + model->req.req_lens = req_lens; + model->req.nreq = nreq; + model->req.req_pos = req_pos; + model->req.kv_pos = req_pos; + model->req.kv_caches = kv_caches; + model->req.n_override = n_override; + model->req.override_pos = override_pos; + model->req.override_embeds = override_embeds; + model->req.output = output; + model->req.logits = logits; + model->req.temperature = temperature; + model->req.topk = topk; + model->req.topp = topp; + + for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].proceed = true; + lock.unlock(); + model->states[idev].cv_start.notify_one(); + } + + for (size_t i = model->dev_ids.size(); i > 0; i--) { + auto idev = i - 1; + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); }); + lock.unlock(); + } +} + +__C void +inferBatchJiugeWithOverridesEx(struct JiugeModel *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, + const uint32_t *req_pos, + const uint32_t *kv_pos, + struct KVCache **kv_caches, + uint32_t n_override, + const uint32_t *override_pos, + const void *override_embeds, + const float *temperature, const uint32_t *topk, const float *topp, + uint32_t *output) { + model->req.tokens = tokens; + model->req.ntok = ntok; + model->req.req_lens = req_lens; + model->req.nreq = nreq; + model->req.req_pos = req_pos; + model->req.kv_pos = kv_pos; + model->req.kv_caches = kv_caches; + model->req.n_override = n_override; + model->req.override_pos = override_pos; + model->req.override_embeds = override_embeds; + model->req.output = output; + model->req.logits = nullptr; + model->req.temperature = temperature; + model->req.topk = topk; + model->req.topp = topp; + + for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].proceed = true; + lock.unlock(); + model->states[idev].cv_start.notify_one(); + } + for (size_t i = model->dev_ids.size(); i > 0; i--) { + auto idev = i - 1; + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); }); + lock.unlock(); + } +} + +// __C void +// inferBatchJiugeWithOverridesExWithLogits(struct JiugeModel *model, +// const uint32_t *tokens, uint32_t ntok, +// const uint32_t *req_lens, uint32_t nreq, +// const uint32_t *req_pos, +// const uint32_t *kv_pos, +// struct KVCache **kv_caches, +// uint32_t n_override, +// const uint32_t *override_pos, +// const void *override_embeds, +// const float *temperature, const uint32_t *topk, const float *topp, +// uint32_t *output, void *logits) { +// model->req.tokens = tokens; +// model->req.ntok = ntok; +// model->req.req_lens = req_lens; +// model->req.nreq = nreq; +// model->req.req_pos = req_pos; +// model->req.kv_pos = kv_pos; +// model->req.kv_caches = kv_caches; +// model->req.n_override = n_override; +// model->req.override_pos = override_pos; +// model->req.override_embeds = override_embeds; +// model->req.output = output; +// model->req.logits = logits; +// model->req.temperature = temperature; +// model->req.topk = topk; +// model->req.topp = topp; + +// for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { +// std::unique_lock lock(model->states[idev].mtx); +// model->states[idev].proceed = true; +// lock.unlock(); +// model->states[idev].cv_start.notify_one(); +// } +// for (size_t i = model->dev_ids.size(); i > 0; i--) { +// auto idev = i - 1; +// std::unique_lock lock(model->states[idev].mtx); +// model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); }); +// lock.unlock(); +// } +// } + +__C void +forwardBatchJiugeWithOverrides(struct JiugeModel *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + uint32_t n_override, + const uint32_t *override_pos, + const void *override_embeds, + void *logits) { + model->req.tokens = tokens; + model->req.ntok = ntok; + model->req.req_lens = req_lens; + model->req.nreq = nreq; + model->req.req_pos = req_pos; + model->req.kv_pos = req_pos; + model->req.kv_caches = kv_caches; + model->req.n_override = n_override; + model->req.override_pos = override_pos; + model->req.override_embeds = override_embeds; + model->req.output = nullptr; + model->req.logits = logits; + model->req.temperature = nullptr; + model->req.topk = nullptr; + model->req.topp = nullptr; + + for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].proceed = true; + lock.unlock(); + model->states[idev].cv_start.notify_one(); + } + for (size_t i = model->dev_ids.size(); i > 0; i--) { + auto idev = i - 1; + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].cv_done.wait(lock, [&] { return !(model->states[idev].proceed); }); + lock.unlock(); + } +} + +__C void +forwardBatchJiugeWithOverridesEx(struct JiugeModel *model, + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, + const uint32_t *req_pos, + const uint32_t *kv_pos, + struct KVCache **kv_caches, + uint32_t n_override, + const uint32_t *override_pos, + const void *override_embeds, + void *logits) { model->req.tokens = tokens; model->req.ntok = ntok; model->req.req_lens = req_lens; model->req.nreq = nreq; model->req.req_pos = req_pos; + model->req.kv_pos = kv_pos; model->req.kv_caches = kv_caches; + model->req.n_override = n_override; + model->req.override_pos = override_pos; + model->req.override_embeds = override_embeds; model->req.output = nullptr; model->req.logits = logits; model->req.temperature = nullptr; @@ -365,6 +844,7 @@ forwardBatchJiuge(struct JiugeModel *model, void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, JiugeDeviceResource *rsrc, InferState &state, InferRequest &req, infiniDevice_t device, int idev, int ndev, int dev_id, infinicclComm_t comm) { // Create Device Resource + // 初始化设备资源 createDeviceResource(rsrc, &meta, weights, device, idev, ndev, dev_id, comm); CacheManager cache_manager(100); @@ -373,6 +853,7 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, JiugeDevic // Set the inference context for this thread setInferenceContext(&ctx); + // 通知主线程:这个设备已经加载完成 { std::unique_lock lock(state.mtx); state.loaded = true; @@ -381,21 +862,27 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, JiugeDevic } // Infer Loop + // 进入推理循环(这个线程会一直运行) while (true) { std::unique_lock lock(state.mtx); + // 关键点:线程在这里停下来等待! state.cv_start.wait(lock, [&] { return state.proceed || state.exit_flag; }); // quit if exit_flag is set if (state.exit_flag) { - break; + break; // 退出线程 } - inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok, - req.req_lens, req.nreq, req.req_pos, req.kv_caches, - req.temperature, req.topk, req.topp, req.output, req.logits); + // 这里是关键:真正执行推理的地方! + // 只有收到信号才会执行到这里! + inferDeviceBatchEx(meta, *rsrc, idev, ndev, req.tokens, req.ntok, + req.req_lens, req.nreq, req.req_pos, req.kv_pos, req.kv_caches, + req.n_override, req.override_pos, req.override_embeds, + req.temperature, req.topk, req.topp, req.output, req.logits); - state.proceed = false; + state.proceed = false; // 重置信号 lock.unlock(); - state.cv_done.notify_one(); + // 通知主线程:这个设备完成了推理 + state.cv_done.notify_one(); // 通知主线程:我做完了 } // Clean-Up @@ -407,17 +894,21 @@ JiugeModel::JiugeModel(const JiugeMeta *_meta, const JiugeWeights *weights, infi int ndev = int(device_ids.size()); device = device_; dev_ids = device_ids; - dev_resources = std::vector(ndev); - states = std::vector(ndev); - threads.resize(ndev); + dev_resources = std::vector(ndev); // 每个设备的资源 + states = std::vector(ndev); // 每个设备的状态 + threads.resize(ndev); // 每个设备的线程 RUN_INFINI(infinirtInit()); auto comms = std::vector(ndev, nullptr); if (ndev > 1) { RUN_INFINI(infinicclCommInitAll(device, comms.data(), ndev, dev_ids.data())); } + // 一个卡一个线程 for (int i = 0; i < ndev; i++) { + // 🧵🧵🧵 这里创建线程! threads[i] = std::thread(launchDevice, std::cref(meta), weights, &dev_resources[i], std::ref(states[i]), std::ref(req), device, i, ndev, dev_ids[i], comms[i]); + // ⏳ 线程立即启动,进入launchDevice函数 + // 😴 在cv_start.wait()处开始休眠等待 } for (int i = 0; i < ndev; i++) { std::unique_lock lock(states[i].mtx); diff --git a/src/models/jiuge/jiuge_impl.hpp b/src/models/jiuge/jiuge_impl.hpp index 55800a37..41fba23e 100644 --- a/src/models/jiuge/jiuge_impl.hpp +++ b/src/models/jiuge/jiuge_impl.hpp @@ -20,7 +20,7 @@ struct JiugeDeviceResource { // Weights std::shared_ptr w_in_embd, w_out_norm, w_out_embd, sin_table, cos_table; - std::vector> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_out, + std::vector> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_q_norm, w_attn_k_norm,w_attn_out, w_ffn_norm, w_ffn_gate_up, w_ffn_down; // Streams infinirtStream_t stream; @@ -38,13 +38,57 @@ struct InferState { bool exit_flag = false; }; +// 1. mtx (互斥锁) +// 作用: 保护每个线程自己的状态变量访问 +// 特点: 每个线程有独立的mutex,避免线程间竞争 +// 使用场景: 在修改loaded、proceed、exit_flag时加锁 + +// 2. cv_load (加载完成条件变量) +// 作用: 设备线程等待主线程的加载信号 +// 流程: +// 设备线程:cv_load.wait() - 等待任务 +// 主线程:cv_load.notify_one() - 分配任务 +// 语义: "数据已准备好,可以开始执行" + +// 3. cv_start (开始执行条件变量) +// 作用: 控制设备线程开始执行推理 +// 流程: +// 主线程:cv_start.notify_one() - 允许开始 +// 设备线程:cv_start.wait() - 等待开始许可 +// 语义: "可以开始执行推理" + +// 4. cv_done (执行完成条件变量) +// 作用: 设备线程通知主线程任务完成 +// 流程: +// 设备线程:cv_done.notify_one() - 报告完成 +// 主线程:cv_done.wait() - 等待完成 +// 语义: "推理已完成" + +// 5. loaded (加载状态标志) +// 作用: 标识任务数据是否已加载完成 +// 值: false → true (主线程设置),true → false (设备线程重置) + +// 6. proceed (执行状态标志) +// 作用: 控制是否允许继续执行下一步 +// 值: false → true (主线程授权执行) + +// 7. exit_flag (退出标志) +// 作用: 通知线程优雅退出 +// 值: false → true (程序结束时设置) + + + struct InferRequest { const uint32_t *tokens; uint32_t ntok; const uint32_t *req_lens; uint32_t nreq; const uint32_t *req_pos; + const uint32_t *kv_pos; struct KVCache **kv_caches; + uint32_t n_override; + const uint32_t *override_pos; + const void *override_embeds; const float *temperature; const uint32_t *topk; const float *topp; diff --git a/src/models/jiuge/jiuge_weight.hpp b/src/models/jiuge/jiuge_weight.hpp index 6e8bc33e..7ee10155 100644 --- a/src/models/jiuge/jiuge_weight.hpp +++ b/src/models/jiuge/jiuge_weight.hpp @@ -70,6 +70,22 @@ inline std::shared_ptr getAttnQKVBias( return Tensor::weight((char *)(w->attn_qkv_b[layer]) + offset, w->dt_mat, shape); } +inline std::shared_ptr getAttnQNorm( + JiugeMeta const *meta, + JiugeWeights const *w, + size_t layer) { + auto shape = std::vector({meta->dh}); + return Tensor::weight((char *)(w->attn_q_norm[layer]), w->dt_norm, shape); +} + +inline std::shared_ptr getAttnKNorm( + JiugeMeta const *meta, + JiugeWeights const *w, + size_t layer) { + auto shape = std::vector({meta->dh}); + return Tensor::weight((char *)(w->attn_k_norm[layer]), w->dt_norm, shape); +} + inline std::shared_ptr getAttnO(JiugeMeta const *meta, JiugeWeights const *w, size_t layer, size_t idev, size_t ndev) { diff --git a/src/models/llava/llava.cpp b/src/models/llava/llava.cpp new file mode 100644 index 00000000..365e0961 --- /dev/null +++ b/src/models/llava/llava.cpp @@ -0,0 +1,1099 @@ +#include "llava_impl.hpp" +#include "llava_weight.hpp" + +#include "../../tensor.hpp" +#include "../../utils.hpp" +#include "../inference_context.hpp" +#include "infinicore_infer/models/llava.h" + +#include +#include +#include +#include +#include +#include +#include +#include + +static bool llava_debug_enabled() { + static int cached = -1; + if (cached == -1) { + const char *env = std::getenv("LLAVA_DEBUG"); + cached = (env != nullptr && std::strcmp(env, "0") != 0) ? 1 : 0; + } + return cached != 0; +} + +// LLaVA设备资源创建函数,模仿jiuge.cpp的createDeviceResource +void createLlavaDeviceResource(LlavaDeviceResource *rsrc, const LlavaMeta *meta, + const LlavaWeights *weights, + infiniDevice_t device, int idev, int ndev, int dev_id, + infinicclComm_t comm) { + RUN_INFINI(infinirtSetDevice(device, dev_id)); + infiniopHandle_t handle; + infiniopCreateHandle(&handle); + infinirtStream_t stream; + infinirtStreamCreate(&stream); + + // 初始化memory_pool + auto memory_pool = std::make_shared(128 * 1024 * 1024); + + // 初始化Language Model权重(暂时为空,复用jiuge结构) + std::vector> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_q_norm, w_attn_k_norm, w_attn_out, + w_ffn_norm, w_ffn_gate_up, w_ffn_down; + + // 初始化Vision Encoder权重 + auto vision_patch_embed_weight = getPatchEmbedWeight(meta, weights); + auto vision_position_embedding = createPositionEmbedding(meta, weights); // 从meta中获取形状 + auto vision_class_token = getClassToken(meta, weights); // 从meta中获取形状 + auto vision_pre_layernorm_weight = getVisionPreLNWeight(meta, weights); + auto vision_pre_layernorm_bias = getVisionPreLNBias(meta, weights); + + auto vision_post_layernorm_weight = getVisionPostLNWeight(meta, weights); + auto vision_post_layernorm_bias = getVisionPostLNBias(meta, weights); + + // 初始化Projector权重 + auto projector_weight_1 = getProjectorWeight1(meta, weights); + auto projector_bias_1 = getProjectorBias1(meta, weights); + auto projector_weight_2 = getProjectorWeight2(meta, weights); + auto projector_bias_2 = getProjectorBias2(meta, weights); + + std::vector> vision_q_weights, vision_q_biases, + vision_k_weights, vision_k_biases, + vision_v_weights, vision_v_biases, + vision_in_layer_pre_norm_weights, vision_in_layer_pre_norm_biases, + vision_proj_weight, vision_proj_bias, + vision_in_layer_post_norm_weight, vision_post_norm_bias, + vision_mlp_fc1_weight, vision_mlp_fc1_bias, + vision_mlp_fc2_weight, vision_mlp_fc2_bias; + + + for (size_t layer = 0; layer < meta->vision_meta.vision_num_layers; layer++) { + vision_q_weights.push_back( + getVisionQWeight(meta, weights, layer)); + vision_q_biases.push_back( + getVisionQBias(meta, weights, layer)); + vision_k_weights.push_back( + getVisionKWeight(meta, weights, layer)); + vision_k_biases.push_back( + getVisionKBias(meta, weights, layer)); + vision_v_weights.push_back( + getVisionVWeight(meta, weights, layer)); + vision_v_biases.push_back( + getVisionVBias(meta, weights, layer)); + // in-layer pre norm + vision_in_layer_pre_norm_weights.push_back( + getVisionInLayerPreNormWeight(meta, weights, layer)); + vision_in_layer_pre_norm_biases.push_back( + getVisionInLayerPreNormBias(meta, weights, layer)); + + // proj + vision_proj_weight.push_back( + getVisionProjWeight(meta, weights, layer)); + vision_proj_bias.push_back( + getVisionProjBias(meta, weights, layer)); + + // post norm + vision_in_layer_post_norm_weight.push_back( + getVisionInLayerPostNormWeight(meta, weights, layer)); + vision_post_norm_bias.push_back( + getVisionInLayerPostNormBias(meta, weights, layer)); + + // MLP fc1 + vision_mlp_fc1_weight.push_back( + getVisionMLPFC1Weight(meta, weights, layer)); + vision_mlp_fc1_bias.push_back( + getVisionMLPFC1Bias(meta, weights, layer)); + + // MLP fc2 + vision_mlp_fc2_weight.push_back( + getVisionMLPFC2Weight(meta, weights, layer)); + vision_mlp_fc2_bias.push_back( + getVisionMLPFC2Bias(meta, weights, layer)); + + } + + + // auto vision_class_embedding = getClassToken(meta); + + // 临时创建language model权重(将来应该从weights中加载) + std::shared_ptr w_in_embd = nullptr; + std::shared_ptr w_out_norm = nullptr; + std::shared_ptr w_out_embd = nullptr; + std::shared_ptr sin_table = nullptr; + std::shared_ptr cos_table = nullptr; + + *rsrc = LlavaDeviceResource{ + device, + dev_id, + handle, + w_in_embd, w_out_norm, w_out_embd, sin_table, cos_table, + w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_q_norm, w_attn_k_norm, w_attn_out, + w_ffn_norm, w_ffn_gate_up, w_ffn_down, + vision_patch_embed_weight, + vision_position_embedding, + vision_class_token, + vision_pre_layernorm_weight, vision_pre_layernorm_bias, + vision_post_layernorm_weight, vision_post_layernorm_bias, + vision_q_weights, vision_q_biases, + vision_k_weights, vision_k_biases, + vision_v_weights, vision_v_biases, + vision_in_layer_pre_norm_weights, vision_in_layer_pre_norm_biases, + vision_proj_weight, vision_proj_bias, + vision_in_layer_post_norm_weight, vision_post_norm_bias, + vision_mlp_fc1_weight, vision_mlp_fc1_bias, + vision_mlp_fc2_weight, vision_mlp_fc2_bias, + projector_weight_1, projector_bias_1, + projector_weight_2, projector_bias_2, + stream, + comm, + memory_pool, + }; + RUN_INFINI(infinirtDeviceSynchronize()); +} + +void releaseDeviceResource(LlavaDeviceResource &res) { + infinirtDeviceSynchronize(); + // Release individual Tensors + res.w_in_embd.reset(); + res.w_out_norm.reset(); + res.w_out_embd.reset(); + res.sin_table.reset(); + res.cos_table.reset(); + for (auto &t : res.w_attn_norm) { + t.reset(); + } + res.w_attn_norm.clear(); + for (auto &t : res.w_attn_qkv) { + t.reset(); + } + res.w_attn_qkv.clear(); + for (auto &t : res.b_attn_qkv) { + t.reset(); + } + res.b_attn_qkv.clear(); + for (auto &t : res.w_attn_out) { + t.reset(); + } + res.w_attn_out.clear(); + for (auto &t : res.w_ffn_norm) { + t.reset(); + } + res.w_ffn_norm.clear(); + for (auto &t : res.w_ffn_gate_up) { + t.reset(); + } + res.w_ffn_gate_up.clear(); + for (auto &t : res.w_ffn_down) { + t.reset(); + } + res.w_ffn_down.clear(); + res.projector_weight_1.reset(); + res.projector_bias_1.reset(); + res.projector_weight_2.reset(); + res.projector_bias_2.reset(); + infiniopDestroyHandle(res.handle); + res.handle = nullptr; + infinirtStreamDestroy(res.stream); + res.stream = nullptr; + infinicclCommDestroy(res.comm); + res.comm = nullptr; +} + +float fp16_to_fp32(uint16_t h) { + // 完整处理零、非规格化、Inf、NaN 的 FP16 -> FP32 转换 + uint32_t sign = (static_cast(h) & 0x8000u) << 16; + uint32_t exp = (h >> 10) & 0x1Fu; + uint32_t frac = h & 0x03FFu; + + uint32_t f_exp = 0; + uint32_t f_frac = 0; + + if (exp == 0) { + if (frac == 0) { + // zero + f_exp = 0; + f_frac = 0; + } else { + // subnormal: normalize + uint32_t e = 1; + while ((frac & 0x0400u) == 0) { + frac <<= 1; + e--; + } + frac &= 0x03FFu; + f_exp = (e + (127 - 15)) << 23; + f_frac = frac << 13; + } + } else if (exp == 0x1Fu) { + // Inf/NaN + f_exp = 0xFFu << 23; + f_frac = frac << 13; + if (f_frac != 0) { + f_frac |= 0x1u; // preserve a quiet NaN payload bit + } + } else { + // normal + f_exp = (exp + (127 - 15)) << 23; + f_frac = frac << 13; + } + + uint32_t bits = sign | f_exp | f_frac; + float out = 0.0f; + std::memcpy(&out, &bits, sizeof(out)); + return out; +} + +void debug_fp16_data_u16(const void* data, size_t count) { + if (!llava_debug_enabled()) { + return; + } + const uint16_t* ptr = static_cast(data); + + for (size_t i = 0; i < count; ++i) { + for (int j = 15; j >= 0; --j) { + std::cout << ((ptr[i] >> j) & 1); + } + std::cout << " "; + float val = fp16_to_fp32(ptr[i]); + if (std::isnan(val)) { + std::cout << i << ": NaN" << std::endl; + } else if (std::isinf(val)) { + std::cout << i << ": Inf" << std::endl; + } else { + std::cout << val << std::endl; + } + } +} + +// LLaVA视觉编码设备层推理函数(模仿inferDeviceBatch) +void inferDeviceBatchVision(const LlavaMeta &meta, LlavaDeviceResource &rsrc, + uint32_t idev, uint32_t ndev, + const void *image_data, uint32_t stage, void *output) { + + const bool debug = llava_debug_enabled(); + + // debug_fp16_data_u16(image_data, 100); + + // std::cout << "image_data pointer from cpp: " << image_data << std::endl; + // vision_tower部分 + // === 1. 准备参数 === + auto vision_embed_dim = meta.vision_meta.vision_embed_dim; // 1024 + auto vision_nh = meta.vision_meta.vision_num_heads; // 16 + auto image_size = meta.vision_meta.image_size; // 336 + auto patch_size = meta.vision_meta.patch_size; // 14 + auto dt_logits = meta.language_meta.dt_logits; // F16 + auto stream = rsrc.stream; + auto vision_num_layers = meta.vision_meta.vision_num_layers; // 24 + // 计算patches数量 + auto patches_per_dim = image_size / patch_size; // 24 + auto total_patches = patches_per_dim * patches_per_dim; // 576 + auto vision_intermediate_size = meta.vision_meta.vision_intermediate_size; // 4096 + + // 假设你已经得到了 q_buf, k_buf, v_buf shape = [1, seq_len, vision_embed_dim] + // 现在 reshape 成多头格式 + auto vision_dh = vision_embed_dim / vision_nh; + auto vision_seq = 1 + total_patches; // 577 + auto scale = 1.0f / std::sqrt(static_cast(vision_dh)); + + // === 2. 准备buffer === + // auto input_image_tensor_f32 = Tensor::buffer(INFINI_DTYPE_F32, {1, 3, image_size, image_size}, rsrc.memory_pool); + auto input_image_tensor = Tensor::buffer(dt_logits, {1, 3, image_size, image_size}, rsrc.memory_pool); + auto patch_embed_output = Tensor::buffer(dt_logits, {1, vision_embed_dim, patches_per_dim, patches_per_dim}, rsrc.memory_pool); + // embeddings = torch.cat([class_embeds, patch_embeds], dim=1) + auto embeddings = Tensor::buffer(dt_logits, {1, 1 + total_patches, vision_embed_dim}, rsrc.memory_pool); + // [ 1 577 1024 ] + auto pre_layernorm = Tensor::buffer(dt_logits, {1, 1 + total_patches, vision_embed_dim}, rsrc.memory_pool); + auto vision_residual = Tensor::buffer(dt_logits, {1, 1 + total_patches, vision_embed_dim}, rsrc.memory_pool); + auto in_layer_pre_norm = Tensor::buffer(dt_logits, {1, 1 + total_patches, vision_embed_dim}, rsrc.memory_pool); + // [ 1 577 1024 ] + auto q_buf = Tensor::buffer(dt_logits, {1, 1 + total_patches, vision_embed_dim}, rsrc.memory_pool); + auto k_buf = Tensor::buffer(dt_logits, {1, 1 + total_patches, vision_embed_dim}, rsrc.memory_pool); + auto v_buf = Tensor::buffer(dt_logits, {1, 1 + total_patches, vision_embed_dim}, rsrc.memory_pool); + auto input_standardization = Tensor::buffer(dt_logits, {1, 1 + total_patches, vision_embed_dim}, rsrc.memory_pool); + auto input_std_deviation = Tensor::buffer(dt_logits, {1, 1 + total_patches}, rsrc.memory_pool); + + + // 复制输入图像数据 + RUN_INFINI(infinirtMemcpyAsync(input_image_tensor->data(), image_data, + image_size * image_size * 3 * sizeof(uint16_t), + INFINIRT_MEMCPY_H2D, stream)); + + // printf("DEBUG: input_image_tensor after memcpy:\n"); + // input_image_tensor->debug_first_n(10000); + + // === 3. CLIPVisionEmbeddings Forward === + // Step 1: Patch Embedding (Conv2d) + + // printf("DEBUG: Running Conv2d: input [1,3,%ld,%ld] -> output [1,%ld,%ld,%ld]\n", + // image_size, image_size, vision_embed_dim, patches_per_dim, patches_per_dim); + + // 准备卷积参数 + std::vector pads = {0, 0}; // 无padding + std::vector strides = {static_cast(patch_size), static_cast(patch_size)}; + std::vector dilations = {1, 1}; + // input_image_tensor->debug_first_n(10000); + // patch_embeds = self.patch_embedding(pixel_values.to(dtype=target_dtype)) # Conv2d + conv2d(patch_embed_output, input_image_tensor, rsrc.vision_patch_embed_weight, + nullptr, pads, strides, dilations); // (1,1024,24,24) + + // flatten 2D patch -> [batch, embed_dim, total_patches] + auto patch_embed_flat = patch_embed_output->view({1, vision_embed_dim, total_patches}); + + // transpose -> [batch, total_patches, embed_dim] + auto patch_embed_transposed = patch_embed_flat->permute({0, 2, 1}); + // 创建 class embedding buffer + // class_embeds = self.class_embedding.expand(batch_size, 1, -1) + // assume batch=1 + + auto class_embed_tensor = Tensor::buffer(dt_logits, {1, 1, vision_embed_dim}, rsrc.memory_pool); + // Tensor: shape[ 1 1 1024 ] + RUN_INFINI(infinirtMemcpyAsync(class_embed_tensor->data(), + rsrc.vision_class_token->data(), + sizeof(uint16_t) * vision_embed_dim, + INFINIRT_MEMCPY_D2D, stream)); + + // printf("DEBUG: class_embed_tensor:\n"); + // class_embed_tensor->debug_first_n(20000); + // 1) 把 class token 放到 embeddings[:, 0:1, :] + rearrange(embeddings->slice(1, 0, 1), class_embed_tensor); // 注意:slice(dim=1, start=0, length=1) + + // printf("DEBUG: patch_embed_transposed:\n"); + // patch_embed_transposed->debug_first_n(20000); + // 2) 把所有 patch token 放到 embeddings[:, 1:1+T, :] + rearrange(embeddings->slice(1, 1, total_patches), patch_embed_transposed); // 注意:slice(dim=1, start=1, length=total_patches) + + // 3) 加 position embedding (pos tensor 必须是 [1, 1+T, C]) + // std::cout << "=== Before add() ===" << std::endl; + // embeddings->debug_first_n(20000); + // rsrc.vision_position_embedding->debug_first_n(20000); + add(embeddings, embeddings, rsrc.vision_position_embedding); + // printf("DEBUG: embeddings after add position embedding:\n"); + // embeddings->debug_first_n(10); + // embeddings->debug(); + + // (pre_layrnorm): LayerNorm((1024,), eps=1e-05, elementwise_affine=True) # 暂未实现 + // printf("meta.vision_meta.vision_epsilon: %e\n", meta.vision_meta.vision_epsilon); + // std::cout << "=== embeddings first 20000 values ===" << std::endl; + // embeddings->debug_first_n(20000); + layernorm(/*out_put*/ pre_layernorm, + /*input_standardization*/ input_standardization, + /*input_std_deviation*/ input_std_deviation, + /*input*/ embeddings, + /*weight*/ rsrc.vision_pre_layernorm_weight, + /*bias*/ rsrc.vision_pre_layernorm_bias, + meta.vision_meta.vision_epsilon); // 1e-5 + + if (stage == LLAVA_VISION_STAGE_PRE_LN) { + ASSERT_VALID_PTR(output); + const size_t out_rows = static_cast(vision_seq); + const size_t out_cols = static_cast(vision_embed_dim); + const size_t out_bytes = out_rows * out_cols * dsize(dt_logits); + RUN_INFINI(infinirtMemcpyAsync(output, + pre_layernorm->data(), + out_bytes, + INFINIRT_MEMCPY_D2H, + stream)); + RUN_INFINI(infinirtStreamSynchronize(stream)); + return; + } + // printf("DEBUG: pre_layernorm after LayerNorm_1\n"); + // pre_layernorm->debug_first_n(10); + + // printf("DEBUG: pre_layernorm:\n"); + // pre_layernorm->debug_first_n(10); + + auto layer_input = Tensor::buffer(dt_logits, {1, 1 + total_patches, vision_embed_dim}, rsrc.memory_pool); + auto layer_output = Tensor::buffer(dt_logits, {1, 1 + total_patches, vision_embed_dim}, rsrc.memory_pool); + + RUN_INFINI(infinirtMemcpyAsync(layer_input->data(), + pre_layernorm->data(), + sizeof(uint16_t) * (1 + total_patches) * vision_embed_dim, + INFINIRT_MEMCPY_D2D, stream)); + + // 用来存每层 hidden_states + std::vector> all_hidden_states; + all_hidden_states.reserve(vision_num_layers + 1); // 多预留一个 + + all_hidden_states.push_back(layer_input); + + for (uint32_t layer = 0; layer < vision_num_layers; layer++) { + // for (uint32_t layer = 0; layer < 1; layer++) { + + // residual = hidden_states + // vision_residual = pre_layernorm; + RUN_INFINI(infinirtMemcpyAsync(vision_residual->data(), + layer_input->data(), + sizeof(dt_logits) * (1 + total_patches) * vision_embed_dim, + INFINIRT_MEMCPY_D2D, stream)); + if (debug) { + printf("DEBUG: pre_layernorm:\n"); + pre_layernorm->debug_first_n(10); + } + // printf("DEBUG: vision_residual:\n"); + // vision_residual->debug_first_n(10); + + // (layer_norm1): LayerNorm((1024,), eps=1e-05, elementwise_affine=True)) + + // std::cout << "q_buf->info()" << q_buf->info() << std::endl; + layernorm(/*out_put*/ in_layer_pre_norm, + /*input_standardization*/ input_standardization, + /*input_std_deviation*/ input_std_deviation, + /*input*/ layer_input, + /*weight*/ rsrc.vision_in_layer_pre_norm_weights[layer], + /*bias*/ rsrc.vision_in_layer_pre_norm_biases[layer], + meta.vision_meta.vision_epsilon); // 1e-5 + if (debug) { + printf("layer_norm1 output:\n"); + in_layer_pre_norm->debug_first_n(10); + } + + // // 线性投影 + linear(q_buf, in_layer_pre_norm, rsrc.vision_q_weights[layer]->permute({1, 0}), 1.0, 0.0, nullptr, rsrc.vision_q_biases[layer]); + // // debug: 不考虑中间两行,这里是对的了。(== queries (first 10 elements): ) + + + + // q_buf->debug(); + linear(k_buf, in_layer_pre_norm, rsrc.vision_k_weights[layer]->permute({1, 0}), 1.0, 0.0, nullptr, rsrc.vision_k_biases[layer]); + linear(v_buf, in_layer_pre_norm, rsrc.vision_v_weights[layer]->permute({1, 0}), 1.0, 0.0, nullptr, rsrc.vision_v_biases[layer]); + + + // 1) rearrange Q/K/V → [vision_nh, vision_seq, vision_dh] + auto q_rearr = Tensor::buffer(dt_logits, {1, vision_nh, vision_seq, vision_dh}, rsrc.memory_pool); + auto k_rearr = Tensor::buffer(dt_logits, {1, vision_nh, vision_seq, vision_dh}, rsrc.memory_pool); + auto v_rearr = Tensor::buffer(dt_logits, {1, vision_nh, vision_seq, vision_dh}, rsrc.memory_pool); + + + rearrange(q_rearr, q_buf->view({1, vision_seq, vision_nh, vision_dh})->permute({0,2,1,3})); + rearrange(k_rearr, k_buf->view({1, vision_seq, vision_nh, vision_dh})->permute({0,2,1,3})); + rearrange(v_rearr, v_buf->view({1, vision_seq, vision_nh, vision_dh})->permute({0,2,1,3})); + + // printf("[DEBUG] Q output:\n"); + // q_rearr->debug_first_n(10); + + // printf("[DEBUG] K output:\n"); + // k_rearr->debug_first_n(10); + + // printf("[DEBUG] V output:\n"); + // v_rearr->debug_first_n(10); + + // 2) 准备 QK = [vision_nh, vision_seq, vision_seq] + auto qk_buf = Tensor::buffer(dt_logits, {vision_nh, vision_seq, vision_seq}, rsrc.memory_pool); + + // 3) Q * K^T + scaling + auto k_T = k_rearr->permute({0,1,3,2}); // [vision_nh, vision_dh, vision_seq] + + linear( + qk_buf, + q_rearr->slice(0, 0, 1)->view({vision_nh, vision_seq, vision_dh}), + k_T->slice(0, 0, 1)->view({vision_nh, vision_dh, vision_seq}), + /*alpha=*/scale, + /*beta=*/0.0, + nullptr, + nullptr + ); + + // printf("[DEBUG] attn_weights before softmax:\n"); + // qk_buf->debug_first_n(10); + + // 4) softmax (你还没实现,用 causalSoftmax 临时代替) + auto qk_softmax = qk_buf->view({vision_nh, vision_seq, vision_seq}); + softmax(qk_softmax, qk_softmax, -1); // non-causal softmax (vision) + + // printf("[DEBUG] qk_softmax after softmax:\n"); + // qk_softmax->debug_first_n(5); + + // 5) Attn * V + auto attn_val_buf = Tensor::buffer(dt_logits, {vision_nh, vision_seq, vision_dh}, rsrc.memory_pool); + // auto v_gemm = v_rearr->permute({0,1,3,2}); // [vision_nh, vision_dh, vision_seq] + auto v_gemm = v_rearr->permute({0,1,2,3}); // debug + + linear( + attn_val_buf, // debug: shape[ 16 577 64 ] strides[ 36928 64 1 ] + qk_softmax, // debug: shape[ 16 577 577 ] strides[ 332929 577 1 ] + v_gemm->slice(0, 0, 1)->view({vision_nh, vision_seq, vision_dh}), // debug: 注意这里的 view, 可能不对【shape[ 16 64 577 ] strides[ 36928 577 1 ]】 + /*alpha=*/1.0, + /*beta=*/0.0, + nullptr, + nullptr + ); + + // 6) 合头 → o: [1, vision_seq, vision_embed_dim] + auto o_tmp = Tensor::buffer(dt_logits, {1, vision_seq, vision_nh, vision_dh}, rsrc.memory_pool); + rearrange(o_tmp, attn_val_buf->view({1, vision_nh, vision_seq, vision_dh})->permute({0,2,1,3})); + // std::cout << "o_tmp->info()" << o_tmp->info() << std::endl; // Tensor: shape[ 1 577 16 64 ] + auto o = Tensor::buffer(dt_logits, {1, vision_seq, vision_embed_dim}, rsrc.memory_pool); + rearrange(o, o_tmp->view({1, vision_seq, vision_embed_dim})); + // std::cout << "o->info()" << o->info() << std::endl; + + + // === Attention out_proj === + // o -> attn_out + auto attn_out = Tensor::buffer(dt_logits, {1, vision_seq, vision_embed_dim}, rsrc.memory_pool); + linear(attn_out, o, rsrc.vision_proj_weight[layer]->permute({1, 0}), 1.0f, 0.0f, nullptr, rsrc.vision_proj_bias[layer]); + if (debug) { + printf("attn hidden_stats:\n"); + attn_out->debug_first_n(10); + } + // === Attention residual add === // 复用 pre_layernorm 作为输出 buffer + // hidden_states = residual + hidden_states + auto attn_residual_out = Tensor::buffer(dt_logits, {1, vision_seq, vision_embed_dim}, rsrc.memory_pool); + add(attn_residual_out, attn_out, vision_residual); + auto post_attn_norm = Tensor::buffer(dt_logits, {1, vision_seq, vision_embed_dim}, rsrc.memory_pool); + layernorm( + /*out*/ post_attn_norm, + /*input_standardization*/ input_standardization, + /*input_std_deviation*/ input_std_deviation, + /*input*/ attn_residual_out, + /*weight*/ rsrc.vision_in_layer_post_norm_weight[layer], + /*bias*/ rsrc.vision_post_norm_bias[layer], + meta.vision_meta.vision_epsilon + ); + if (debug) { + printf("layer norm2 output:\n"); + post_attn_norm->debug_first_n(10); + } + // === MLP === + // FC1: 1024 -> 4096 + auto mlp_fc1_out = Tensor::buffer(dt_logits, {1, vision_seq, vision_intermediate_size}, rsrc.memory_pool); + linear(mlp_fc1_out, post_attn_norm, rsrc.vision_mlp_fc1_weight[layer]->permute({1, 0}), + 1.0f, 0.0f, nullptr, rsrc.vision_mlp_fc1_bias[layer]); + + // QuickGELU Activation + auto mlp_activated_out = Tensor::buffer(dt_logits, {1, vision_seq, vision_intermediate_size}, rsrc.memory_pool); + quickGelu(mlp_activated_out, mlp_fc1_out); + + // FC2: 4096 -> 1024 + auto mlp_fc2_out = Tensor::buffer(dt_logits, {1, vision_seq, vision_embed_dim}, rsrc.memory_pool); + linear(mlp_fc2_out, mlp_activated_out, rsrc.vision_mlp_fc2_weight[layer]->permute({1, 0}), + 1.0f, 0.0f, nullptr, rsrc.vision_mlp_fc2_bias[layer]); + if (debug) { + printf("mlp output:\n"); + mlp_fc2_out->debug_first_n(10); + } + + // === 第二次残差连接:MLP === + add(layer_output, mlp_fc2_out, attn_residual_out); + + // 为下一层做准备 + std::swap(layer_input, layer_output); + + all_hidden_states.push_back(layer_input); + } + + // auto fake_output = Tensor::buffer(dt_logits, {1, vision_seq, vision_embed_dim}, rsrc.memory_pool); + auto post_layernorm_output = Tensor::buffer(dt_logits, {1, vision_seq, vision_embed_dim}, rsrc.memory_pool); + layernorm(post_layernorm_output, + input_standardization, + input_std_deviation, + layer_input, // 所有encoder层的输出 + rsrc.vision_post_layernorm_weight, // 需要在资源中添加这个权重 + rsrc.vision_post_layernorm_bias, // 需要在资源中添加这个偏置 + meta.vision_meta.vision_epsilon + ); + if (debug) { + printf("post_layernorm output:\n"); + post_layernorm_output->debug_first_n(10); + } + + // multi_modal_projector部分 + + int second_last_idx = all_hidden_states.size() - 2; + auto selected_all = Tensor::buffer(dt_logits, {1, vision_seq, vision_embed_dim}, rsrc.memory_pool); + rearrange(selected_all, all_hidden_states[second_last_idx]); + + if (stage == LLAVA_VISION_STAGE_SELECT_ALL) { + ASSERT_VALID_PTR(output); + const size_t out_rows = static_cast(vision_seq); + const size_t out_cols = static_cast(vision_embed_dim); + const size_t out_bytes = out_rows * out_cols * dsize(dt_logits); + RUN_INFINI(infinirtMemcpyAsync(output, + selected_all->data(), + out_bytes, + INFINIRT_MEMCPY_D2H, + stream)); + RUN_INFINI(infinirtStreamSynchronize(stream)); + return; + } + + auto projector_input = Tensor::buffer(dt_logits, {1, vision_seq - 1, vision_embed_dim}, rsrc.memory_pool); + rearrange(projector_input, selected_all->slice(1, 1, vision_seq - 1)); + + if (stage == LLAVA_VISION_STAGE_SELECT_PATCH) { + ASSERT_VALID_PTR(output); + const size_t out_rows = static_cast(vision_seq - 1); + const size_t out_cols = static_cast(vision_embed_dim); + const size_t out_bytes = out_rows * out_cols * dsize(dt_logits); + RUN_INFINI(infinirtMemcpyAsync(output, + projector_input->data(), + out_bytes, + INFINIRT_MEMCPY_D2H, + stream)); + RUN_INFINI(infinirtStreamSynchronize(stream)); + return; + } + + // 准备projector的buffer + auto projector_linear1_out = Tensor::buffer(dt_logits, {1, vision_seq - 1, 4096}, rsrc.memory_pool); + auto projector_gelu_out = Tensor::buffer(dt_logits, {1, vision_seq - 1, 4096}, rsrc.memory_pool); + auto projector_final_out = Tensor::buffer(dt_logits, {1, vision_seq - 1, 4096}, rsrc.memory_pool); + + // printf("projector weight 1:\n"); + // rsrc.projector_weight_1->debug_first_n(10); + // printf("projector bias 1:\n"); + // rsrc.projector_bias_1->debug_first_n(10); + if (debug) { + printf("projector_input:\n"); + projector_input->debug_first_n(10); + } + + // Linear 1: 1024 -> 4096 + linear(projector_linear1_out, + projector_input, + rsrc.projector_weight_1->permute({1, 0}), + 1.0f, 0.0f, + nullptr, + rsrc.projector_bias_1); + + if (debug) { + printf("projector linear1 output:\n"); + projector_linear1_out->debug_first_n(10); + } + + // GELU Activation + gelu(projector_gelu_out, projector_linear1_out); + + if (debug) { + printf("projector gelu output:\n"); + projector_gelu_out->debug_first_n(10); + } + + // printf("projector weight 2:\n"); + // rsrc.projector_weight_2->debug_first_n(10); + // printf("projector bias 2:\n"); + // rsrc.projector_bias_2->debug_first_n(10); + + // Linear 2: 4096 -> 4096 + linear(projector_final_out, + projector_gelu_out, + rsrc.projector_weight_2->permute({1, 0}), + 1.0f, 0.0f, + nullptr, + rsrc.projector_bias_2); + + if (stage == LLAVA_VISION_STAGE_PROJECTOR_ALL) { + auto projector_in_all = Tensor::buffer(dt_logits, {1, vision_seq, vision_embed_dim}, rsrc.memory_pool); + rearrange(projector_in_all, selected_all); + auto proj1_all = Tensor::buffer(dt_logits, {1, vision_seq, 4096}, rsrc.memory_pool); + auto gelu_all = Tensor::buffer(dt_logits, {1, vision_seq, 4096}, rsrc.memory_pool); + auto proj2_all = Tensor::buffer(dt_logits, {1, vision_seq, 4096}, rsrc.memory_pool); + + linear(proj1_all, + projector_in_all, + rsrc.projector_weight_1->permute({1, 0}), + 1.0f, 0.0f, + nullptr, + rsrc.projector_bias_1); + gelu(gelu_all, proj1_all); + linear(proj2_all, + gelu_all, + rsrc.projector_weight_2->permute({1, 0}), + 1.0f, 0.0f, + nullptr, + rsrc.projector_bias_2); + + ASSERT_VALID_PTR(output); + const size_t out_rows = static_cast(vision_seq); + const size_t out_cols = static_cast(meta.projector_meta.text_embed_dim); + const size_t out_bytes = out_rows * out_cols * dsize(dt_logits); + RUN_INFINI(infinirtMemcpyAsync(output, + proj2_all->data(), + out_bytes, + INFINIRT_MEMCPY_D2H, + stream)); + RUN_INFINI(infinirtStreamSynchronize(stream)); + return; + } + + if (debug) { + printf("projector final output:\n"); + projector_final_out->debug_first_n(10); + } + + // Write projector output back to host buffer. + // Output contract: [vision_patches, text_embed_dim] == [576, 4096], dtype = meta.language_meta.dt_logits. + ASSERT_VALID_PTR(output); + const size_t out_rows = static_cast(vision_seq - 1); + const size_t out_cols = static_cast(meta.projector_meta.text_embed_dim); + const size_t out_bytes = out_rows * out_cols * dsize(dt_logits); + RUN_INFINI(infinirtMemcpyAsync(output, + projector_final_out->data(), + out_bytes, + INFINIRT_MEMCPY_D2H, + stream)); + RUN_INFINI(infinirtStreamSynchronize(stream)); +} + + + + +// LLaVA设备工作线程函数,严格按照jiuge.cpp的launchDevice结构 +void launchLlavaDevice(const LlavaMeta &meta, const LlavaWeights *weights, + LlavaDeviceResource *rsrc, LlavaInferState &state, + LlavaRequest &req, + infiniDevice_t device, int idev, int ndev, int dev_id, + infinicclComm_t comm) { + // Create Device Resource + // 初始化设备资源 + createLlavaDeviceResource(rsrc, &meta, weights, device, idev, ndev, dev_id, comm); + + CacheManager cache_manager(100); + InferenceContext ctx(rsrc->handle, rsrc->memory_pool, &cache_manager, rsrc->stream); + setInferenceContext(&ctx); + + // 通知主线程:这个设备已经加载完成 + // TODO: 没有检查现在标志位是否靠谱 + { + std::unique_lock lock(state.mtx); + state.loaded = true; + lock.unlock(); + state.cv_stage.notify_one(); + } + + // Infer Loop + // 进入推理循环(这个线程会一直运行) + while (true) { + std::unique_lock lock(state.mtx); + // 关键点:线程在这里停下来等待! + state.cv_stage.wait(lock, [&] { return state.proceed || state.exit_flag; }); + // quit if exit_flag is set + if (state.exit_flag) { + break; // 退出线程 + } + + // TODO: 执行推理 + // // 占位符:简单返回一个token + // if (req.output && req.batch_size > 0) { + // req.output[0] = 1; + // } + + inferDeviceBatchVision(meta, *rsrc, idev, ndev, + req.image_data, req.vision_stage, req.output); + + // // === LLaVA四阶段推理流程 === + // // 阶段1: Vision Encoder (如果有图像) + // if (req.image_data != nullptr) { + // state.current_stage = 1; + // state.stage_completed = false; + // lock.unlock(); + // state.cv_stage.notify_one(); // 通知主线程进入阶段1 + + // // TODO: 实现vision encoding + // // encodeVisionFeatures(meta, *rsrc, req.image_data, state.vision_features); + + // lock.lock(); + // state.stage_completed = true; + // state.current_stage = 2; + // } + + // // 阶段2: MultiModal Projector (如果有图像特征) + // if (state.vision_features != nullptr) { + // lock.unlock(); + // state.cv_stage.notify_one(); // 通知主线程进入阶段2 + + // // TODO: 实现multimodal projection + // // projectMultiModalFeatures(meta, *rsrc, state.vision_features, state.projected_features); + + // lock.lock(); + // state.stage_completed = true; + // state.current_stage = 3; + // } + + // // 阶段3: Language Model Prefill (包含KV-Cache) + // state.current_stage = 3; + // state.stage_completed = false; + // lock.unlock(); + // state.cv_stage.notify_one(); // 通知主线程进入阶段3 + + // // TODO: 实现language model prefill + // // 这里调用Jiuge的推理逻辑来处理text tokens + projected vision features + // // inferDeviceBatchLanguage(meta, *rsrc, idev, ndev, req.input_tokens, req.ntok, + // // req.req_lens, req.nreq, req.req_pos, req.kv_caches, + // // req.temperature, req.topk, req.topp, req.output, nullptr); + + // lock.lock(); + // state.stage_completed = true; + // state.current_stage = 4; + + // // 阶段4: KV-Cache Compression (可选) + // if (req.kv_caches != nullptr && state.stage_completed) { + // lock.unlock(); + // state.cv_stage.notify_one(); // 通知主线程进入阶段4 + + // // TODO: 实现KV-Cache压缩 (Future: 集成Fastcache) + // // compressKVCaches(meta, *rsrc, req.kv_caches); + + // lock.lock(); + // state.stage_completed = true; + // } + + // // 简单占位符:返回一个token (临时) + // if (req.output && req.batch_size > 0) { + // req.output[0] = 1; // 暂时返回固定token + // } + + + + state.proceed = false; // 重置信号 + lock.unlock(); + // 通知主线程:这个设备完成了推理 + state.cv_stage.notify_one(); + } + // Clean-Up + releaseDeviceResource(*rsrc); + setInferenceContext(nullptr); // Clear the context when done +} + + + +// // LLaVA四阶段统一推理实现 +// void LlavaModel::inferBatchLlava(const uint32_t* input_tokens, const void* image_data, +// void** kv_caches, uint32_t batch_size, +// uint32_t* output) { +// // 1. 设置推理请求参数 +// req.input_tokens = input_tokens; +// req.image_data = image_data; +// req.kv_caches = kv_caches; +// req.batch_size = batch_size; +// req.ntok = batch_size; // 简化:假设每个请求只有一个token +// req.nreq = 1; // 简化:假设只有一个请求 +// req.output = output; + +// // 2. 启动所有设备线程 +// auto ndev = dev_resources.size(); +// for (size_t i = 0; i < ndev; i++) { +// std::unique_lock lock(states[i].mtx); +// states[i].proceed = true; +// lock.unlock(); +// states[i].cv_stage.notify_one(); // 发出推理开始信号 +// } + +// // 3. 等待所有设备完成 +// for (size_t i = 0; i < ndev; i++) { +// std::unique_lock lock(states[i].mtx); +// states[i].cv_stage.wait(lock, [&] { return !(states[i].proceed); }); +// lock.unlock(); +// } + +// // 4. 清理请求参数 +// req.input_tokens = nullptr; +// req.image_data = nullptr; +// req.kv_caches = nullptr; +// req.output = nullptr; +// } + +// 模仿jiuge.cpp的LlavaModel constructor +LlavaModel::LlavaModel(const LlavaMeta *_meta, const LlavaWeights *weights, + infiniDevice_t device_, std::vector device_ids) : meta(*_meta) { + int ndev = int(device_ids.size()); + device = device_; + dev_ids = device_ids; + dev_resources = std::vector(ndev); // 每个设备的资源 + states = std::vector(ndev); // 每个设备的状态 + threads.resize(ndev); // 每个设备的线程 + + RUN_INFINI(infinirtInit()); + + auto comms = std::vector(ndev, nullptr); + if (ndev > 1) { + RUN_INFINI(infinicclCommInitAll(device, comms.data(), ndev, dev_ids.data())); + } + + // 🧵🧵🧵 这里创建线程! + for (int i = 0; i < ndev; i++) { + threads[i] = std::thread( + launchLlavaDevice, + std::cref(meta), + weights, + &dev_resources[i], + std::ref(states[i]), + std::ref(req), + device, + i, + ndev, + dev_ids[i], + comms[i]); + + // ⏳ 线程立即启动,进入launchLlavaDevice函数 + // 😴 在cv_stage.wait()处开始休眠等待 + } + + // 等待所有设备线程加载完成 - 使用cv_load与jiuge.cpp保持一致 + for (int i = 0; i < ndev; i++) { + std::unique_lock lock(states[i].mtx); + states[i].cv_stage.wait(lock, [&] { return states[i].loaded; }); + lock.unlock(); + } +} + + +// // 最简单的统一推理接口 +// void LlavaModel::inferBatchLlava(const uint32_t* input_tokens, const void* image_data, +// void** kv_caches, const char* mode, uint32_t batch_size, +// uint32_t* output) { +// // 暂时只是占位符实现 +// if (output && batch_size > 0) { +// output[0] = 1; // 返回一个简单的token +// } +// } + +// // 各阶段执行函数的占位符实现 +// void LlavaModel::executeVisionStage() { +// // 占位符 +// } + +// void LlavaModel::executePrefillStage() { +// // 占位符 +// } + +// void LlavaModel::executeCompressStage() { +// // 占位符 +// } + +// void LlavaModel::executeDecodeStage() { +// // 占位符 +// } + +// void LlavaModel::workerLoop() { +// // 占位符 +// } + + + + +// API implementations - 模仿jiuge.cpp的createJiugeModel +__C struct LlavaModel *createLlavaModel(const LlavaMeta *meta, + const LlavaWeights *weights, + infiniDevice_t device, + int ndev, + const int *dev_ids) { + std::vector device_ids_vec(ndev); + std::copy(dev_ids, dev_ids + ndev, device_ids_vec.begin()); + LlavaModel *model = new LlavaModel(meta, weights, device, device_ids_vec); + return model; +} + +__C void destroyLlavaModel(struct LlavaModel *model) { + if (!model) { + return; + } + + auto ndev = model->dev_resources.size(); + + // 通知所有设备线程退出 + for (size_t idev = 0; idev < ndev; idev++) { + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].exit_flag = true; + lock.unlock(); + model->states[idev].cv_stage.notify_one(); + } + + // 等待所有线程结束 + for (size_t idev = 0; idev < ndev; idev++) { + model->threads[idev].join(); + } + + delete model; +} + +// C API: 批量视觉编码(用于Python接口) +__C void inferBatchLlavaVison(struct LlavaModel *model, + const void *image_data, + void *output) { + if (!model || !image_data || !output) { + return; + } + + // 1. 设置推理参数(模仿inferBatchJiuge) + // TODO: 感觉这里的req结构可能要逐渐改的像 struct InferRequest + model->req.input_tokens = nullptr; // vision encoding不需要input_tokens + model->req.image_data = image_data; + model->req.kv_caches = nullptr; // vision encoding不需要kv_caches + model->req.batch_size = 1; // 简化:假设batch_size为1 + model->req.ntok = 0; // vision encoding不需要tokens + model->req.nreq = 1; // 简化:假设一个请求 + model->req.output = output; + model->req.vision_stage = LLAVA_VISION_STAGE_PROJECTOR; + + // Current vision path does not support tensor-parallel; prevent multi-device races on output. + ASSERT_EQ(model->dev_ids.size(), size_t(1)); + + if (llava_debug_enabled()) { + auto vision_embed_dim = model->meta.vision_meta.vision_embed_dim; + auto num_patches = model->meta.vision_meta.num_patches; + auto total_features = vision_embed_dim * num_patches; + printf("inferBatchLlavaVison called: image_data=%p, output=%p\n", image_data, output); + printf("Vision config: embed_dim=%zu, num_patches=%zu, total_features=%zu\n", + vision_embed_dim, num_patches, total_features); + } + + + // 2. 通知所有设备线程开始工作(模仿inferBatchJiuge) + // TODO: 注意,和jiuge不一样的地方在于,我们这里现在只有一个信号量 + { + const size_t idev = 0; + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].proceed = true; + lock.unlock(); + model->states[idev].cv_stage.notify_one(); + } + + // 3. 等待所有设备线程完成工作(模仿inferBatchJiuge) + { + const size_t idev = 0; + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].cv_stage.wait(lock, [&] { return !(model->states[idev].proceed); }); + lock.unlock(); + } + + if (llava_debug_enabled()) { + printf("inferBatchLlavaVison: vision encoding completed\n"); + } +} + +__C void inferBatchLlavaVisionStage(struct LlavaModel *model, + const void *image_data, + uint32_t stage, + void *output) { + if (!model || !image_data || !output) { + return; + } + + model->req.input_tokens = nullptr; + model->req.image_data = image_data; + model->req.kv_caches = nullptr; + model->req.batch_size = 1; + model->req.ntok = 0; + model->req.nreq = 1; + model->req.output = output; + model->req.vision_stage = stage; + + ASSERT_EQ(model->dev_ids.size(), size_t(1)); + { + const size_t idev = 0; + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].proceed = true; + lock.unlock(); + model->states[idev].cv_stage.notify_one(); + } + + { + const size_t idev = 0; + std::unique_lock lock(model->states[idev].mtx); + model->states[idev].cv_stage.wait(lock, [&] { return !(model->states[idev].proceed); }); + lock.unlock(); + } +} + +// 暂时注释掉其他复杂的API函数,只保留最基本的 diff --git a/src/models/llava/llava_impl.hpp b/src/models/llava/llava_impl.hpp new file mode 100644 index 00000000..1a5cc0c6 --- /dev/null +++ b/src/models/llava/llava_impl.hpp @@ -0,0 +1,140 @@ +#ifndef LLAVA_IMPL_HPP +#define LLAVA_IMPL_HPP + +#include "infinicore_infer/models/llava.h" +#include "../../allocator.hpp" +#include "../../tensor.hpp" +#include "../../cache.hpp" // 添加KV Cache支持 + +#include +#include +#include +#include + + + +// 设备资源结构 - 统一线程架构只需要一套resource +struct LlavaDeviceResource { + infiniDevice_t device; + int device_id; + infiniopHandle_t handle; + + // Language Model Weights (复用jiuge结构) + std::shared_ptr w_in_embd, w_out_norm, w_out_embd, sin_table, + cos_table; + std::vector> w_attn_norm, w_attn_qkv, b_attn_qkv, w_attn_q_norm, w_attn_k_norm,w_attn_out, + w_ffn_norm, w_ffn_gate_up, w_ffn_down; + + // === Vision Encoder Weights === + // Patch Embedding Conv2d + std::shared_ptr vision_patch_embed_weight; // [1024, 3, 14, 14] + + // Position Embedding + std::shared_ptr vision_position_embedding; // [1, 577, 1024] + + // Class Token + std::shared_ptr vision_class_token; // [1, 1024] + + // pre and post LayerNorm weights and biases + std::shared_ptr vision_pre_layernorm_weight; // [1024] + std::shared_ptr vision_pre_layernorm_bias; // [1024] + std::shared_ptr vision_post_layernorm_weight; // [1024] + std::shared_ptr vision_post_layernorm_bias; // [1024] + + // qkv weights and biases for Vision Transformer Layers + std::vector> vision_q_weights, vision_q_biases, + vision_k_weights, vision_k_biases, + vision_v_weights, vision_v_biases, + vision_in_layer_pre_norm_weights, vision_in_layer_pre_norm_biases, + vision_proj_weight, vision_proj_bias, + vision_in_layer_post_norm_weight, vision_post_norm_bias, + vision_mlp_fc1_weight, vision_mlp_fc1_bias, + vision_mlp_fc2_weight, vision_mlp_fc2_bias; + + // MultiModal Projector Weights + std::shared_ptr projector_weight_1; + std::shared_ptr projector_bias_1; + std::shared_ptr projector_weight_2; + std::shared_ptr projector_bias_2; + + // Vision Transformer Layers (复用language结构存储) + // 注意:这里先只实现patch embedding部分 + + // Streams + infinirtStream_t stream; + // Communicator + infinicclComm_t comm; + + std::shared_ptr memory_pool; +}; + +// 最简单的推理状态结构 +struct LlavaInferState { + std::mutex mtx; + std::condition_variable cv_stage; // 使用cv_stage进行同步 + + bool proceed = false; + bool exit_flag = false; + int current_stage = 0; + bool stage_completed = false; + bool error_occurred = false; + std::string error_message; + + // 添加loaded标志,与jiuge.cpp保持一致 + bool loaded = false; + + const uint32_t* input_tokens = nullptr; + const void* image_data = nullptr; + void** kv_caches = nullptr; + uint32_t ntok = 0; + uint32_t nreq = 0; + uint32_t batch_size = 0; + void* output = nullptr; + void* vision_features = nullptr; + void* projected_features = nullptr; +}; + +// // 推理请求结构 +// struct LlavaRequest { +// const uint32_t* input_tokens; +// const void* image_data; +// void** kv_caches; +// uint32_t ntok; +// uint32_t nreq; +// uint32_t batch_size; +// uint32_t* output; +// } req; + +// TODO: 想想这里需要啥 +struct LlavaRequest { + const uint32_t* input_tokens; + const void* image_data; + void** kv_caches; + uint32_t ntok; + uint32_t nreq; + uint32_t batch_size; + void* output; + uint32_t vision_stage = 0; +}; + +struct LlavaModel { + LlavaMeta meta; + infiniDevice_t device; + std::vector dev_ids; + std::vector dev_resources; + std::vector states; + std::vector threads; // 添加线程向量 + LlavaRequest req; + + LlavaModel(const LlavaMeta *, const LlavaWeights *, + infiniDevice_t device, std::vector device_ids); + // ~LlavaModel(); + + // // LLaVA四阶段统一推理接口 + // void inferBatchLlava(const uint32_t* input_tokens, const void* image_data, + // void** kv_caches, uint32_t batch_size, + // uint32_t* output); + +}; + +#endif diff --git a/src/models/llava/llava_weight.hpp b/src/models/llava/llava_weight.hpp new file mode 100644 index 00000000..cea183b4 --- /dev/null +++ b/src/models/llava/llava_weight.hpp @@ -0,0 +1,414 @@ +#ifndef LLAVA_WEIGHT_HPP +#define LLAVA_WEIGHT_HPP + +#include "llava_impl.hpp" +#include "../inference_context.hpp" +#include "infinicore_infer/models/llava.h" + +#include +#include // for memcpy +#include + +inline bool llava_debug_enabled_in_weighthpp() { + static int cached = -1; + if (cached == -1) { + const char *env = std::getenv("LLAVA_DEBUG"); + cached = (env != nullptr && std::strcmp(env, "0") != 0) ? 1 : 0; + } + return cached != 0; +} + +// Vision weight getters +inline std::shared_ptr getPatchEmbedWeight( + LlavaMeta const *meta, + LlavaWeights const *weights) { + // 从meta中获取vision embedding参数 + auto vision_embed_dim = meta->vision_meta.vision_embed_dim; // 输出通道数 [1024] + auto patch_size = meta->vision_meta.patch_size; // 卷积核大小 [14] + + // 对于RGB图像,输入通道数总是3 + const size_t input_channels = 3; + + // Patch embedding卷积核形状: [vision_embed_dim, input_channels, patch_size, patch_size] + auto shape = std::vector{vision_embed_dim, input_channels, patch_size, patch_size}; + + if (llava_debug_enabled_in_weighthpp()) { + printf("[CPP getPatchEmbedWeight] vision_patch_embed_weight pointer: %p\n", weights->vision_patch_embed_weight); + } + auto vision_patch_embed_device_tensor = + Tensor::weight( + (char *)weights->vision_patch_embed_weight, // 权重数据指针 + meta->language_meta.dt_logits, + shape + ); + + return vision_patch_embed_device_tensor; +} + +// 创建position embedding (从meta中获取形状) +inline std::shared_ptr createPositionEmbedding(LlavaMeta const *meta, + LlavaWeights const *weights) { + // 从meta中获取参数 + auto vision_embed_dim = meta->vision_meta.vision_embed_dim; + auto num_patches = meta->vision_meta.num_patches; + + // CLIP ViT通常还需要class token,所以位置编码长度是 num_patches + 1 + auto pos_embed_length = num_patches + 1; // 576 + 1 = 577 + + if (llava_debug_enabled_in_weighthpp()) { + printf("[CPP createPositionEmbedding] Shape: [1, %zu, %zu]\n", pos_embed_length, vision_embed_dim); + } + + return Tensor::weight((char *)weights->vision_position_embedding, INFINI_DTYPE_F16, {1, pos_embed_length, vision_embed_dim}); +} + +// 创建class token (从meta中获取形状) +inline std::shared_ptr getClassToken(LlavaMeta const *meta, + LlavaWeights const *weights) { + auto vision_embed_dim = meta->vision_meta.vision_embed_dim; + + if (llava_debug_enabled_in_weighthpp()) { + printf("[CPP getClassToken] vision_class_token pointer: %p\n", weights->vision_class_token); + } + auto vision_class_token_device_tensor = + Tensor::weight((char *)weights->vision_class_token, + INFINI_DTYPE_F16, + {vision_embed_dim}); + + // printf("[CPP getClassToken] First 10 values: \n"); + // vision_class_token_device_tensor->debug_first_n(10); + + return vision_class_token_device_tensor; +} + +inline std::shared_ptr getVisionQWeight( + LlavaMeta const *meta, + LlavaWeights const *weights, + size_t layer) { + auto vision_embed_dim = meta->vision_meta.vision_embed_dim; + + return Tensor::weight( + (char *)weights->vision_q_weights[layer], + INFINI_DTYPE_F16, + {vision_embed_dim, vision_embed_dim} + ); +} +inline std::shared_ptr getVisionQBias( + LlavaMeta const *meta, + LlavaWeights const *weights, + size_t layer) { + auto vision_embed_dim = meta->vision_meta.vision_embed_dim; + + return Tensor::weight( + (char *)weights->vision_q_biases[layer], + INFINI_DTYPE_F16, + {vision_embed_dim} + ); +} + +inline std::shared_ptr getVisionKWeight( + LlavaMeta const *meta, + LlavaWeights const *weights, + size_t layer) { + auto vision_embed_dim = meta->vision_meta.vision_embed_dim; + + return Tensor::weight( + (char *)weights->vision_k_weights[layer], + INFINI_DTYPE_F16, + {vision_embed_dim, vision_embed_dim} + ); +} +inline std::shared_ptr getVisionKBias( + LlavaMeta const *meta, + LlavaWeights const *weights, + size_t layer) { + auto vision_embed_dim = meta->vision_meta.vision_embed_dim; + + return Tensor::weight( + (char *)weights->vision_k_biases[layer], + INFINI_DTYPE_F16, + {vision_embed_dim} + ); +} +inline std::shared_ptr getVisionVWeight( + LlavaMeta const *meta, + LlavaWeights const *weights, + size_t layer) { + auto vision_embed_dim = meta->vision_meta.vision_embed_dim; + + return Tensor::weight( + (char *)weights->vision_v_weights[layer], + INFINI_DTYPE_F16, + {vision_embed_dim, vision_embed_dim} + ); +} +inline std::shared_ptr getVisionVBias( + LlavaMeta const *meta, + LlavaWeights const *weights, + size_t layer) { + auto vision_embed_dim = meta->vision_meta.vision_embed_dim; + + return Tensor::weight( + (char *)weights->vision_v_biases[layer], + INFINI_DTYPE_F16, + {vision_embed_dim} + ); +} + +inline std::shared_ptr getVisionPreLNWeight( + LlavaMeta const *meta, + LlavaWeights const *weights) { + + if (llava_debug_enabled_in_weighthpp()) { + printf("[CPP getVisionPreLNWeight] vision_pre_layernorm_weight pointer: %p\n", weights->vision_pre_layernorm_weight); + } + auto dim = meta->vision_meta.vision_embed_dim; + + return Tensor::weight( + (char *)weights->vision_pre_layernorm_weight, + INFINI_DTYPE_F16, + {dim} + ); +} + +inline std::shared_ptr getVisionPreLNBias( + LlavaMeta const *meta, + LlavaWeights const *weights) { + + auto dim = meta->vision_meta.vision_embed_dim; + + return Tensor::weight( + (char *)weights->vision_pre_layernorm_bias, + INFINI_DTYPE_F16, + {dim} + ); +} + +inline std::shared_ptr getVisionPostLNWeight( + LlavaMeta const *meta, + LlavaWeights const *weights) { + + auto dim = meta->vision_meta.vision_embed_dim; + + return Tensor::weight( + (char *)weights->vision_post_layernorm_weight, + INFINI_DTYPE_F16, + {dim} + ); +} + +inline std::shared_ptr getVisionPostLNBias( + LlavaMeta const *meta, + LlavaWeights const *weights) { + + auto dim = meta->vision_meta.vision_embed_dim; + + return Tensor::weight( + (char *)weights->vision_post_layernorm_bias, + INFINI_DTYPE_F16, + {dim} + ); +} + +inline std::shared_ptr getVisionInLayerPreNormWeight( + LlavaMeta const *meta, LlavaWeights const *weights, size_t layer) { + auto dim = meta->vision_meta.vision_embed_dim; + return Tensor::weight((char *)weights->vision_in_layer_pre_norm_weights[layer], + INFINI_DTYPE_F16, {dim}); +} + +inline std::shared_ptr getVisionInLayerPreNormBias( + LlavaMeta const *meta, LlavaWeights const *weights, size_t layer) { + auto dim = meta->vision_meta.vision_embed_dim; + return Tensor::weight((char *)weights->vision_in_layer_pre_norm_biases[layer], + INFINI_DTYPE_F16, {dim}); +} + + + +inline std::shared_ptr getVisionProjWeight( + LlavaMeta const *meta, LlavaWeights const *weights, size_t layer) { + auto dim = meta->vision_meta.vision_embed_dim; + return Tensor::weight((char *)weights->vision_proj_weight[layer], + INFINI_DTYPE_F16, {dim, dim}); +} + +inline std::shared_ptr getVisionProjBias( + LlavaMeta const *meta, LlavaWeights const *weights, size_t layer) { + auto dim = meta->vision_meta.vision_embed_dim; + return Tensor::weight((char *)weights->vision_proj_bias[layer], + INFINI_DTYPE_F16, {dim}); +} + + +inline std::shared_ptr getVisionInLayerPostNormWeight( + LlavaMeta const *meta, LlavaWeights const *weights, size_t layer) { + auto dim = meta->vision_meta.vision_embed_dim; + // printf("[CPP vision_in_layer_post_norm_weight] layer: %zu, pointer: %p\n", layer, weights->vision_in_layer_post_norm_weight[layer]); + return Tensor::weight((char *)weights->vision_in_layer_post_norm_weight[layer], + INFINI_DTYPE_F16, {dim}); +} + +inline std::shared_ptr getVisionInLayerPostNormBias( + LlavaMeta const *meta, LlavaWeights const *weights, size_t layer) { + auto dim = meta->vision_meta.vision_embed_dim; + // printf("[CPP vision_post_norm_bias] layer: %zu, pointer: %p\n", layer, weights->vision_post_norm_bias[layer]); + return Tensor::weight((char *)weights->vision_post_norm_bias[layer], + INFINI_DTYPE_F16, {dim}); +} + + +inline std::shared_ptr getVisionMLPFC1Weight( + LlavaMeta const *meta, LlavaWeights const *weights, size_t layer) { + auto dim = meta->vision_meta.vision_embed_dim; + auto mlp = meta->vision_meta.vision_intermediate_size; + return Tensor::weight((char *)weights->vision_mlp_fc1_weight[layer], + INFINI_DTYPE_F16, {mlp, dim}); +} + +inline std::shared_ptr getVisionMLPFC1Bias( + LlavaMeta const *meta, LlavaWeights const *weights, size_t layer) { + auto mlp = meta->vision_meta.vision_intermediate_size; + return Tensor::weight((char *)weights->vision_mlp_fc1_bias[layer], + INFINI_DTYPE_F16, {mlp}); +} + + +inline std::shared_ptr getVisionMLPFC2Weight( + LlavaMeta const *meta, LlavaWeights const *weights, size_t layer) { + auto dim = meta->vision_meta.vision_embed_dim; + auto mlp = meta->vision_meta.vision_intermediate_size; + return Tensor::weight((char *)weights->vision_mlp_fc2_weight[layer], + INFINI_DTYPE_F16, {dim, mlp}); +} + +inline std::shared_ptr getVisionMLPFC2Bias( + LlavaMeta const *meta, LlavaWeights const *weights, size_t layer) { + auto dim = meta->vision_meta.vision_embed_dim; + return Tensor::weight((char *)weights->vision_mlp_fc2_bias[layer], + INFINI_DTYPE_F16, {dim}); +} + +// MultiModal Projector (two-layer MLP) +inline std::shared_ptr getProjectorWeight1( + LlavaMeta const *meta, LlavaWeights const *weights) { + auto vision_dim = meta->projector_meta.vision_embed_dim; + auto hidden_dim = meta->projector_meta.projector_hidden_size; + return Tensor::weight((char *)weights->projector_weight_1, + INFINI_DTYPE_F16, {hidden_dim, vision_dim}); +} + +inline std::shared_ptr getProjectorBias1( + LlavaMeta const *meta, LlavaWeights const *weights) { + auto hidden_dim = meta->projector_meta.projector_hidden_size; + return Tensor::weight((char *)weights->projector_bias_1, + INFINI_DTYPE_F16, {hidden_dim}); +} + +inline std::shared_ptr getProjectorWeight2( + LlavaMeta const *meta, LlavaWeights const *weights) { + auto text_dim = meta->projector_meta.text_embed_dim; + auto hidden_dim = meta->projector_meta.projector_hidden_size; + return Tensor::weight((char *)weights->projector_weight_2, + INFINI_DTYPE_F16, {text_dim, hidden_dim}); +} + +inline std::shared_ptr getProjectorBias2( + LlavaMeta const *meta, LlavaWeights const *weights) { + auto text_dim = meta->projector_meta.text_embed_dim; + return Tensor::weight((char *)weights->projector_bias_2, + INFINI_DTYPE_F16, {text_dim}); +} + + +// inline std::shared_ptr createClassEmbedding(LlavaMeta const *meta) { +// auto vision_embed_dim = meta->vision_meta.vision_embed_dim; + +// printf("[CPP createClassEmbedding] Shape: [1, %zu]\n", vision_embed_dim); + +// std::vector class_embedding_data(vision_embed_dim, 0.0f); +// return Tensor::weight(class_embedding_data.data(), INFINI_DTYPE_F16, {1, vision_embed_dim}); +// } + +// inline std::shared_ptr getProjectorWeight(const LlavaWeights *weights, size_t text_dim, size_t vision_dim) { +// return std::make_shared( +// std::vector{text_dim, vision_dim}, +// weights->projector_weight, +// DT_F16, DEVICE_CPU, 0 +// ); +// } + +// // Reuse Jiuge weight getters for language model +// inline std::shared_ptr getAttnNorm(const LlavaLanguageMeta *meta, const LlavaWeights *weights, size_t layer) { +// return std::make_shared( +// std::vector{meta->d}, +// weights->attn_norm[layer], +// meta->dt_norm, DEVICE_CPU, 0 +// ); +// } + +// inline std::shared_ptr getAttnQKV(const LlavaLanguageMeta *meta, const LlavaWeights *weights, size_t layer, int idev, int ndev) { +// return std::make_shared( +// std::vector{(meta->nh + 2 * meta->nkvh) / ndev, meta->dh, meta->d}, +// weights->attn_qkv[layer], +// meta->dt_mat, DEVICE_CPU, 0 +// ); +// } + +// inline std::shared_ptr getAttnQKVBias(const LlavaLanguageMeta *meta, const LlavaWeights *weights, size_t layer, int idev, int ndev) { +// return std::make_shared( +// std::vector{(meta->nh + 2 * meta->nkvh) / ndev, meta->dh}, +// weights->attn_qkv_b[layer], +// meta->dt_mat, DEVICE_CPU, 0 +// ); +// } + +// inline std::shared_ptr getAttnQNorm(const LlavaLanguageMeta *meta, const LlavaWeights *weights, size_t layer) { +// return std::make_shared( +// std::vector{meta->dh}, +// weights->attn_q_norm[layer], +// meta->dt_norm, DEVICE_CPU, 0 +// ); +// } + +// inline std::shared_ptr getAttnKNorm(const LlavaLanguageMeta *meta, const LlavaWeights *weights, size_t layer) { +// return std::make_shared( +// std::vector{meta->dh}, +// weights->attn_k_norm[layer], +// meta->dt_norm, DEVICE_CPU, 0 +// ); +// } + +// inline std::shared_ptr getAttnO(const LlavaLanguageMeta *meta, const LlavaWeights *weights, size_t layer, int idev, int ndev) { +// return std::make_shared( +// std::vector{meta->d, meta->nkvh / ndev * meta->dh}, +// weights->attn_o[layer], +// meta->dt_mat, DEVICE_CPU, 0 +// ); +// } + +// inline std::shared_ptr getFFNNorm(const LlavaLanguageMeta *meta, const LlavaWeights *weights, size_t layer) { +// return std::make_shared( +// std::vector{meta->d}, +// weights->ffn_norm[layer], +// meta->dt_norm, DEVICE_CPU, 0 +// ); +// } + +// inline std::shared_ptr getFFNGateUp(const LlavaLanguageMeta *meta, const LlavaWeights *weights, size_t layer, int idev, int ndev) { +// return std::make_shared( +// std::vector{2 * meta->di / ndev, meta->d}, +// weights->ffn_gate_up[layer], +// meta->dt_mat, DEVICE_CPU, 0 +// ); +// } + +// inline std::shared_ptr getFFNDown(const LlavaLanguageMeta *meta, const LlavaWeights *weights, size_t layer, int idev, int ndev) { +// return std::make_shared( +// std::vector{meta->d, meta->di / ndev}, +// weights->ffn_down[layer], +// meta->dt_mat, DEVICE_CPU, 0 +// ); +// } + +#endif diff --git a/src/models/minicpmv/minicpmv.cpp b/src/models/minicpmv/minicpmv.cpp new file mode 100644 index 00000000..0afb88ba --- /dev/null +++ b/src/models/minicpmv/minicpmv.cpp @@ -0,0 +1,520 @@ +#include "minicpmv_impl.hpp" + +#include "infinicore_infer.h" +#include "ref_ops.hpp" +#include "ref_pos_embed.hpp" + +#include "../inference_context.hpp" +#include "../../cache.hpp" + +#include +#include +#include +#include +#include +#include + +__C struct MiniCPMVModel * +createMiniCPMVModel(const MiniCPMVMeta *meta, + const MiniCPMVWeights *weights, + infiniDevice_t device, + int ndev, + const int *dev_ids) { + std::vector device_ids(ndev); + std::copy(dev_ids, dev_ids + ndev, device_ids.begin()); + auto *model = new MiniCPMVModel(meta, weights, device, std::move(device_ids)); + ASSERT_EQ(model->dev_ids.size(), size_t(1)); + + RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[0])); + RUN_INFINI(infiniopCreateHandle(&model->op_handle)); + RUN_INFINI(infinirtStreamCreate(&model->stream)); + return model; +} + +__C void destroyMiniCPMVModel(struct MiniCPMVModel *model) { + if (!model) { + return; + } + if (model->stream) { + infinirtStreamDestroy(model->stream); + model->stream = nullptr; + } + if (model->op_handle) { + infiniopDestroyHandle(model->op_handle); + model->op_handle = nullptr; + } + delete model; +} + +static std::shared_ptr make_host_pos_embed(infiniDtype_t dtype, + size_t embed_dim, + uint32_t tgt_h, + uint32_t tgt_w) { + auto pos_f32 = minicpmv::ref_pos_embed::make_2d_sincos_pos_embed(embed_dim, tgt_h, tgt_w); + const size_t n = static_cast(tgt_h) * static_cast(tgt_w); + + if (dtype == INFINI_DTYPE_F32) { + return Tensor::weight(pos_f32.data(), INFINI_DTYPE_F32, {n, embed_dim}); + } + + std::vector packed(n * embed_dim); + for (size_t i = 0; i < n * embed_dim; ++i) { + packed[i] = (dtype == INFINI_DTYPE_BF16) ? f32_to_bf16(pos_f32[i]) : f32_to_f16(pos_f32[i]); + } + return Tensor::weight(packed.data(), dtype, {n, embed_dim}); +} + +static uint32_t bucketize_pos(uint32_t idx, uint32_t nb_patches, uint32_t num_patches_per_side) { + // Equivalent to torch.bucketize(i/nb, boundaries=arange(1/N..), right=True) + // boundaries are uniform, so this becomes floor(i * N / nb). + return (idx * num_patches_per_side) / nb_patches; +} + +static std::vector build_siglip_pos_embed(const void *pos_table, + infiniDtype_t dt, + size_t embed_dim, + uint32_t tgt_h, + uint32_t tgt_w, + uint32_t num_patches_per_side) { + ASSERT_VALID_PTR(pos_table); + const size_t seq_len = static_cast(tgt_h) * static_cast(tgt_w); + const size_t unit = dsize(dt); + std::vector out(seq_len * embed_dim * unit); + + const auto *src = reinterpret_cast(pos_table); + auto *dst = out.data(); + + for (uint32_t ih = 0; ih < tgt_h; ++ih) { + const uint32_t bh = bucketize_pos(ih, tgt_h, num_patches_per_side); + for (uint32_t iw = 0; iw < tgt_w; ++iw) { + const uint32_t bw = bucketize_pos(iw, tgt_w, num_patches_per_side); + const uint32_t pos_id = bh * num_patches_per_side + bw; // [0, 4899] + const size_t row = static_cast(ih) * static_cast(tgt_w) + iw; + std::memcpy(dst + (row * embed_dim) * unit, + src + (static_cast(pos_id) * embed_dim) * unit, + embed_dim * unit); + } + } + return out; +} + +__C void inferMiniCPMVSiglipEmbeddings(struct MiniCPMVModel *model, + const void *pixel_values, + size_t seq_len, + uint32_t tgt_h, + uint32_t tgt_w, + void *output) { + ASSERT_VALID_PTR(model); + ASSERT_VALID_PTR(pixel_values); + ASSERT_VALID_PTR(output); + ASSERT_EQ(model->dev_ids.size(), size_t(1)); + RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[0])); + + const auto &vm = model->meta.vision_meta; + const auto dt = model->meta.language_meta.dt_logits; + ASSERT_EQ(seq_len, static_cast(tgt_h) * static_cast(tgt_w)); + ASSERT_EQ(vm.patch_size, size_t(14)); + ASSERT_EQ(vm.vision_num_positions, size_t(4900)); + + ASSERT_VALID_PTR(model->weights); + ASSERT_VALID_PTR(model->weights->vpm_patch_embedding_weight); + ASSERT_VALID_PTR(model->weights->vpm_patch_embedding_bias); + ASSERT_VALID_PTR(model->weights->vpm_position_embedding); + + CacheManager cache_manager(100); + auto memory_pool = std::make_shared(256 * 1024 * 1024); + InferenceContext ctx(model->op_handle, memory_pool, &cache_manager, model->stream); + setInferenceContext(&ctx); + + // Input: [1, 3, patch, seq_len*patch] + auto x = Tensor::weight(const_cast(pixel_values), dt, + {1, 3, vm.patch_size, seq_len * vm.patch_size}); + auto w = Tensor::weight(const_cast(model->weights->vpm_patch_embedding_weight), + dt, {vm.vision_embed_dim, 3, vm.patch_size, vm.patch_size}); + auto b = Tensor::weight(const_cast(model->weights->vpm_patch_embedding_bias), + dt, {vm.vision_embed_dim}); + + auto y = Tensor::buffer(dt, {1, vm.vision_embed_dim, 1, seq_len}, memory_pool); + std::vector pads{0, 0}; + std::vector strides{vm.patch_size, vm.patch_size}; + std::vector dilations{1, 1}; + conv2d(y, x, w, b, pads, strides, dilations); + RUN_INFINI(infinirtStreamSynchronize(model->stream)); + + // Create output [seq_len, embed_dim] and copy from NCHW conv output. + auto out = Tensor::buffer(dt, {seq_len, vm.vision_embed_dim}, memory_pool); + // y is [1, C, 1, W] contiguous. View it as [C, W] and rearrange into out's [C, W] view. + auto out_cw = out->permute({1, 0}); // [C, W] (strided view) + auto y_cw = y->view_as({vm.vision_embed_dim, seq_len}, + {static_cast(seq_len), 1}); + rearrange(out_cw, y_cw); + + // Add position embedding (bucketized lookup from the 70x70 table, built on host and uploaded). + const uint32_t N = static_cast(vm.vision_image_size / vm.patch_size); // 70 + auto pos_host = build_siglip_pos_embed(model->weights->vpm_position_embedding, + dt, + vm.vision_embed_dim, + tgt_h, + tgt_w, + N); + auto pos = Tensor::weight(pos_host.data(), dt, {seq_len, vm.vision_embed_dim}); + add(out, out, pos); + RUN_INFINI(infinirtStreamSynchronize(model->stream)); + + RUN_INFINI(infinirtMemcpy(output, out->data(), out->numel() * dsize(dt), INFINIRT_MEMCPY_D2H)); + + setInferenceContext(nullptr); +} + +__C void inferMiniCPMVSiglipLayer(struct MiniCPMVModel *model, + uint32_t layer_idx, + const void *hidden_states, + size_t seq_len, + void *output) { + ASSERT_VALID_PTR(model); + ASSERT_VALID_PTR(hidden_states); + ASSERT_VALID_PTR(output); + ASSERT_EQ(model->dev_ids.size(), size_t(1)); + RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[0])); + + const auto &vm = model->meta.vision_meta; + const auto dt = model->meta.language_meta.dt_logits; + ASSERT_EQ(vm.vision_embed_dim % vm.vision_num_heads, size_t(0)); + ASSERT(layer_idx < vm.vision_num_layers); + const size_t nh = vm.vision_num_heads; + const size_t d = vm.vision_embed_dim; + const size_t dh = d / nh; + const float scale = 1.0f / std::sqrt(static_cast(dh)); + + ASSERT_VALID_PTR(model->weights); + ASSERT_VALID_PTR(model->weights->vpm_layers); + const MiniCPMVSiglipLayerWeights *lw = &model->weights->vpm_layers[layer_idx]; + + CacheManager cache_manager(100); + auto memory_pool = std::make_shared(512 * 1024 * 1024); + InferenceContext ctx(model->op_handle, memory_pool, &cache_manager, model->stream); + setInferenceContext(&ctx); + + auto x = Tensor::weight(const_cast(hidden_states), dt, {seq_len, d}); + auto y = Tensor::buffer(dt, {seq_len, d}, memory_pool); + rearrange(y, x); + RUN_INFINI(infinirtStreamSynchronize(model->stream)); + + // LN1 + auto ln1_w = Tensor::weight(const_cast(lw->layer_norm1_weight), dt, {d}); + auto ln1_b = Tensor::weight(const_cast(lw->layer_norm1_bias), dt, {d}); + auto x_ln1 = Tensor::buffer(dt, {seq_len, d}, memory_pool); + layerNorm(x_ln1, y, ln1_w, ln1_b, model->meta.vision_meta.vision_layer_norm_eps); + + // Q/K/V projections. We expect weights are pre-transposed to [in_dim, out_dim] for GEMM. + auto wq = Tensor::weight(const_cast(lw->q_weight), dt, {d, d}); + auto bq = Tensor::weight(const_cast(lw->q_bias), dt, {d}); + auto wk = Tensor::weight(const_cast(lw->k_weight), dt, {d, d}); + auto bk = Tensor::weight(const_cast(lw->k_bias), dt, {d}); + auto wv = Tensor::weight(const_cast(lw->v_weight), dt, {d, d}); + auto bv = Tensor::weight(const_cast(lw->v_bias), dt, {d}); + + auto q = Tensor::buffer(dt, {seq_len, d}, memory_pool); + auto k = Tensor::buffer(dt, {seq_len, d}, memory_pool); + auto v = Tensor::buffer(dt, {seq_len, d}, memory_pool); + linear(q, x_ln1, wq, 1.0f, 0.0f, nullptr, bq); + linear(k, x_ln1, wk, 1.0f, 0.0f, nullptr, bk); + linear(v, x_ln1, wv, 1.0f, 0.0f, nullptr, bv); + RUN_INFINI(infinirtStreamSynchronize(model->stream)); + + // Attention per-head using slices to avoid non-contiguous view issues. + auto attn_out = Tensor::buffer(dt, {seq_len, d}, memory_pool); + auto scores = Tensor::buffer(dt, {seq_len, seq_len}, memory_pool); + auto out_h = Tensor::buffer(dt, {seq_len, dh}, memory_pool); + auto q_h_contig = Tensor::buffer(dt, {seq_len, dh}, memory_pool); + auto k_h_contig = Tensor::buffer(dt, {seq_len, dh}, memory_pool); + auto v_h_contig = Tensor::buffer(dt, {seq_len, dh}, memory_pool); + auto k_t_contig = Tensor::buffer(dt, {dh, seq_len}, memory_pool); + + for (size_t h = 0; h < nh; ++h) { + const size_t col = h * dh; + auto q_h = q->slice(1, col, dh); // [L, dh] (strided view) + auto k_h = k->slice(1, col, dh); // [L, dh] (strided view) + auto v_h = v->slice(1, col, dh); // [L, dh] (strided view) + rearrange(q_h_contig, q_h); + rearrange(k_h_contig, k_h); + rearrange(v_h_contig, v_h); + auto k_t_view = k_h_contig->permute({1, 0}); // [dh, L] (strided view) + rearrange(k_t_contig, k_t_view); // materialize to contiguous for GEMM + + gemm(scores, q_h_contig, k_t_contig, scale, 0.0f); // [L, L] + RUN_INFINI(infinirtStreamSynchronize(model->stream)); + Softmax(scores, scores, 1); + gemm(out_h, scores, v_h_contig, 1.0f, 0.0f); // [L, dh] + rearrange(attn_out->slice(1, col, dh), out_h); + } + RUN_INFINI(infinirtStreamSynchronize(model->stream)); + + // Out proj + residual + auto wo = Tensor::weight(const_cast(lw->out_weight), dt, {d, d}); + auto bo = Tensor::weight(const_cast(lw->out_bias), dt, {d}); + auto attn_proj = Tensor::buffer(dt, {seq_len, d}, memory_pool); + linear(attn_proj, attn_out, wo, 1.0f, 0.0f, nullptr, bo); + add(attn_proj, attn_proj, y); + RUN_INFINI(infinirtStreamSynchronize(model->stream)); + + // LN2 + auto ln2_w = Tensor::weight(const_cast(lw->layer_norm2_weight), dt, {d}); + auto ln2_b = Tensor::weight(const_cast(lw->layer_norm2_bias), dt, {d}); + auto x_ln2 = Tensor::buffer(dt, {seq_len, d}, memory_pool); + layerNorm(x_ln2, attn_proj, ln2_w, ln2_b, model->meta.vision_meta.vision_layer_norm_eps); + + // MLP: fc1 -> gelu_tanh -> fc2 -> residual + auto w1 = Tensor::weight(const_cast(lw->fc1_weight), dt, {d, vm.vision_intermediate_size}); + auto b1 = Tensor::weight(const_cast(lw->fc1_bias), dt, {vm.vision_intermediate_size}); + auto w2 = Tensor::weight(const_cast(lw->fc2_weight), dt, {vm.vision_intermediate_size, d}); + auto b2 = Tensor::weight(const_cast(lw->fc2_bias), dt, {d}); + + auto fc1 = Tensor::buffer(dt, {seq_len, vm.vision_intermediate_size}, memory_pool); + linear(fc1, x_ln2, w1, 1.0f, 0.0f, nullptr, b1); + RUN_INFINI(infinirtStreamSynchronize(model->stream)); + geluTanh(fc1, fc1); + + auto fc2 = Tensor::buffer(dt, {seq_len, d}, memory_pool); + linear(fc2, fc1, w2, 1.0f, 0.0f, nullptr, b2); + add(fc2, fc2, attn_proj); + RUN_INFINI(infinirtStreamSynchronize(model->stream)); + + RUN_INFINI(infinirtMemcpy(output, fc2->data(), fc2->numel() * dsize(dt), INFINIRT_MEMCPY_D2H)); + + setInferenceContext(nullptr); +} + +__C void inferMiniCPMVSiglipLayer0(struct MiniCPMVModel *model, + const void *hidden_states, + size_t seq_len, + void *output) { + inferMiniCPMVSiglipLayer(model, 0, hidden_states, seq_len, output); +} + +__C void inferMiniCPMVSiglipEncoder(struct MiniCPMVModel *model, + uint32_t num_layers, + const void *hidden_states, + size_t seq_len, + void *output) { + ASSERT_VALID_PTR(model); + ASSERT_VALID_PTR(hidden_states); + ASSERT_VALID_PTR(output); + + ASSERT_EQ(model->dev_ids.size(), size_t(1)); + + const auto &vm = model->meta.vision_meta; + const auto dt = model->meta.language_meta.dt_logits; + ASSERT(num_layers <= vm.vision_num_layers); + + ASSERT_VALID_PTR(model->weights); + ASSERT_VALID_PTR(model->weights->vpm_post_layernorm_weight); + ASSERT_VALID_PTR(model->weights->vpm_post_layernorm_bias); + + const size_t d = vm.vision_embed_dim; + const size_t bytes = seq_len * d * dsize(dt); + + std::vector buf_a(bytes); + std::vector buf_b(bytes); + std::memcpy(buf_a.data(), hidden_states, bytes); + + uint8_t *buf_in = buf_a.data(); + uint8_t *buf_out = buf_b.data(); + for (uint32_t i = 0; i < num_layers; ++i) { + inferMiniCPMVSiglipLayer(model, i, buf_in, seq_len, buf_out); + std::swap(buf_in, buf_out); + } + + // Apply post-layernorm. + if (model->device == INFINI_DEVICE_CPU) { + minicpmv::ref_ops::layer_norm_last_dim_raw( + output, + buf_in, + model->weights->vpm_post_layernorm_weight, + model->weights->vpm_post_layernorm_bias, + dt, + seq_len, + d, + vm.vision_layer_norm_eps); + return; + } + + RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[0])); + CacheManager cache_manager(100); + auto memory_pool = std::make_shared(128 * 1024 * 1024); + InferenceContext ctx(model->op_handle, memory_pool, &cache_manager, model->stream); + setInferenceContext(&ctx); + + auto x = Tensor::weight(buf_in, dt, {seq_len, d}); + auto y = Tensor::buffer(dt, {seq_len, d}, memory_pool); + auto ln_w = Tensor::weight(const_cast(model->weights->vpm_post_layernorm_weight), dt, {d}); + auto ln_b = Tensor::weight(const_cast(model->weights->vpm_post_layernorm_bias), dt, {d}); + layerNorm(y, x, ln_w, ln_b, vm.vision_layer_norm_eps); + RUN_INFINI(infinirtStreamSynchronize(model->stream)); + RUN_INFINI(infinirtMemcpy(output, y->data(), y->numel() * dsize(dt), INFINIRT_MEMCPY_D2H)); + + setInferenceContext(nullptr); +} + +__C void inferMiniCPMVResampler(struct MiniCPMVModel *model, + const void *x, + size_t seq_len, + uint32_t tgt_h, + uint32_t tgt_w, + void *output) { + ASSERT_VALID_PTR(model); + ASSERT_VALID_PTR(x); + ASSERT_VALID_PTR(output); + + ASSERT_EQ(model->dev_ids.size(), size_t(1)); + RUN_INFINI(infinirtSetDevice(model->device, model->dev_ids[0])); + + const auto &rm = model->meta.resampler_meta; + const auto dt = model->meta.language_meta.dt_logits; + ASSERT_EQ(rm.embed_dim, model->meta.language_meta.d); + ASSERT_EQ(rm.num_heads, model->meta.language_meta.nh); + ASSERT_EQ(seq_len, static_cast(tgt_h) * static_cast(tgt_w)); + + CacheManager cache_manager(100); + auto memory_pool = std::make_shared(256 * 1024 * 1024); + InferenceContext ctx(model->op_handle, memory_pool, &cache_manager, model->stream); + setInferenceContext(&ctx); + + // Load inputs/weights into Tensor objects on CPU "device" memory. + auto x_in = Tensor::weight(const_cast(x), dt, {seq_len, rm.kv_dim}); + auto kv_w = Tensor::weight(const_cast(model->weights->resampler_kv_proj_weight), dt, {rm.kv_dim, rm.embed_dim}); + auto q_param = Tensor::weight(const_cast(model->weights->resampler_query), dt, {rm.num_queries, rm.embed_dim}); + + auto ln_q_w = Tensor::weight(const_cast(model->weights->resampler_ln_q_weight), dt, {rm.embed_dim}); + auto ln_q_b = Tensor::weight(const_cast(model->weights->resampler_ln_q_bias), dt, {rm.embed_dim}); + auto ln_kv_w = Tensor::weight(const_cast(model->weights->resampler_ln_kv_weight), dt, {rm.embed_dim}); + auto ln_kv_b = Tensor::weight(const_cast(model->weights->resampler_ln_kv_bias), dt, {rm.embed_dim}); + auto ln_post_w = Tensor::weight(const_cast(model->weights->resampler_ln_post_weight), dt, {rm.embed_dim}); + auto ln_post_b = Tensor::weight(const_cast(model->weights->resampler_ln_post_bias), dt, {rm.embed_dim}); + + // x_proj = x_in @ kv_w + auto x_proj = Tensor::buffer(dt, {seq_len, rm.embed_dim}, memory_pool); + gemm(x_proj, x_in, kv_w, 1.0f, 0.0f); + RUN_INFINI(infinirtStreamSynchronize(model->stream)); + + // ln_kv(x_proj) + layerNorm(x_proj, x_proj, ln_kv_w, ln_kv_b, rm.layer_norm_eps); + + // In the reference implementation, pos_embed is added to KEY only: + // attn(q, key=x+pos, value=x) + auto x_val = x_proj; + auto x_key = Tensor::buffer(dt, {seq_len, rm.embed_dim}, memory_pool); + rearrange(x_key, x_val); + auto pos = make_host_pos_embed(dt, rm.embed_dim, tgt_h, tgt_w); + add(x_key, x_key, pos); + RUN_INFINI(infinirtStreamSynchronize(model->stream)); + + // q = ln_q(query) + auto q_ln = Tensor::buffer(dt, {rm.num_queries, rm.embed_dim}, memory_pool); + rearrange(q_ln, q_param); + RUN_INFINI(infinirtStreamSynchronize(model->stream)); + //minicpmv::ref_ops::layer_norm_last_dim(q_ln, q_ln, ln_q_w, ln_q_b, rm.layer_norm_eps); + layerNorm(q_ln, q_ln, ln_q_w, ln_q_b, rm.layer_norm_eps); + + // In-proj: use pre-transposed weights [D, 3D], then slice into q/k/v. + auto in_w_full = Tensor::weight(const_cast(model->weights->resampler_attn_in_proj_weight), + dt, {rm.embed_dim, 3 * rm.embed_dim}); + auto in_b_full = Tensor::weight(const_cast(model->weights->resampler_attn_in_proj_bias), + dt, {3 * rm.embed_dim}); + + auto w_q = in_w_full->slice(1, 0, rm.embed_dim); + auto w_k = in_w_full->slice(1, rm.embed_dim, rm.embed_dim); + auto w_v = in_w_full->slice(1, 2 * rm.embed_dim, rm.embed_dim); + auto b_q = in_b_full->slice(0, 0, rm.embed_dim); + auto b_k = in_b_full->slice(0, rm.embed_dim, rm.embed_dim); + auto b_v = in_b_full->slice(0, 2 * rm.embed_dim, rm.embed_dim); + + auto q_full = Tensor::buffer(dt, {rm.num_queries, rm.embed_dim}, memory_pool); + auto k_full = Tensor::buffer(dt, {seq_len, rm.embed_dim}, memory_pool); + auto v_full = Tensor::buffer(dt, {seq_len, rm.embed_dim}, memory_pool); + linear(q_full, q_ln, w_q, 1.0f, 0.0f, nullptr, b_q); + linear(k_full, x_key, w_k, 1.0f, 0.0f, nullptr, b_k); + linear(v_full, x_val, w_v, 1.0f, 0.0f, nullptr, b_v); + RUN_INFINI(infinirtStreamSynchronize(model->stream)); + + const size_t nh = rm.num_heads; + const size_t dh = rm.embed_dim / nh; + const float scale = 1.0f / std::sqrt(static_cast(dh)); + + // Avoid permute+view on non-contiguous tensors: slice heads from the last dim directly. + auto out_merge = Tensor::buffer(dt, {rm.num_queries, rm.embed_dim}, memory_pool); + auto qk = Tensor::buffer(dt, {rm.num_queries, seq_len}, memory_pool); + auto out_h = Tensor::buffer(dt, {rm.num_queries, dh}, memory_pool); + auto q_h_contig = Tensor::buffer(dt, {rm.num_queries, dh}, memory_pool); + auto k_h_contig = Tensor::buffer(dt, {seq_len, dh}, memory_pool); + auto v_h_contig = Tensor::buffer(dt, {seq_len, dh}, memory_pool); + auto k_t_contig = Tensor::buffer(dt, {dh, seq_len}, memory_pool); + + for (size_t h = 0; h < nh; ++h) { + const size_t col = h * dh; + auto q_h = q_full->slice(1, col, dh); // [Q, dh] (strided view) + auto k_h = k_full->slice(1, col, dh); // [L, dh] (strided view) + auto v_h = v_full->slice(1, col, dh); // [L, dh] (strided view) + rearrange(q_h_contig, q_h); + rearrange(k_h_contig, k_h); + rearrange(v_h_contig, v_h); + auto k_h_t_view = k_h_contig->permute({1, 0}); // [dh, L] (strided view) + rearrange(k_t_contig, k_h_t_view); // materialize to contiguous for GEMM + + gemm(qk, q_h_contig, k_t_contig, scale, 0.0f); + RUN_INFINI(infinirtStreamSynchronize(model->stream)); + Softmax(qk, qk, 1); + + gemm(out_h, qk, v_h_contig, 1.0f, 0.0f); + rearrange(out_merge->slice(1, col, dh), out_h); + } + //RUN_INFINI(infinirtStreamSynchronize(stream)); + RUN_INFINI(infinirtStreamSynchronize(model->stream)); + + // Out proj: assume already transposed to [embed_dim, embed_dim] + auto out_w = Tensor::weight(const_cast(model->weights->resampler_attn_out_proj_weight), dt, {rm.embed_dim, rm.embed_dim}); + auto out_b = Tensor::weight(const_cast(model->weights->resampler_attn_out_proj_bias), dt, {rm.embed_dim}); + auto out_proj = Tensor::buffer(dt, {rm.num_queries, rm.embed_dim}, memory_pool); + linear(out_proj, out_merge, out_w, 1.0f, 0.0f, nullptr, out_b); + RUN_INFINI(infinirtStreamSynchronize(model->stream)); + + // ln_post and final projection + //minicpmv::ref_ops::layer_norm_last_dim(out_proj, out_proj, ln_post_w, ln_post_b, rm.layer_norm_eps); + layerNorm(out_proj, out_proj, ln_post_w, ln_post_b, rm.layer_norm_eps); + auto proj_w = Tensor::weight(const_cast(model->weights->resampler_proj), dt, {rm.embed_dim, rm.embed_dim}); + auto out_final = Tensor::buffer(dt, {rm.num_queries, rm.embed_dim}, memory_pool); + gemm(out_final, out_proj, proj_w, 1.0f, 0.0f); + RUN_INFINI(infinirtStreamSynchronize(model->stream)); + + // Copy back to caller. + RUN_INFINI(infinirtMemcpy(output, out_final->data(), + out_final->numel() * dsize(dt), + INFINIRT_MEMCPY_D2H)); + + setInferenceContext(nullptr); +} + +__C void inferMiniCPMVVisionResampler(struct MiniCPMVModel *model, + const void *pixel_values, + size_t seq_len, + uint32_t tgt_h, + uint32_t tgt_w, + void *output) { + ASSERT_VALID_PTR(model); + ASSERT_VALID_PTR(pixel_values); + ASSERT_VALID_PTR(output); + + ASSERT_EQ(model->dev_ids.size(), size_t(1)); + + const auto &vm = model->meta.vision_meta; + const auto dt = model->meta.language_meta.dt_logits; + ASSERT_EQ(seq_len, static_cast(tgt_h) * static_cast(tgt_w)); + + const size_t d = vm.vision_embed_dim; + const size_t bytes = seq_len * d * dsize(dt); + std::vector buf(bytes); + + inferMiniCPMVSiglipEmbeddings(model, pixel_values, seq_len, tgt_h, tgt_w, buf.data()); + inferMiniCPMVSiglipEncoder(model, static_cast(vm.vision_num_layers), buf.data(), seq_len, buf.data()); + inferMiniCPMVResampler(model, buf.data(), seq_len, tgt_h, tgt_w, output); +} diff --git a/src/models/minicpmv/minicpmv_impl.hpp b/src/models/minicpmv/minicpmv_impl.hpp new file mode 100644 index 00000000..f192cf05 --- /dev/null +++ b/src/models/minicpmv/minicpmv_impl.hpp @@ -0,0 +1,23 @@ +#ifndef MINICPMV_IMPL_HPP +#define MINICPMV_IMPL_HPP + +#include "infinicore_infer/models/minicpmv.h" + +#include + +struct MiniCPMVModel { + MiniCPMVMeta meta; + const MiniCPMVWeights *weights; + infiniDevice_t device; + std::vector dev_ids; + infiniopHandle_t op_handle = nullptr; + infinirtStream_t stream = nullptr; + + MiniCPMVModel(const MiniCPMVMeta *meta_, + const MiniCPMVWeights *weights_, + infiniDevice_t device_, + std::vector device_ids) + : meta(*meta_), weights(weights_), device(device_), dev_ids(std::move(device_ids)) {} +}; + +#endif diff --git a/src/models/minicpmv/ref_ops.hpp b/src/models/minicpmv/ref_ops.hpp new file mode 100644 index 00000000..0092b286 --- /dev/null +++ b/src/models/minicpmv/ref_ops.hpp @@ -0,0 +1,207 @@ +#ifndef MINICPMV_REF_OPS_HPP +#define MINICPMV_REF_OPS_HPP + +#include "../../tensor.hpp" +#include "../../utils.hpp" + +#include +#include +#include +#include + +namespace minicpmv::ref_ops { + +inline float read_as_f32(const void *p, infiniDtype_t dtype) { + switch (dtype) { + case INFINI_DTYPE_F32: + return *reinterpret_cast(p); + case INFINI_DTYPE_F16: + return f16_to_f32(*reinterpret_cast(p)); + case INFINI_DTYPE_BF16: + return bf16_to_f32(*reinterpret_cast(p)); + default: + PANIC(unsupported_dtype); + return 0.0f; + } +} + +inline void write_from_f32(void *p, infiniDtype_t dtype, float v) { + switch (dtype) { + case INFINI_DTYPE_F32: + *reinterpret_cast(p) = v; + return; + case INFINI_DTYPE_F16: + *reinterpret_cast(p) = f32_to_f16(v); + return; + case INFINI_DTYPE_BF16: + *reinterpret_cast(p) = f32_to_bf16(v); + return; + default: + PANIC(unsupported_dtype); + return; + } +} + +inline void layer_norm_last_dim(std::shared_ptr y, + std::shared_ptr x, + std::shared_ptr gamma, + std::shared_ptr beta, + float eps) { + ASSERT_EQ(y->deviceType(), INFINI_DEVICE_CPU); + ASSERT_EQ(x->deviceType(), INFINI_DEVICE_CPU); + ASSERT_EQ(gamma->deviceType(), INFINI_DEVICE_CPU); + ASSERT_EQ(beta->deviceType(), INFINI_DEVICE_CPU); + + ASSERT(x->ndim() >= 1); + ASSERT_EQ(gamma->ndim(), 1); + ASSERT_EQ(beta->ndim(), 1); + ASSERT_EQ(gamma->shape()[0], beta->shape()[0]); + + const size_t d = gamma->shape()[0]; + ASSERT_EQ(x->shape().back(), d); + ASSERT_EQ(y->shape(), x->shape()); + + const infiniDtype_t dt_x = x->dtype(); + const infiniDtype_t dt_y = y->dtype(); + const infiniDtype_t dt_w = gamma->dtype(); + ASSERT_EQ(beta->dtype(), dt_w); + + const size_t outer = x->numel() / d; + const char *x_ptr = reinterpret_cast(x->data()); + char *y_ptr = reinterpret_cast(y->data()); + const char *g_ptr = reinterpret_cast(gamma->data()); + const char *b_ptr = reinterpret_cast(beta->data()); + + const size_t sx = dsize(dt_x); + const size_t sy = dsize(dt_y); + const size_t sw = dsize(dt_w); + + for (size_t row = 0; row < outer; ++row) { + float mean = 0.0f; + for (size_t i = 0; i < d; ++i) { + mean += read_as_f32(x_ptr + (row * d + i) * sx, dt_x); + } + mean /= static_cast(d); + + float var = 0.0f; + for (size_t i = 0; i < d; ++i) { + float v = read_as_f32(x_ptr + (row * d + i) * sx, dt_x) - mean; + var += v * v; + } + var /= static_cast(d); + + const float inv_std = 1.0f / std::sqrt(var + eps); + for (size_t i = 0; i < d; ++i) { + const float xv = read_as_f32(x_ptr + (row * d + i) * sx, dt_x); + const float gv = read_as_f32(g_ptr + i * sw, dt_w); + const float bv = read_as_f32(b_ptr + i * sw, dt_w); + const float yn = (xv - mean) * inv_std; + write_from_f32(y_ptr + (row * d + i) * sy, dt_y, yn * gv + bv); + } + } +} + +inline void layer_norm_last_dim_raw(void *y, + const void *x, + const void *gamma, + const void *beta, + infiniDtype_t dtype, + size_t outer, + size_t d, + float eps) { + ASSERT_VALID_PTR(y); + ASSERT_VALID_PTR(x); + ASSERT_VALID_PTR(gamma); + ASSERT_VALID_PTR(beta); + + const char *x_ptr = reinterpret_cast(x); + char *y_ptr = reinterpret_cast(y); + const char *g_ptr = reinterpret_cast(gamma); + const char *b_ptr = reinterpret_cast(beta); + + const size_t s = dsize(dtype); + + for (size_t row = 0; row < outer; ++row) { + float mean = 0.0f; + for (size_t i = 0; i < d; ++i) { + mean += read_as_f32(x_ptr + (row * d + i) * s, dtype); + } + mean /= static_cast(d); + + float var = 0.0f; + for (size_t i = 0; i < d; ++i) { + float v = read_as_f32(x_ptr + (row * d + i) * s, dtype) - mean; + var += v * v; + } + var /= static_cast(d); + + const float inv_std = 1.0f / std::sqrt(var + eps); + for (size_t i = 0; i < d; ++i) { + const float xv = read_as_f32(x_ptr + (row * d + i) * s, dtype); + const float gv = read_as_f32(g_ptr + i * s, dtype); + const float bv = read_as_f32(b_ptr + i * s, dtype); + const float yn = (xv - mean) * inv_std; + write_from_f32(y_ptr + (row * d + i) * s, dtype, yn * gv + bv); + } + } +} + +inline void softmax_last_dim_inplace(std::shared_ptr x) { + ASSERT_EQ(x->deviceType(), INFINI_DEVICE_CPU); + ASSERT(x->ndim() >= 1); + const size_t d = x->shape().back(); + const size_t outer = x->numel() / d; + const infiniDtype_t dt = x->dtype(); + const size_t s = dsize(dt); + char *ptr = reinterpret_cast(x->data()); + + std::vector tmp(d); + for (size_t row = 0; row < outer; ++row) { + float max_v = -INFINITY; + for (size_t i = 0; i < d; ++i) { + float v = read_as_f32(ptr + (row * d + i) * s, dt); + tmp[i] = v; + if (v > max_v) { + max_v = v; + } + } + + float sum = 0.0f; + for (size_t i = 0; i < d; ++i) { + float e = std::exp(tmp[i] - max_v); + tmp[i] = e; + sum += e; + } + const float inv = 1.0f / sum; + + for (size_t i = 0; i < d; ++i) { + write_from_f32(ptr + (row * d + i) * s, dt, tmp[i] * inv); + } + } +} + +inline float gelu_tanh_f32(float x) { + // gelu(x) ≈ 0.5x(1+tanh(sqrt(2/pi)(x+0.044715x^3))) + constexpr float kAlpha = 0.7978845608028654f; // sqrt(2/pi) + constexpr float kBeta = 0.044715f; + const float x3 = x * x * x; + const float u = kAlpha * (x + kBeta * x3); + return 0.5f * x * (1.0f + std::tanh(u)); +} + +inline void gelu_tanh_inplace(std::shared_ptr x) { + ASSERT_EQ(x->deviceType(), INFINI_DEVICE_CPU); + const infiniDtype_t dt = x->dtype(); + const size_t n = x->numel(); + const size_t s = dsize(dt); + char *ptr = reinterpret_cast(x->data()); + for (size_t i = 0; i < n; ++i) { + float v = read_as_f32(ptr + i * s, dt); + v = gelu_tanh_f32(v); + write_from_f32(ptr + i * s, dt, v); + } +} + +} // namespace minicpmv::ref_ops + +#endif diff --git a/src/models/minicpmv/ref_pos_embed.hpp b/src/models/minicpmv/ref_pos_embed.hpp new file mode 100644 index 00000000..92b23db0 --- /dev/null +++ b/src/models/minicpmv/ref_pos_embed.hpp @@ -0,0 +1,42 @@ +#ifndef MINICPMV_REF_POS_EMBED_HPP +#define MINICPMV_REF_POS_EMBED_HPP + +#include +#include +#include + +namespace minicpmv::ref_pos_embed { + +// Matches minicpmv_config/resampler.py: +// - omega = 1 / 10000 ** (i / (D/2)) +// - for 2D: meshgrid(w, h) then emb_h uses grid[0] (w), emb_w uses grid[1] (h) +inline void compute_2d_sincos_pos_embed(float *out, size_t embed_dim, size_t h, size_t w) { + const size_t half = embed_dim / 2; + const size_t quarter = half / 2; + + for (size_t y = 0; y < h; ++y) { + for (size_t x = 0; x < w; ++x) { + float *dst = out + (y * w + x) * embed_dim; + for (size_t i = 0; i < quarter; ++i) { + const float omega = std::pow(10000.0f, -static_cast(i) / static_cast(quarter)); + const float a = static_cast(x) * omega; // w first + const float b = static_cast(y) * omega; // then h + dst[i] = std::sin(a); + dst[i + quarter] = std::cos(a); + dst[i + half] = std::sin(b); + dst[i + half + quarter] = std::cos(b); + } + } + } +} + +inline std::vector make_2d_sincos_pos_embed(size_t embed_dim, size_t h, size_t w) { + std::vector out(h * w * embed_dim); + compute_2d_sincos_pos_embed(out.data(), embed_dim, h, w); + return out; +} + +} // namespace minicpmv::ref_pos_embed + +#endif + diff --git a/src/tensor.hpp b/src/tensor.hpp index 320d871c..082a2379 100644 --- a/src/tensor.hpp +++ b/src/tensor.hpp @@ -130,6 +130,7 @@ class Tensor : public std::enable_shared_from_this { void debug(const std::string &filename) const; void debug() const; + void debug_first_n(size_t n = 10) const; // 新增方法:只打印前n个数据 std::string info() const; size_t seed() const; diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index edf0faeb..c18e08c3 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -159,7 +159,7 @@ std::shared_ptr Tensor::weight(void *data, infiniDtype_t dtype, tensor->_storage = Storage::create(size); tensor->_desc = TensorDesc::create(dtype, shape, strides); if (data != nullptr) { - tensor->load(data); + tensor->load(data); // CPU -> 设备内存传输 } tensor->_offset = 0; @@ -272,6 +272,36 @@ void print_data_bf16(uint16_t const *data, const std::vector &shape, } } +// 新增模板函数:只打印前n个数据 +template +void print_data_first_n(T *data, size_t total_elements, size_t n) { + size_t elements_to_print = std::min(total_elements, n); + std::cout << "前" << elements_to_print << "个数据: "; + for (size_t i = 0; i < elements_to_print; i++) { + std::cout << data[i] << " "; + } + std::cout << std::endl; +} + +// F16特化版本,显示hex值和近似float值 +void print_data_first_n_f16(uint16_t const *data, size_t total_elements, size_t n) { + // size_t elements_to_print = std::min(total_elements, n); + // std::cout << "前" << elements_to_print << "个F16数据: " << std::endl; + // for (size_t i = 0; i < elements_to_print; i++) { + // std::cout << f16_to_f32(data[i]) << " "; + // } + // std::cout << std::endl; +} + +// BF16特化版本 - 使用不同的函数名避免与F16冲突 +void print_data_first_n_bf16(uint16_t const *data, size_t total_elements, size_t n) { + size_t elements_to_print = std::min(total_elements, n); + std::cout << "前" << elements_to_print << "个BF16数据: " << std::endl; + for (size_t i = 0; i < elements_to_print; i++) { + std::cout << " [" << i << "]: 0x" << std::hex << data[i] << std::dec << std::endl; + } +} + std::string Tensor::info() const { std::stringstream ss; @@ -279,7 +309,8 @@ std::string Tensor::info() const { << this->_desc->info() << " device=" << this->deviceType() << " device_id=" << this->deviceId(); - return this->_desc->info(); + // return this->_desc->info(); + return ss.str(); } size_t Tensor::seed() const { @@ -424,3 +455,67 @@ void Tensor::debug(const std::string &filename) const { } void Tensor::debug() const { this->debug(""); } + +void Tensor::debug_first_n(size_t n) const { + RUN_INFINI(infinirtDeviceSynchronize()); + + // std::cout << "=== Tensor Debug (First " << n << " Elements) ===" << std::endl; + // std::cout << info() << std::endl; + + void const *cpu_data; + void *allocated_memory = nullptr; + + if (this->deviceType() != INFINI_DEVICE_CPU) { + // 从设备内存复制数据到主机内存 + allocated_memory = std::malloc(this->_storage->size()); + RUN_INFINI(infinirtMemcpy(allocated_memory, this->_storage->memory(), + this->_storage->size(), INFINIRT_MEMCPY_D2H)); + cpu_data = allocated_memory; + //std::cout << "数据已从设备内存复制到主机内存" << std::endl; + } else { + cpu_data = this->_storage->memory(); + //std::cout << "数据直接在主机内存中" << std::endl; + } + + // 计算总元素数量 + size_t total_elements = numel(); + // std::cout << "总元素数量: " << total_elements << std::endl; + // std::cout << "数据类型大小: " << dsize(this->dtype()) << " 字节" << std::endl; + // std::cout << "存储总大小: " << this->_storage->size() << " 字节" << std::endl; + + // 根据数据类型打印前n个元素 + const char* data_ptr = static_cast(cpu_data) + dataOffset(); + + switch (this->dtype()) { + case INFINI_DTYPE_F16: + print_data_first_n_f16((uint16_t const *)data_ptr, total_elements, n); + break; + case INFINI_DTYPE_F32: + print_data_first_n((float const *)data_ptr, total_elements, n); + break; + case INFINI_DTYPE_U64: + print_data_first_n((uint64_t const *)data_ptr, total_elements, n); + break; + case INFINI_DTYPE_I64: + print_data_first_n((int64_t const *)data_ptr, total_elements, n); + break; + case INFINI_DTYPE_U32: + print_data_first_n((uint32_t const *)data_ptr, total_elements, n); + break; + case INFINI_DTYPE_I32: + print_data_first_n((int32_t const *)data_ptr, total_elements, n); + break; + case INFINI_DTYPE_BF16: + print_data_first_n_bf16((uint16_t const *)data_ptr, total_elements, n); + break; + default: + std::cout << "不支持的数据类型,无法显示" << std::endl; + PANIC("Unsupported data type"); + } + + // 释放分配的内存 + if (allocated_memory) { + std::free(allocated_memory); + } + +} diff --git a/tests/test_kv_compression_correctness.cpp b/tests/test_kv_compression_correctness.cpp new file mode 100644 index 00000000..801b3795 --- /dev/null +++ b/tests/test_kv_compression_correctness.cpp @@ -0,0 +1,218 @@ +#include "../src/cache_manager/kv_compression.hpp" +#include "../src/cache_manager/opcache_manager.hpp" +#include "../src/models/inference_context.hpp" +#include "../include/infinicore_infer/cache.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace { +struct Meta { + uint32_t layers = 0; + uint32_t heads = 0; + uint32_t head_dim = 0; + uint32_t seq_in = 0; + uint32_t seq_out = 0; + uint32_t compression_factor = 1; + uint32_t min_seq_len = 1; + uint32_t image_kv_len = 0; +}; + +int64_t extract_int(const std::string &s, const std::string &key, int64_t def = -1) { + std::regex re(key + R"(\s*:\s*([0-9]+))"); + std::smatch m; + if (std::regex_search(s, m, re)) { + return std::stoll(m[1]); + } + return def; +} + +Meta parse_meta(const std::string &path) { + std::ifstream fin(path); + std::string content((std::istreambuf_iterator(fin)), std::istreambuf_iterator()); + Meta m; + m.layers = extract_int(content, "\"layers\""); + m.heads = extract_int(content, "\"heads\""); + m.head_dim = extract_int(content, "\"head_dim\""); + m.seq_in = extract_int(content, "\"seq_len_in\""); + m.seq_out = extract_int(content, "\"seq_len_out\""); + m.compression_factor = extract_int(content, "\"compression_factor\"", 1); + m.min_seq_len = extract_int(content, "\"min_seq_len\"", 1); + // it_len: [a, b] + std::regex it_re(R"("it_len"\s*:\s*\[\s*([0-9]+)\s*,\s*([0-9]+)\s*\])"); + std::smatch mm; + if (std::regex_search(content, mm, it_re)) { + m.image_kv_len = static_cast(std::stoul(mm[1])); + } + return m; +} + +std::vector load_bin(const std::string &path) { + std::ifstream fin(path, std::ios::binary); + fin.seekg(0, std::ios::end); + size_t bytes = fin.tellg(); + fin.seekg(0, std::ios::beg); + std::vector buf(bytes / sizeof(uint16_t)); + fin.read(reinterpret_cast(buf.data()), bytes); + return buf; +} + +std::vector tensor_to_float(std::shared_ptr t) { + size_t n = t->numel(); + std::vector out(n); + std::vector tmp(n); + RUN_INFINI(infinirtMemcpy(tmp.data(), t->data(), n * sizeof(uint16_t), INFINIRT_MEMCPY_D2H)); + for (size_t i = 0; i < n; ++i) out[i] = f16_to_f32(tmp[i]); + return out; +} + +void copy_from_bin(std::shared_ptr t, const std::vector &buf, size_t offset_elems, uint32_t heads, uint32_t seq, uint32_t dk) { + // Source layout: [B=1, H, S, D] row-major in buf; target Tensor layout: [S, H, D] + std::vector tmp(t->numel()); + for (uint32_t h = 0; h < heads; ++h) { + for (uint32_t s = 0; s < seq; ++s) { + for (uint32_t d = 0; d < dk; ++d) { + size_t src_idx = offset_elems + ((h * seq + s) * dk + d); + size_t dst_idx = (static_cast(s) * heads + h) * dk + d; + tmp[dst_idx] = buf[src_idx]; + } + } + } + RUN_INFINI(infinirtMemcpy(t->data(), tmp.data(), tmp.size() * sizeof(uint16_t), INFINIRT_MEMCPY_H2D)); +} + +std::pair diff_stats(const std::vector &a, const std::vector &b) { + float maxd = 0.f, meand = 0.f; + size_t n = a.size(); + for (size_t i = 0; i < n; ++i) { + float d = std::abs(a[i] - b[i]); + maxd = std::max(maxd, d); + meand += d; + } + meand /= static_cast(n); + return {maxd, meand}; +} + +void print_first(const std::vector &v, size_t n, const std::string &label) { + std::cout << label << " first " << n << ": "; + for (size_t i = 0; i < std::min(n, v.size()); ++i) { + std::cout << v[i] << " "; + } + std::cout << "\n"; +} + + + + +} // namespace + +int main(int argc, char **argv) { + if (argc < 2) { + std::cerr << "Usage: " << argv[0] << " " << std::endl; + return 1; + } + + Meta meta = parse_meta("dump_kv/meta.json"); + if (meta.layers == 0) { + std::cerr << "Failed to parse dump_kv/meta.json" << std::endl; + return 2; + } + + RUN_INFINI(infinirtInit()); + RUN_INFINI(infinirtSetDevice(INFINI_DEVICE_HYGON, 0)); + + CompressionConfig cfg; + cfg.enable = true; + cfg.weight_path = argv[1]; + cfg.image_kv_len = meta.image_kv_len; + cfg.compression_factor = meta.compression_factor; + cfg.min_seq_len = meta.min_seq_len; + + Compressor compressor(cfg); + if (!compressor.loadWeights()) { + std::cerr << "loadWeights failed" << std::endl; + return 3; + } + + + int dev_id = 0; + auto kv = createKVCache(meta.layers, meta.seq_in, meta.heads, meta.head_dim, meta.head_dim, + INFINI_DTYPE_F16, INFINI_DEVICE_HYGON, &dev_id, 1); + + auto input_buf = load_bin("dump_kv/input_kv.bin"); + auto output_buf = load_bin("dump_kv/output_kv.bin"); + + for(int i = 0; i < 10; ++i) { + std::cout << f16_to_f32(input_buf[i]) << std::endl; + } + // Load input K/V per layer using offsets from meta.index (deterministic order K then V). + size_t elems_per = static_cast(meta.heads) * meta.seq_in * meta.head_dim; + for (size_t layer = 0; layer < meta.layers; ++layer) { + size_t k_off = layer * 2 * elems_per; + size_t v_off = k_off + elems_per; + copy_from_bin(kv->k[0][layer], input_buf, k_off, meta.heads, meta.seq_in, meta.head_dim); + copy_from_bin(kv->v[0][layer], input_buf, v_off, meta.heads, meta.seq_in, meta.head_dim); + } + + auto pool = std::make_shared(256 * 1024 * 1024); + CacheManager cache_mgr(32); + infiniopHandle_t handle = nullptr; + infinirtStream_t stream = nullptr; + RUN_INFINI(infiniopCreateHandle(&handle)); + RUN_INFINI(infinirtStreamCreate(&stream)); + InferenceContext ctx(handle, pool, &cache_mgr, stream); + setInferenceContext(&ctx); + + auto compressed = compressor.compress(*kv, meta.seq_in); + if (!compressed) { + std::cerr << "compress returned nullptr" << std::endl; + return 4; + } + + float max_diff = 0.f, mean_accum = 0.f; + size_t total_elems = 0; + size_t elems_out_per = static_cast(meta.heads) * meta.seq_out * meta.head_dim; + + for (size_t layer = 0; layer < meta.layers; ++layer) { + auto &layer_out = compressed->layers[layer]; + auto k_vec = tensor_to_float(layer_out.k_comp); + auto v_vec = tensor_to_float(layer_out.v_comp); + size_t k_off = layer * 2 * elems_out_per; + size_t v_off = k_off + elems_out_per; + + std::vector k_exp(elems_out_per), v_exp(elems_out_per); + for (size_t i = 0; i < elems_out_per; ++i) { + k_exp[i] = f16_to_f32(output_buf[k_off + i]); + v_exp[i] = f16_to_f32(output_buf[v_off + i]); + } + if (layer == 0) { + print_first(k_vec, 8, "k_out"); + print_first(k_exp, 8, "k_exp"); + print_first(v_vec, 8, "v_out"); + print_first(v_exp, 8, "v_exp"); + } + auto kd = diff_stats(k_vec, k_exp); + auto vd = diff_stats(v_vec, v_exp); + max_diff = std::max({max_diff, kd.first, vd.first}); + mean_accum += (kd.second * k_vec.size() + vd.second * v_vec.size()); + total_elems += k_vec.size() + v_vec.size(); + } + + float mean_diff = mean_accum / static_cast(total_elems); + std::cout << "max_diff=" << max_diff << " mean_diff=" << mean_diff << std::endl; + + dropKVCache(kv); + RUN_INFINI(infinirtStreamDestroy(stream)); + RUN_INFINI(infiniopDestroyHandle(handle)); + return (max_diff < 1e-3f) ? 0 : 5; +} diff --git a/tests/test_kv_compression_correctness_cpu.cpp b/tests/test_kv_compression_correctness_cpu.cpp new file mode 100644 index 00000000..07b287a3 --- /dev/null +++ b/tests/test_kv_compression_correctness_cpu.cpp @@ -0,0 +1,217 @@ +#include "../src/cache_manager/kv_compression.hpp" +#include "../src/cache_manager/opcache_manager.hpp" +#include "../src/models/inference_context.hpp" +#include "../include/infinicore_infer/cache.h" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + + +namespace { +struct Meta { + uint32_t layers = 0; + uint32_t heads = 0; + uint32_t head_dim = 0; + uint32_t seq_in = 0; + uint32_t seq_out = 0; + uint32_t compression_factor = 1; + uint32_t min_seq_len = 1; + uint32_t image_kv_len = 0; +}; + +int64_t extract_int(const std::string &s, const std::string &key, int64_t def = -1) { + std::regex re(key + R"(\s*:\s*([0-9]+))"); + std::smatch m; + if (std::regex_search(s, m, re)) { + return std::stoll(m[1]); + } + return def; +} + +Meta parse_meta(const std::string &path) { + std::ifstream fin(path); + std::string content((std::istreambuf_iterator(fin)), std::istreambuf_iterator()); + Meta m; + m.layers = extract_int(content, "\"layers\""); + m.heads = extract_int(content, "\"heads\""); + m.head_dim = extract_int(content, "\"head_dim\""); + m.seq_in = extract_int(content, "\"seq_len_in\""); + m.seq_out = extract_int(content, "\"seq_len_out\""); + m.compression_factor = extract_int(content, "\"compression_factor\"", 1); + m.min_seq_len = extract_int(content, "\"min_seq_len\"", 1); + // it_len: [a, b] + std::regex it_re(R"("it_len"\s*:\s*\[\s*([0-9]+)\s*,\s*([0-9]+)\s*\])"); + std::smatch mm; + if (std::regex_search(content, mm, it_re)) { + m.image_kv_len = static_cast(std::stoul(mm[1])); + } + return m; +} + +std::vector load_bin(const std::string &path) { + std::ifstream fin(path, std::ios::binary); + fin.seekg(0, std::ios::end); + size_t bytes = fin.tellg(); + fin.seekg(0, std::ios::beg); + std::vector buf(bytes / sizeof(uint16_t)); + fin.read(reinterpret_cast(buf.data()), bytes); + return buf; +} + +std::vector tensor_to_float(std::shared_ptr t) { + size_t n = t->numel(); + std::vector out(n); + std::vector tmp(n); + RUN_INFINI(infinirtMemcpy(tmp.data(), t->data(), n * sizeof(uint16_t), INFINIRT_MEMCPY_D2H)); + for (size_t i = 0; i < n; ++i) out[i] = f16_to_f32(tmp[i]); + return out; +} + +void copy_from_bin(std::shared_ptr t, const std::vector &buf, size_t offset_elems, uint32_t heads, uint32_t seq, uint32_t dk) { + // Source layout: [B=1, H, S, D] row-major in buf; target Tensor layout: [S, H, D] + std::vector tmp(t->numel()); + for (uint32_t h = 0; h < heads; ++h) { + for (uint32_t s = 0; s < seq; ++s) { + for (uint32_t d = 0; d < dk; ++d) { + size_t src_idx = offset_elems + ((h * seq + s) * dk + d); + size_t dst_idx = (static_cast(s) * heads + h) * dk + d; + tmp[dst_idx] = buf[src_idx]; + } + } + } + RUN_INFINI(infinirtMemcpy(t->data(), tmp.data(), tmp.size() * sizeof(uint16_t), INFINIRT_MEMCPY_H2D)); +} + +std::pair diff_stats(const std::vector &a, const std::vector &b) { + float maxd = 0.f, meand = 0.f; + size_t n = a.size(); + for (size_t i = 0; i < n; ++i) { + float d = std::abs(a[i] - b[i]); + maxd = std::max(maxd, d); + meand += d; + } + meand /= static_cast(n); + return {maxd, meand}; +} + +void print_first(const std::vector &v, size_t n, const std::string &label) { + std::cout << label << " first " << n << ": "; + for (size_t i = 0; i < std::min(n, v.size()); ++i) { + std::cout << v[i] << " "; + } + std::cout << "\n"; +} + + + + +} // namespace + +int main(int argc, char **argv) { + if (argc < 2) { + std::cerr << "Usage: " << argv[0] << " " << std::endl; + return 1; + } + + Meta meta = parse_meta("dump_kv/meta.json"); + if (meta.layers == 0) { + std::cerr << "Failed to parse dump_kv/meta.json" << std::endl; + return 2; + } + + CompressionConfig cfg; + cfg.enable = true; + cfg.weight_path = argv[1]; + cfg.image_kv_len = meta.image_kv_len; + cfg.compression_factor = meta.compression_factor; + cfg.min_seq_len = meta.min_seq_len; + + Compressor compressor(cfg); + if (!compressor.loadWeights()) { + std::cerr << "loadWeights failed" << std::endl; + return 3; + } + + RUN_INFINI(infinirtInit()); + RUN_INFINI(infinirtSetDevice(INFINI_DEVICE_CPU, 0)); + + int dev_id = 0; + auto kv = createKVCache(meta.layers, meta.seq_in, meta.heads, meta.head_dim, meta.head_dim, + INFINI_DTYPE_F16, INFINI_DEVICE_CPU, &dev_id, 1); + + auto input_buf = load_bin("dump_kv/input_kv.bin"); + auto output_buf = load_bin("dump_kv/output_kv.bin"); + + for(int i = 0; i < 10; ++i) { + std::cout << f16_to_f32(input_buf[i]) << std::endl; + } + // Load input K/V per layer using offsets from meta.index (deterministic order K then V). + size_t elems_per = static_cast(meta.heads) * meta.seq_in * meta.head_dim; + for (size_t layer = 0; layer < meta.layers; ++layer) { + size_t k_off = layer * 2 * elems_per; + size_t v_off = k_off + elems_per; + copy_from_bin(kv->k[0][layer], input_buf, k_off, meta.heads, meta.seq_in, meta.head_dim); + copy_from_bin(kv->v[0][layer], input_buf, v_off, meta.heads, meta.seq_in, meta.head_dim); + } + + auto pool = std::make_shared(256 * 1024 * 1024); + CacheManager cache_mgr(32); + infiniopHandle_t handle = nullptr; + infinirtStream_t stream = nullptr; + RUN_INFINI(infiniopCreateHandle(&handle)); + RUN_INFINI(infinirtStreamCreate(&stream)); + InferenceContext ctx(handle, pool, &cache_mgr, stream); + setInferenceContext(&ctx); + + auto compressed = compressor.compress(*kv, meta.seq_in); + if (!compressed) { + std::cerr << "compress returned nullptr" << std::endl; + return 4; + } + + float max_diff = 0.f, mean_accum = 0.f; + size_t total_elems = 0; + size_t elems_out_per = static_cast(meta.heads) * meta.seq_out * meta.head_dim; + + for (size_t layer = 0; layer < meta.layers; ++layer) { + auto &layer_out = compressed->layers[layer]; + auto k_vec = tensor_to_float(layer_out.k_comp); + auto v_vec = tensor_to_float(layer_out.v_comp); + size_t k_off = layer * 2 * elems_out_per; + size_t v_off = k_off + elems_out_per; + + std::vector k_exp(elems_out_per), v_exp(elems_out_per); + for (size_t i = 0; i < elems_out_per; ++i) { + k_exp[i] = f16_to_f32(output_buf[k_off + i]); + v_exp[i] = f16_to_f32(output_buf[v_off + i]); + } + if (layer == 0) { + print_first(k_vec, 8, "k_out"); + print_first(k_exp, 8, "k_exp"); + print_first(v_vec, 8, "v_out"); + print_first(v_exp, 8, "v_exp"); + } + auto kd = diff_stats(k_vec, k_exp); + auto vd = diff_stats(v_vec, v_exp); + max_diff = std::max({max_diff, kd.first, vd.first}); + mean_accum += (kd.second * k_vec.size() + vd.second * v_vec.size()); + total_elems += k_vec.size() + v_vec.size(); + } + + float mean_diff = mean_accum / static_cast(total_elems); + std::cout << "max_diff=" << max_diff << " mean_diff=" << mean_diff << std::endl; + + dropKVCache(kv); + RUN_INFINI(infinirtStreamDestroy(stream)); + RUN_INFINI(infiniopDestroyHandle(handle)); + return (max_diff < 1e-3f) ? 0 : 5; +} diff --git a/tests/test_kv_compression_load.cpp b/tests/test_kv_compression_load.cpp new file mode 100644 index 00000000..b4f40cc3 --- /dev/null +++ b/tests/test_kv_compression_load.cpp @@ -0,0 +1,79 @@ +#include "../src/cache_manager/kv_compression.hpp" +#include "../src/models/inference_context.hpp" // to init context for linear/relu +#include "../src/cache_manager/opcache_manager.hpp" +#include "../include/infinicore_infer/cache.h" +#include +#include + +#include +#include + +int main(int argc, char **argv) { + if (argc < 2) { + std::cerr << "Usage: " << argv[0] << " [cpu|hygon]" << std::endl; + return 1; + } + std::string dev_arg = argc >= 3 ? std::string(argv[2]) : "cpu"; + + CompressionConfig cfg; + cfg.enable = true; + cfg.weight_path = argv[1]; + + // Prepare a fake KVCache (single device, single layer). + size_t nlayers = 1; + size_t max_len = 20; + size_t nkvh = 32; // match LLAVA config (num_key_value_heads) + size_t dk = 128; // head_dim; cols = head_dim * factor (128*5=640) + size_t dv = 128; + // Use Hygon device as requested + infiniDevice_t device = INFINI_DEVICE_CPU; + if (dev_arg == "hygon") { + device = INFINI_DEVICE_HYGON; + } + int dev_id = 0; + + RUN_INFINI(infinirtInit()); + RUN_INFINI(infinirtSetDevice(device, dev_id)); + + Compressor compressor(cfg); + if (!compressor.loadWeights()) { + std::cerr << "loadWeights failed" << std::endl; + return 2; + } + + auto kv = createKVCache(nlayers, max_len, nkvh, dk, dv, INFINI_DTYPE_F16, device, &dev_id, 1); + + // Initialize a minimal inference context (device-specific) for linear/relu ops. + auto pool = std::make_shared(128 * 1024 * 1024); + CacheManager cache_mgr(32); + infiniopHandle_t handle = nullptr; + infinirtStream_t stream = nullptr; + RUN_INFINI(infiniopCreateHandle(&handle)); + RUN_INFINI(infinirtStreamCreate(&stream)); + InferenceContext ctx(handle, pool, &cache_mgr, stream); + setInferenceContext(&ctx); + + auto compressed = compressor.compress(*kv, static_cast(max_len)); + if (!compressed) { + std::cerr << "compress returned nullptr (likely missing op_handle/memory_pool)" << std::endl; + } else { + auto &layer0 = compressed->layers[0]; + std::cout << "Compressed seq_len=" << layer0.comp_seq_len + << " orig=" << layer0.orig_seq_len << std::endl; + if (layer0.k_comp) { + std::cout << "k_comp shape: "; + for (auto d : layer0.k_comp->shape()) std::cout << d << " "; + std::cout << "\n"; + } + if (layer0.v_comp) { + std::cout << "v_comp shape: "; + for (auto d : layer0.v_comp->shape()) std::cout << d << " "; + std::cout << "\n"; + } + } + + dropKVCache(kv); + RUN_INFINI(infinirtStreamDestroy(stream)); + RUN_INFINI(infiniopDestroyHandle(handle)); + return 0; +} diff --git a/xmake.lua b/xmake.lua index 598ac534..8b886f7a 100644 --- a/xmake.lua +++ b/xmake.lua @@ -2,6 +2,7 @@ local INFINI_ROOT = os.getenv("INFINI_ROOT") or (os.getenv(is_host("windows") an target("infinicore_infer") set_kind("shared") + set_symbols("debug") -- add debug symbols add_includedirs("include", { public = false }) add_includedirs(INFINI_ROOT.."/include", { public = true }) @@ -22,5 +23,9 @@ target("infinicore_infer") set_installdir(INFINI_ROOT) add_installfiles("include/infinicore_infer.h", {prefixdir = "include"}) + add_installfiles("include/infinicore_infer/cache.h", {prefixdir = "include/infinicore_infer"}) + add_installfiles("include/infinicore_infer/kv_compression.h", {prefixdir = "include/infinicore_infer"}) + add_installfiles("include/infinicore_infer/weights_loader.h", {prefixdir = "include/infinicore_infer"}) add_installfiles("include/infinicore_infer/models/*.h", {prefixdir = "include/infinicore_infer/models"}) target_end() +