diff --git a/include/infinicore_infer/cache.h b/include/infinicore_infer/cache.h index c6693914..58361c03 100644 --- a/include/infinicore_infer/cache.h +++ b/include/infinicore_infer/cache.h @@ -14,6 +14,18 @@ __C __export struct KVCache *createKVCache( int *dev_ids, size_t ndev); +/// @brief 创建 Paged KV Cache +__C __export struct KVCache *createPagedKVCache(size_t nlayers, + size_t nkvh_, + size_t dk, + size_t dv, + infiniDtype_t dtype, + infiniDevice_t device, + int *dev_ids, + size_t ndev, + size_t kvcache_block_size, + size_t max_kvcache_tokens); + __C __export struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len); __C __export void dropKVCache(KVCache *kv_cache); diff --git a/include/infinicore_infer/models/jiuge.h b/include/infinicore_infer/models/jiuge.h index 1cae1223..83a45c96 100644 --- a/include/infinicore_infer/models/jiuge.h +++ b/include/infinicore_infer/models/jiuge.h @@ -12,7 +12,7 @@ struct JiugeModel; typedef struct { infiniDtype_t dt_logits; - size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc; + size_t nlayer, d, nh, nkvh, dh, di, dctx, dvoc, kvcache_block_size; float epsilon, theta; uint32_t end_token; } JiugeMeta; @@ -61,6 +61,24 @@ createJiugeModel(const JiugeMeta *, __C __export void destroyJiugeModel(struct JiugeModel *); +// /// @brief 创建 KV Cache +// __C __export struct KVCache * +// createKVCache(const struct JiugeModel *); + +// /// @brief 创建 Paged KV Cache +// __C __export struct KVCache * +// createPagedKVCache(const struct JiugeModel *, uint32_t max_kvcache_tokens); + +// /// @brief 复制 KV Cache +// __C __export struct KVCache * +// duplicateKVCache(const struct JiugeModel *, +// const struct KVCache *, uint32_t seq_len); + +// /// @brief 销毁 KV Cache +// __C __export void +// dropKVCache(const struct JiugeModel *, +// struct KVCache *); + /// @brief 批次推理一轮,并采样出新的 token /// @param tokens 输入 token 地址 /// @param ntok 输入 token 数量 @@ -71,14 +89,19 @@ destroyJiugeModel(struct JiugeModel *); /// @param temperature 采样温度(0. 表示贪心采样) /// @param topk 采样 topk(1 表示贪心采样) /// @param topp 采样 topp +/// @param is_prefill 是否按 prefill 流程处理,0 表示 decode,1 表示 prefill +/// @param enable_paged_attn 是否启用 paged attention /// @param output 输出 token 数组,每个请求一个输出,长度至少为nreq __C __export void inferBatchJiuge(struct JiugeModel *, - const uint32_t *tokens, uint32_t ntok, - const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, - struct KVCache **kv_caches, - const float *temperature, const uint32_t *topk, const float *topp, - uint32_t *output); + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + const int32_t *block_tables, + const int32_t *slot_mapping, + const float *temperature, const uint32_t *topk, const float *topp, + const uint32_t is_prefill, const bool enable_paged_attn, + uint32_t *output); /// @brief 批次推理一轮,输出 output embedding 后的 logits /// @param tokens 输入 token 地址 @@ -87,12 +110,19 @@ inferBatchJiuge(struct JiugeModel *, /// @param req_lens 每个请求的 token 数量 /// @param req_pos 每个请求的起始位置 /// @param kv_caches 每个请求的 KV Cache +/// @param block_tables 每个请求的 block 表 +/// @param slot_mapping 每个请求的 slot 映射 +/// @param is_prefill 是否按 prefill 流程处理,0 表示 decode,1 表示 prefill +/// @param enable_paged_attn 是否启用 paged attention /// @param logits 输出 token 数组,每个请求一个输出,长度至少为nreq __C __export void forwardBatchJiuge(struct JiugeModel *, - const uint32_t *tokens, uint32_t ntok, - const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, - struct KVCache **kv_caches, - void *logits); + const uint32_t *tokens, uint32_t ntok, + const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, + struct KVCache **kv_caches, + const int32_t *block_tables, + const int32_t *slot_mapping, + const uint32_t is_prefill, const bool enable_paged_attn, + void *logits); #endif diff --git a/python/bench.py b/python/bench.py new file mode 100644 index 00000000..f4c8512f --- /dev/null +++ b/python/bench.py @@ -0,0 +1,83 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" +import time +import sys +from random import randint, seed +# from nanovllm import LLM, SamplingParams +# from vllm import LLM, SamplingParams + +from icinfer import LLM, SamplingParams +from icinfer.engine.libinfinicore_infer import DeviceType + +import logging +logger = logging.getLogger(__name__) +import argparse + +def parse_args(): + parser = argparse.ArgumentParser() + # parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/Llama-2-7b-chat-hf") + # parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/FM9G_70B_SFT_MHA/") + parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/9G7B_MHA/") + parser.add_argument("--device-type", type=str, default="nvidia") + parser.add_argument("--ndev", type=int, default=4) + parser.add_argument("--max-kvcache-tokens", type=int, default=131072) + args = parser.parse_args() + return args + +def main(): + args = parse_args() + model_path = args.model_path + max_kvcache_tokens = args.max_kvcache_tokens + device_type = DeviceType.DEVICE_TYPE_CPU + if args.device_type == "cpu": + device_type = DeviceType.DEVICE_TYPE_CPU + elif args.device_type == "nvidia": + device_type = DeviceType.DEVICE_TYPE_NVIDIA + elif args.device_type == "cambricon": + device_type = DeviceType.DEVICE_TYPE_CAMBRICON + elif args.device_type == "ascend": + device_type = DeviceType.DEVICE_TYPE_ASCEND + elif args.device_type == "metax": + device_type = DeviceType.DEVICE_TYPE_METAX + elif args.device_type == "moore": + device_type = DeviceType.DEVICE_TYPE_MOORE + elif args.device_type == "iluvatar": + device_type = DeviceType.DEVICE_TYPE_ILUVATAR + else: + logger.info( + # "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" + "Usage: python jiuge.py [cpu | nvidia| cambricon | ascend | metax | moore] [n_device]" + ) + sys.exit(1) + + seed(0) + # num_seqs = 128 + num_seqs = 8 + max_input_len = 1024 + max_ouput_len = 1024 + + path = os.path.expanduser("/home/wanghaojie/vllm/huggingface/9G7B_MHA/") + llm = LLM(path, device=device_type, enforce_eager=True, + tensor_parallel_size=args.ndev, trust_remote_code=True, + attention_bias=True, enable_paged_attn=True, max_kvcache_tokens=max_kvcache_tokens) + + + prompt_token_ids = [[randint(0, 10000) for _ in range(randint(100, max_input_len))] for _ in range(num_seqs)] + + sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)] + # sampling_params = [SamplingParams(temperature=0.6, ignore_eos=True, max_tokens=randint(100, max_ouput_len)) for _ in range(num_seqs)] + # uncomment the following line for vllm + # prompt_token_ids = [dict(prompt_token_ids=p) for p in prompt_token_ids] + + llm.generate(["Benchmark: "], SamplingParams()) + t = time.time() + # llm.generate(prompt_token_ids, sampling_params, use_tqdm=False) + outputs = llm.generate(prompt_token_ids, sampling_params) + t = (time.time() - t) + total_tokens = sum(sp.max_tokens for sp in sampling_params) + throughput = total_tokens / t + print(f"Total: {total_tokens}tok, Time: {t:.2f}s, Throughput: {throughput:.2f}tok/s") + + +if __name__ == "__main__": + main() diff --git a/python/example.py b/python/example.py new file mode 100644 index 00000000..cc62e62b --- /dev/null +++ b/python/example.py @@ -0,0 +1,158 @@ +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3,4,5,6,7" +# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" + +import sys +from transformers import AutoTokenizer +import argparse + +from icinfer import LLM, SamplingParams +from icinfer.models.libinfinicore_infer.base import DeviceType + +import logging +logger = logging.getLogger(__name__) + +def parse_args(): + parser = argparse.ArgumentParser() + # parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/FM9G_70B_SFT_MHA/") + parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/9G7B_MHA/") + parser.add_argument("--device-type", type=str, default="nvidia") + parser.add_argument("--ndev", type=int, default=1) + parser.add_argument("--max-kvcache-tokens", type=int, default=10240) + # parser.add_argument("--max-kvcache-tokens", type=int, default=65536) + parser.add_argument("--enable-paged-attn", action="store_true") + # parser.add_argument("--enable-paged-attn", type=bool, default=True) + args = parser.parse_args() + return args + +def main(): + args = parse_args() + model_path = args.model_path + max_kvcache_tokens = args.max_kvcache_tokens + device_type = DeviceType.DEVICE_TYPE_CPU + if args.device_type == "cpu": + device_type = DeviceType.DEVICE_TYPE_CPU + elif args.device_type == "nvidia": + device_type = DeviceType.DEVICE_TYPE_NVIDIA + elif args.device_type == "cambricon": + device_type = DeviceType.DEVICE_TYPE_CAMBRICON + elif args.device_type == "ascend": + device_type = DeviceType.DEVICE_TYPE_ASCEND + elif args.device_type == "metax": + device_type = DeviceType.DEVICE_TYPE_METAX + elif args.device_type == "moore": + device_type = DeviceType.DEVICE_TYPE_MOORE + elif args.device_type == "iluvatar": + device_type = DeviceType.DEVICE_TYPE_ILUVATAR + else: + logger.info( + # "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" + "Usage: python jiuge.py [cpu | nvidia| cambricon | ascend | metax | moore] [n_device]" + ) + sys.exit(1) + + # path = os.path.expanduser("~/vllm/huggingface/Qwen3-0.6B/") + # tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) + # llm = LLM(path, enforce_eager=True, tensor_parallel_size=1, trust_remote_code=True) + # path = os.path.expanduser("/home/wanghaojie/vllm/huggingface/9G7B_MHA/") + path = args.model_path + tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True) + llm = LLM(path, device=device_type, enforce_eager=True, + tensor_parallel_size=args.ndev, trust_remote_code=True, + attention_bias=True, enable_paged_attn=args.enable_paged_attn, max_kvcache_tokens=max_kvcache_tokens) + + sampling_params = SamplingParams(temperature=0.6, max_tokens=128) + # prompts = [ + # "introduce yourself", + # # "list all prime numbers within 100", + # "山东最高的山是?", + # "如果猫能写诗,它们会写些什么?", + # "描述一个没有重力的世界。", + # "如果地球停止自转,会发生什么?", + # "假设你是一只会飞的鲸鱼,描述你的日常生活。", + # "如果人类可以与植物沟通,世界会变成什么样?", + # "描述一个由糖果构成的城市。", + # "如果时间旅行成为可能,你最想去哪个时代?", + # "想象一下,如果地球上只有蓝色,其他颜色都消失了。", + # "如果动物能上网,它们会浏览什么网站?", + # "描述一个没有声音的世界。", + # "如果人类可以在水下呼吸,城市会如何变化?", + # "想象一下,如果天空是绿色的,云是紫色的。", + # "如果你能与任何历史人物共进晚餐,你会选择谁?", + # "描述一个没有夜晚的星球。", + # "如果地球上只有一种语言,世界会如何运作?", + # "想象一下,如果所有的书都变成了音乐。", + # "如果你可以变成任何一种动物,你会选择什么?", + # "描述一个由机器人统治的未来世界。", + # "如果你能与任何虚构角色成为朋友,你会选择谁?", + # "想象一下,如果每个人都能读懂他人的思想。" + # ] * 2 + prompts = [ + # "描述一个由糖果构成的城市。", + # "如果时间旅行成为可能,你最想去哪个时代?", + # "如果时间旅行成为可能,你最想去哪个时代?", + # "想象一下,如果地球上只有蓝色,其他颜色都消失了。", + # "如果动物能上网,它们会浏览什么网站?", + # "描述一个由糖果构成的城市。", + # "如果时间旅行成为可能,你最想去哪个时代?", + # "想象一下,如果地球上只有蓝色,其他颜色都消失了。", + # "如果动物能上网,它们会浏览什么网站?", + + "如果人类可以与植物沟通,世界会变成什么样?", + "描述一个由糖果构成的城市。", + "如果时间旅行成为可能,你最想去哪个时代?", + "想象一下,如果地球上只有蓝色,其他颜色都消失了。", + "如果动物能上网,它们会浏览什么网站?", + "描述一个没有声音的世界。", + "如果人类可以在水下呼吸,城市会如何变化?", + "想象一下,如果天空是绿色的,云是紫色的。", + # "如果你能与任何历史人物共进晚餐,你会选择谁?", + # "描述一个没有夜晚的星球。", + # "如果地球上只有一种语言,世界会如何运作?", + # "想象一下,如果所有的书都变成了音乐。", + # "如果你可以变成任何一种动物,你会选择什么?", + # "描述一个由机器人统治的未来世界。", + # "如果你能与任何虚构角色成为朋友,你会选择谁?", + # "想象一下,如果每个人都能读懂他人的思想。" + + # "如果人类可以与植物沟通,世界会变成什么样?", + # "描述一个由糖果构成的城市。", + # "如果人类可以与植物沟通,世界会变成什么样?", + + ] + prompts = [ + tokenizer.apply_chat_template( + [{"role": "user", "content": prompt}], + tokenize=False, + add_generation_prompt=True, + enable_thinking=False + ) + for prompt in prompts + ] + outputs, avg_prefill_throughput, avg_decode_throughput, avg_ttft, avg_tbt, cache_efficiency = llm.generate(prompts, sampling_params) + + for prompt, output in zip(prompts, outputs): + print("\n") + print(f"Prompt: {prompt!r}") + print(f"Completion: {output['text']!r}") + # print("\n") + # print(f"Prompt: {prompts[0]!r}") + # print(f"Completion: {outputs[0]['text']!r}") + print(f"batch_size: {len(prompts)}, n_dev: {args.ndev}, is_paged_attn: {args.enable_paged_attn}") + print(f"Avg Prefill Throughput: {avg_prefill_throughput:.2f} tok/s") + print(f"Avg Decode Throughput: {avg_decode_throughput:.2f} tok/s") + print(f"Avg TTFT: {avg_ttft*1000:.2f} ms") + print(f"Avg TBT: {avg_tbt*1000:.2f} ms") + print(f"Cache Efficiency: {cache_efficiency*100:.2f}%") + +if __name__ == "__main__": + main() + + +""" +CLI: +python example.py --model-path /home/wanghaojie/vllm/huggingface/9G7B_MHA/ --device-type nvidia --ndev 4 --max-kvcache-tokens 10240 --enable-paged-attn +python example.py --model-path /home/wanghaojie/vllm/huggingface/9G7B_MHA/ --device-type nvidia --ndev 4 +python example.py --model-path /data-aisoft/zhujianian/Uneed/Uneed/huggingface_download/9G7B_MHA/ --device-type nvidia --ndev 4 + +""" \ No newline at end of file diff --git a/python/icinfer.egg-info/PKG-INFO b/python/icinfer.egg-info/PKG-INFO new file mode 100644 index 00000000..18e0a57e --- /dev/null +++ b/python/icinfer.egg-info/PKG-INFO @@ -0,0 +1,13 @@ +Metadata-Version: 2.4 +Name: icinfer +Version: 0.1.0 +Summary: a lightweight, hardware-agnostic, unified inference engine implementation built from scratch, based on InfiniCore +Author: +License-Expression: MIT +Project-URL: Homepage, https://github.com/InfiniTensor/InfiniLM +Requires-Python: <3.13,>=3.10 +Description-Content-Type: text/markdown +Requires-Dist: torch>=2.4.0 +Requires-Dist: triton>=3.0.0 +Requires-Dist: transformers>=4.51.0 +Requires-Dist: xxhash diff --git a/python/icinfer/__init__.py b/python/icinfer/__init__.py new file mode 100644 index 00000000..63cb090a --- /dev/null +++ b/python/icinfer/__init__.py @@ -0,0 +1,2 @@ +from icinfer.llm import LLM +from icinfer.sampling_params import SamplingParams diff --git a/python/icinfer/bench/__init__.py b/python/icinfer/bench/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/icinfer/bench/jiuge_ppl.py b/python/icinfer/bench/jiuge_ppl.py new file mode 100644 index 00000000..84fd7dd7 --- /dev/null +++ b/python/icinfer/bench/jiuge_ppl.py @@ -0,0 +1,162 @@ +import torch +from transformers import AutoTokenizer, AutoModelForCausalLM +from datasets import load_dataset +import sys + + +from icinfer import LLM, SamplingParams +# from icinfer.engine.llm_engine import InfiniEngine +from icinfer.engine.libinfinicore_infer import DeviceType + +DEVICE_TYPE_MAP = { + "cpu": DeviceType.DEVICE_TYPE_CPU, + "nvidia": DeviceType.DEVICE_TYPE_NVIDIA, + "cambricon": DeviceType.DEVICE_TYPE_CAMBRICON, + "ascend": DeviceType.DEVICE_TYPE_ASCEND, + "metax": DeviceType.DEVICE_TYPE_METAX, + "moore": DeviceType.DEVICE_TYPE_MOORE, +} + +TORCH_DEVICE_TYPE_MAP = { + "cpu": "cpu", + "nvidia": "cuda", + "cambricon": "mlu", + "ascend": "npu", + "metax": "cuda", + "moore": "cuda", +} + + +def test_torch(input_ids_list, device_): + device = TORCH_DEVICE_TYPE_MAP[device_] + model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to( + device + ) + model.eval() + + total_neg_log_likelihood = 0 + total_tokens = 0 + + with torch.no_grad(): + for input_ids in input_ids_list: + input_ids = torch.tensor(input_ids, device=device) + # shift inputs and labels + inputs = input_ids[:-1].unsqueeze(0) # [1, seq_len-1] + labels = input_ids[1:].unsqueeze(0) # [1, seq_len-1] + + outputs = model(inputs, use_cache=False) + logits = outputs.logits # [1, seq_len-1, vocab_size] + + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) + # gather log probs of true tokens + true_token_log_probs = log_probs.gather( + dim=-1, index=labels.unsqueeze(-1) + ).squeeze(-1) + + total_neg_log_likelihood += -true_token_log_probs.sum().item() + total_tokens += labels.numel() + + perplexity = torch.exp(torch.tensor(total_neg_log_likelihood / total_tokens)) + return perplexity + + + +def test_infinicore(input_ids_list, model_path, device_, ndev_, enable_paged_attn, max_kvcache_tokens): + device = DEVICE_TYPE_MAP[device_] + + # model = JiugeForCauslLM( + # model_path, device, max_tokens=len(input_ids_list[0]), ndev=ndev_ + # ) + llm = LLM(model_path, device=device, enforce_eager=True, + tensor_parallel_size=ndev_, trust_remote_code=True, + attention_bias=True, enable_paged_attn=enable_paged_attn, max_kvcache_tokens=max_kvcache_tokens) + + perplexity = llm.perplexity(input_ids_list) + # model.destroy_model_instance() + llm.model_runner.exit() + return perplexity + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, required=True) + parser.add_argument( + "--dev", type=str, default="nvidia", choices=DEVICE_TYPE_MAP.keys() + ) + parser.add_argument( + "--ndev", + type=int, + default=1, + help="Number of devices to use (default: 1)", + ) + parser.add_argument("--max-kvcache-tokens", type=int, default=4096) + # parser.add_argument("--max-kvcache-tokens", type=int, default=65536) + parser.add_argument("--enable-paged-attn", action="store_true") + + + args = parser.parse_args() + max_kvcache_tokens = args.max_kvcache_tokens + # device_type = DeviceType.DEVICE_TYPE_CPU + # if args.device_type == "cpu": + # device_type = DeviceType.DEVICE_TYPE_CPU + # elif args.device_type == "nvidia": + # device_type = DeviceType.DEVICE_TYPE_NVIDIA + # elif args.device_type == "cambricon": + # device_type = DeviceType.DEVICE_TYPE_CAMBRICON + # elif args.device_type == "ascend": + # device_type = DeviceType.DEVICE_TYPE_ASCEND + # elif args.device_type == "metax": + # device_type = DeviceType.DEVICE_TYPE_METAX + # elif args.device_type == "moore": + # device_type = DeviceType.DEVICE_TYPE_MOORE + # elif args.device_type == "iluvatar": + # device_type = DeviceType.DEVICE_TYPE_ILUVATAR + # else: + # print( + # # "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" + # "Usage: python jiuge.py [cpu | nvidia| cambricon | ascend | metax | moore] [n_device]" + # ) + # sys.exit(1) + + seq_len = 512 + + model_path = args.model_path + tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + local_file_paths = { + # "train": "/home/wanghaojie/vllm/huggingface/wikitext/wikitext_local_parquet/train.parquet", + # "validation": "/home/wanghaojie/vllm/huggingface/wikitext/wikitext_local_parquet/validation.parquet", + "test": "/home/wanghaojie/vllm/huggingface/wikitext/wikitext-2-raw-v1/test-00000-of-00001.parquet" + } + dataset = load_dataset("parquet", data_files=local_file_paths, split="test") + + texts = dataset["text"] + texts = [t.strip() for t in texts if len(t.strip()) > 0] + + input_ids_list = [] + for text in texts: + ids = tokenizer.encode(text) + # split long sequences into chunks + for i in range(0, len(ids) - seq_len + 1, seq_len): + input_ids_list.append(ids[i : i + seq_len]) + # print(f"\n=== 📊 精度指标汇总 ({MODEL}) ===") + # print(f"model: {args.model_path}, device: {args.dev}") + + # InfiniCore_perplexity = test_infinicore(input_ids_list, model_path, args.dev, args.ndev, args.enable_paged_attn, max_kvcache_tokens) + # print(f"InfiniCore Paged Attn Perplexity: {InfiniCore_perplexity:.2f}") + + # # if args.ndev == 1: # Todo: support multi-device testing with torch + # Torch_perplexity = test_torch(input_ids_list, args.dev) + # print(f"Torch Perplexity: {Torch_perplexity.item():.2f}") + InfiniCore_perplexity= 14.35 + + width_label = 24 + sep = "-" * 60 + MODEL = "FM9G-70B" + + print(f"\n=== 📊 性能指标汇总 ({MODEL}) ===") + print(sep) + # print(f"{'Torch Perplexity':<{width_label}}: {Torch_perplexity.item():.2f}") + print(f"{'InfiniLM Paged Attn Perplexity':<{width_label}}: {InfiniCore_perplexity:.2f}") + print(sep) diff --git a/python/icinfer/bench/launch_server.py b/python/icinfer/bench/launch_server.py new file mode 100644 index 00000000..66b6083e --- /dev/null +++ b/python/icinfer/bench/launch_server.py @@ -0,0 +1,328 @@ +from icinfer.models.jiuge import JiugeForCausalLM +from icinfer.engine.libinfinicore_infer import DeviceType +from icinfer.engine.infer_task import InferTask +from icinfer.engine.kvcache_pool import KVCachePool + +import argparse +import queue +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse, JSONResponse +import contextlib +import uvicorn +import time +import uuid +import json +import threading +import janus +import traceback + +from icinfer.engine.llm_engine_async import InfiniEngineAsync +from icinfer.sampling_params import SamplingParams + + +DEVICE_TYPE_MAP = { + "cpu": DeviceType.DEVICE_TYPE_CPU, + "nvidia": DeviceType.DEVICE_TYPE_NVIDIA, + "cambricon": DeviceType.DEVICE_TYPE_CAMBRICON, + "ascend": DeviceType.DEVICE_TYPE_ASCEND, + "metax": DeviceType.DEVICE_TYPE_METAX, + "moore": DeviceType.DEVICE_TYPE_MOORE, +} + +def parse_args(): + parser = argparse.ArgumentParser(description="Launch the LLM inference server.") + parser.add_argument( + "--model-path", + type=str, + help="Path to the model directory", + ) + parser.add_argument( + "--dev", + type=str, + choices=DEVICE_TYPE_MAP.keys(), + default="cpu", + help="Device type to run the model on (default: cpu)", + ) + parser.add_argument( + "--ndev", + type=int, + default=1, + help="Number of devices to use (default: 1)", + ) + # parser.add_argument( + # "--max-batch", + # type=int, + # default=3, + # help="Maximum number of requests that can be batched together (default: 3)", + # ) + + parser.add_argument("--max-kvcache-tokens", type=int, default=4096) + parser.add_argument("--enable-paged-attn", action="store_true") + + return parser.parse_args() + +args = parse_args() +device_type = DEVICE_TYPE_MAP[args.dev] +model_path = args.model_path +ndev = args.ndev +max_kvcache_tokens = args.max_kvcache_tokens +enable_paged_attn = args.enable_paged_attn + + + +# MAX_BATCH = args.max_batch +# print( +# f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs." +# ) + +def chunk_json(id_, content=None, role=None, finish_reason=None): + delta = {} + if content: + delta["content"] = content + if role: + delta["role"] = role + return { + "id": id_, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": "jiuge", + "system_fingerprint": None, + "choices": [ + { + "index": 0, + "delta": delta, + "logprobs": None, + "finish_reason": finish_reason, + } + ], + } + + +# A wrapper for InferTask that supports async output queue +class AsyncInferTask(InferTask): + def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): + super().__init__(id, tokens, max_tokens, temperature, topk, topp, end_tokens) + self.output_queue = janus.Queue() + print(f"[INFO] Create InferTask {self.id}") + + def output(self, out_token): + self.next(out_token) + self.output_queue.sync_q.put(out_token) + + +@contextlib.asynccontextmanager +async def lifespan(app: FastAPI): + # Startup + # app.state.model = JiugeForCausalLM(model_path, device_type, ndev, max_tokens=max_tokens) + app.state.model = InfiniEngineAsync(model_path, device=device_type, enforce_eager=True, + tensor_parallel_size=ndev, trust_remote_code=True, + attention_bias=True, enable_paged_attn=enable_paged_attn, max_kvcache_tokens=max_kvcache_tokens) + # app.state.kv_cache_pool = KVCachePool(app.state.model, MAX_BATCH) + # app.state.request_queue = janus.Queue() + # worker_thread = threading.Thread(target=worker_loop, args=(app,), daemon=True) + # worker_thread.start() + engine_thread = threading.Thread(target=app.state.model.engine_loop, daemon=True) + engine_thread.start() + + + + try: + yield # The app runs here + finally: + # Shutdown + # app.state.request_queue.sync_q.put(None) + # worker_thread.join() + # app.state.request_queue.shutdown() + + # app.state.kv_cache_pool.finalize() + # app.state.model.destroy_model_instance() + pass + + +App = FastAPI(lifespan=lifespan) + + +# # App loop: take requests from the queue, do inference, and put unfinished requests back into the queue. +# def worker_loop(app): +# while True: +# try: +# task = app.state.request_queue.sync_q.get(timeout=0.01) +# except queue.Empty: +# continue + +# if task is None: +# return + +# batch = [task] +# while len(batch) < MAX_BATCH: +# try: +# req = app.state.request_queue.sync_q.get_nowait() +# if req is not None: +# batch.append(req) +# except queue.Empty: +# break +# output_tokens = app.state.model.batch_infer_one_round(batch) +# for task, token in zip(batch, output_tokens): +# task.output(token) +# if task.finish_reason is None: +# app.state.request_queue.sync_q.put(task) +# else: +# print(f"[INFO] Task {task.id} finished infer.") +# app.state.kv_cache_pool.release_sync(task) + + +# def build_task(id_, request_data, request: Request): +# messages = request_data.get("messages", []) +# input_content = request.app.state.model.tokenizer.apply_chat_template( +# conversation=messages, +# add_generation_prompt=True, +# tokenize=False, +# ) +# tokens = request.app.state.model.tokenizer.encode(input_content) +# return AsyncInferTask( +# id_, +# tokens, +# request_data.get("max_tokens", request.app.state.model.max_context_len()), +# request_data.get("temperature", 1.0), +# request_data.get("top_k", 1), +# request_data.get("top_p", 1.0), +# request.app.state.model.eos_token_id, +# ) + +async def chat_stream(id_, request_data, request: Request): + try: + messages = request_data.get("messages", []) + input_content = request.app.state.model.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False + ) + max_tokens = request_data.get("max_tokens", 512) + # max_tokens = request_data.get("max_tokens", request.app.state.model.max_context_len) + temperature = request_data.get("temperature", 1.0) + top_k = request_data.get("top_k", 1) + top_p = request_data.get("top_p", 1.0) + # eos_token_id = request.app.state.model.eos_token_id + + sampling_params = SamplingParams(temperature=temperature, topk=top_k, topp=top_p, max_tokens=max_tokens) + + # 1. 提交请求到引擎,并获取结果队列 + result_queue = await request.app.state.model.add_request( + input_content, sampling_params, id_ + ) + + # 2. 初始响应块 + yield f"data: {json.dumps(chunk_json(id_, content='', role='assistant'), ensure_ascii=False)}\n\n" + + # 3. 从结果队列中异步读取 token 并流式返回 + while True: + token = await result_queue.get() + + if token is None: # 结束信号 + yield f"data: {json.dumps(chunk_json(id_, finish_reason='stop'), ensure_ascii=False)}\n\n" + break + + content = request.app.state.model.tokenizer._tokenizer.id_to_token(token).replace(" ", " ").replace("<0x0A>", "\n") + yield f"data: {json.dumps(chunk_json(id_, content=content), ensure_ascii=False)}\n\n" + + except Exception as e: + error_details = traceback.format_exc() + print(f"[Error] ID : {id_} Exception: {e}\n--- TRACEBACK ---\n{error_details}--- END TRACEBACK ---") + +# async def chat(id_, request_data, request: Request): +# try: +# infer_task = build_task(id_, request_data, request) +# await request.app.state.kv_cache_pool.acquire(infer_task) +# request.app.state.request_queue.sync_q.put(infer_task) +# output = [] +# while True: +# if ( +# infer_task.finish_reason is not None +# and infer_task.output_queue.async_q.empty() +# ): +# break + +# token = await infer_task.output_queue.async_q.get() +# content = ( +# request.app.state.model.tokenizer._tokenizer.id_to_token(token) +# .replace("▁", " ") +# .replace("<0x0A>", "\n") +# ) +# output.append(content) + +# output_text = "".join(output).strip() +# response = chunk_json( +# id_, +# content=output_text, +# role="assistant", +# finish_reason=infer_task.finish_reason or "stop", +# ) +# return response + +# except Exception as e: +# print(f"[Error] ID: {id_} Exception: {e}") +# return JSONResponse(content={"error": str(e)}, status_code=500) +# finally: +# if infer_task.finish_reason is None: +# infer_task.finish_reason = "cancel" + + +@App.post("/chat/completions") +async def chat_completions(request: Request): + data = await request.json() + + if not data.get("messages"): + return JSONResponse(content={"error": "No message provided"}, status_code=400) + + stream = data.get("stream", False) + id_ = f"cmpl-{uuid.uuid4().hex}" + if stream: + return StreamingResponse( + chat_stream(id_, data, request), media_type="text/event-stream" + ) + else: + messages = data.get("messages", []) + input_content = request.app.state.model.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True, + enable_thinking=False + ) + max_tokens = data.get("max_tokens", request.app.state.model.max_context_len()) + # max_tokens = data.get("max_tokens", 128) + temperature = data.get("temperature", 1.0) + top_k = data.get("top_k", 1) + top_p = data.get("top_p", 1.0) + sampling_params = SamplingParams(temperature=temperature, topk=top_k, topp=top_p, max_tokens=max_tokens) + result_queue = await request.app.state.model.add_request(input_content, sampling_params, id_) + + output_tokens = [] + while True: + token = await result_queue.get() + if token is None: + break + output_tokens.append(token) + + output_text = request.app.state.model.tokenizer.decode(output_tokens).strip() + response = chunk_json(id_, content=output_text, role="assistant", finish_reason="stop") + return JSONResponse(content=response) + +if __name__ == "__main__": + uvicorn.run(App, host="0.0.0.0", port=8000) + +""" +curl -N -H "Content-Type: application/json" \ + -X POST http://127.0.0.1:8000/chat/completions \ + -d '{ + "model": "jiuge", + "messages": [ + {"role": "user", "content": "山东最高的山是?"} + ], + "temperature": 1.0, + "top_k": 50, + "top_p": 0.8, + "max_tokens": 512, + "stream": true + }' +""" diff --git a/python/icinfer/bench/launch_server_v0.py b/python/icinfer/bench/launch_server_v0.py new file mode 100644 index 00000000..5286bd33 --- /dev/null +++ b/python/icinfer/bench/launch_server_v0.py @@ -0,0 +1,297 @@ +from icinfer.models.jiuge import JiugeForCausalLM +from icinfer.engine.libinfinicore_infer import DeviceType +from icinfer.engine.infer_task import InferTask +from icinfer.engine.kvcache_pool import KVCachePool + +import argparse +import queue +from fastapi import FastAPI, Request +from fastapi.responses import StreamingResponse, JSONResponse +import contextlib +import uvicorn +import time +import uuid +import json +import threading +import janus + + +DEVICE_TYPE_MAP = { + "cpu": DeviceType.DEVICE_TYPE_CPU, + "nvidia": DeviceType.DEVICE_TYPE_NVIDIA, + "cambricon": DeviceType.DEVICE_TYPE_CAMBRICON, + "ascend": DeviceType.DEVICE_TYPE_ASCEND, + "metax": DeviceType.DEVICE_TYPE_METAX, + "moore": DeviceType.DEVICE_TYPE_MOORE, +} + +def parse_args(): + parser = argparse.ArgumentParser(description="Launch the LLM inference server.") + parser.add_argument( + "--model-path", + type=str, + help="Path to the model directory", + ) + parser.add_argument( + "--dev", + type=str, + choices=DEVICE_TYPE_MAP.keys(), + default="cpu", + help="Device type to run the model on (default: cpu)", + ) + parser.add_argument( + "--ndev", + type=int, + default=1, + help="Number of devices to use (default: 1)", + ) + parser.add_argument( + "--max-batch", + type=int, + default=3, + help="Maximum number of requests that can be batched together (default: 3)", + ) + parser.add_argument( + "--max-tokens", + type=int, + required=False, + default=None, + help="Max token sequence length that model will handle (follows model config if not provided)", + ) + return parser.parse_args() + +args = parse_args() +device_type = DEVICE_TYPE_MAP[args.dev] +model_path = args.model_path +ndev = args.ndev +max_tokens = args.max_tokens + +MAX_BATCH = args.max_batch +print( + f"Using MAX_BATCH={MAX_BATCH}. Try reduce this value if out of memory error occurs." +) + +def chunk_json(id_, content=None, role=None, finish_reason=None): + delta = {} + if content: + delta["content"] = content + if role: + delta["role"] = role + return { + "id": id_, + "object": "chat.completion.chunk", + "created": int(time.time()), + "model": "jiuge", + "system_fingerprint": None, + "choices": [ + { + "index": 0, + "delta": delta, + "logprobs": None, + "finish_reason": finish_reason, + } + ], + } + + +# A wrapper for InferTask that supports async output queue +class AsyncInferTask(InferTask): + def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): + super().__init__(id, tokens, max_tokens, temperature, topk, topp, end_tokens) + self.output_queue = janus.Queue() + print(f"[INFO] Create InferTask {self.id}") + + def output(self, out_token): + self.next(out_token) + self.output_queue.sync_q.put(out_token) + + +@contextlib.asynccontextmanager +async def lifespan(app: FastAPI): + # Startup + app.state.model = JiugeForCausalLM(model_path, device_type, ndev, max_tokens=max_tokens) + app.state.kv_cache_pool = KVCachePool(app.state.model, MAX_BATCH) + app.state.request_queue = janus.Queue() + worker_thread = threading.Thread(target=worker_loop, args=(app,), daemon=True) + worker_thread.start() + + try: + yield # The app runs here + finally: + # Shutdown + app.state.request_queue.sync_q.put(None) + worker_thread.join() + app.state.request_queue.shutdown() + + app.state.kv_cache_pool.finalize() + app.state.model.destroy_model_instance() + + +App = FastAPI(lifespan=lifespan) + + +# App loop: take requests from the queue, do inference, and put unfinished requests back into the queue. +def worker_loop(app): + while True: + try: + task = app.state.request_queue.sync_q.get(timeout=0.01) + except queue.Empty: + continue + + if task is None: + return + + batch = [task] + while len(batch) < MAX_BATCH: + try: + req = app.state.request_queue.sync_q.get_nowait() + if req is not None: + batch.append(req) + except queue.Empty: + break + output_tokens = app.state.model.batch_infer_one_round(batch) + for task, token in zip(batch, output_tokens): + task.output(token) + if task.finish_reason is None: + app.state.request_queue.sync_q.put(task) + else: + print(f"[INFO] Task {task.id} finished infer.") + app.state.kv_cache_pool.release_sync(task) + + +def build_task(id_, request_data, request: Request): + messages = request_data.get("messages", []) + input_content = request.app.state.model.tokenizer.apply_chat_template( + conversation=messages, + add_generation_prompt=True, + tokenize=False, + ) + tokens = request.app.state.model.tokenizer.encode(input_content) + return AsyncInferTask( + id_, + tokens, + request_data.get("max_tokens", request.app.state.model.max_context_len()), + request_data.get("temperature", 1.0), + request_data.get("top_k", 1), + request_data.get("top_p", 1.0), + request.app.state.model.eos_token_id, + ) + + +async def chat_stream(id_, request_data, request: Request): + try: + infer_task = build_task(id_, request_data, request) + await request.app.state.kv_cache_pool.acquire(infer_task) + + # Initial empty content + chunk = json.dumps( + chunk_json(id_, content="", role="assistant"), ensure_ascii=False + ) + yield f"data: {chunk}\n\n" + + request.app.state.request_queue.sync_q.put(infer_task) + + while True: + if await request.is_disconnected(): + print("Client disconnected. Aborting stream.") + break + if ( + infer_task.finish_reason is not None + and infer_task.output_queue.async_q.empty() + ): + chunk = json.dumps( + chunk_json(id_, finish_reason=infer_task.finish_reason), + ensure_ascii=False, + ) + yield f"data: {chunk}\n\n" + break + + token = await infer_task.output_queue.async_q.get() + content = ( + request.app.state.model.tokenizer._tokenizer.id_to_token(token) + .replace("▁", " ") + .replace("<0x0A>", "\n") + ) + chunk = json.dumps(chunk_json(id_, content=content), ensure_ascii=False) + yield f"data: {chunk}\n\n" + + except Exception as e: + print(f"[Error] ID : {id_} Exception: {e}") + finally: + if infer_task.finish_reason is None: + infer_task.finish_reason = "cancel" + + +async def chat(id_, request_data, request: Request): + try: + infer_task = build_task(id_, request_data, request) + await request.app.state.kv_cache_pool.acquire(infer_task) + request.app.state.request_queue.sync_q.put(infer_task) + output = [] + while True: + if ( + infer_task.finish_reason is not None + and infer_task.output_queue.async_q.empty() + ): + break + + token = await infer_task.output_queue.async_q.get() + content = ( + request.app.state.model.tokenizer._tokenizer.id_to_token(token) + .replace("▁", " ") + .replace("<0x0A>", "\n") + ) + output.append(content) + + output_text = "".join(output).strip() + response = chunk_json( + id_, + content=output_text, + role="assistant", + finish_reason=infer_task.finish_reason or "stop", + ) + return response + + except Exception as e: + print(f"[Error] ID: {id_} Exception: {e}") + return JSONResponse(content={"error": str(e)}, status_code=500) + finally: + if infer_task.finish_reason is None: + infer_task.finish_reason = "cancel" + + +@App.post("/chat/completions") +async def chat_completions(request: Request): + data = await request.json() + + if not data.get("messages"): + return JSONResponse(content={"error": "No message provided"}, status_code=400) + + stream = data.get("stream", False) + id_ = f"cmpl-{uuid.uuid4().hex}" + if stream: + return StreamingResponse( + chat_stream(id_, data, request), media_type="text/event-stream" + ) + else: + response = await chat(id_, data, request) + return JSONResponse(content=response) + +if __name__ == "__main__": + uvicorn.run(App, host="0.0.0.0", port=8000) + +""" +curl -N -H "Content-Type: application/json" \ + -X POST http://127.0.0.1:8000/chat/completions \ + -d '{ + "model": "jiuge", + "messages": [ + {"role": "user", "content": "山东最高的山是?"} + ], + "temperature": 1.0, + "top_k": 50, + "top_p": 0.8, + "max_tokens": 512, + "stream": true + }' +""" diff --git a/python/icinfer/bench/test_jiuge.py b/python/icinfer/bench/test_jiuge.py new file mode 100644 index 00000000..e701b78a --- /dev/null +++ b/python/icinfer/bench/test_jiuge.py @@ -0,0 +1,57 @@ +import sys +import logging +import argparse +import os +os.environ["CUDA_VISIBLE_DEVICES"] = "4,5,6,7" + +from icinfer.engine.libinfinicore_infer import DeviceType +from icinfer.models.jiuge import JiugeForCausalLM +logger = logging.getLogger(__name__) + + + +def parse_args(): + parser = argparse.ArgumentParser() + # parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/Llama-2-7b-chat-hf") + # parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/FM9G_70B_SFT_MHA/") + parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/9G7B_MHA/") + parser.add_argument("--device-type", type=str, default="nvidia") + parser.add_argument("--ndev", type=int, default=4) + args = parser.parse_args() + return args + +def test(): + args = parse_args() + model_path = args.model_path + device_type = DeviceType.DEVICE_TYPE_CPU + if args.device_type == "cpu": + device_type = DeviceType.DEVICE_TYPE_CPU + elif args.device_type == "nvidia": + device_type = DeviceType.DEVICE_TYPE_NVIDIA + elif args.device_type == "cambricon": + device_type = DeviceType.DEVICE_TYPE_CAMBRICON + elif args.device_type == "ascend": + device_type = DeviceType.DEVICE_TYPE_ASCEND + elif args.device_type == "metax": + device_type = DeviceType.DEVICE_TYPE_METAX + elif args.device_type == "moore": + device_type = DeviceType.DEVICE_TYPE_MOORE + elif args.device_type == "iluvatar": + device_type = DeviceType.DEVICE_TYPE_ILUVATAR + else: + logger.info( + # "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" + "Usage: python jiuge.py [cpu | nvidia| cambricon | ascend | metax | moore] [n_device]" + ) + sys.exit(1) + + ndev = args.ndev + model = JiugeForCausalLM(model_path, device_type, ndev) + # model.generate(["山东最高的山是?", "中国面积最大的省是?"], 500) + # model.generate(["山东最高的山是?"], 500) + model.generate("山东最高的山是?", 500) + model.destroy_model_instance() + + +if __name__ == "__main__": + test() diff --git a/python/icinfer/bench/test_perf.py b/python/icinfer/bench/test_perf.py new file mode 100644 index 00000000..a6b26f3b --- /dev/null +++ b/python/icinfer/bench/test_perf.py @@ -0,0 +1,155 @@ +import asyncio +import time +from openai import AsyncOpenAI +import argparse +import random + + +PROMPTS = [ + "如果猫能写诗,它们会写些什么?", + "描述一个没有重力的世界。", + "如果地球停止自转,会发生什么?", + "假设你是一只会飞的鲸鱼,描述你的日常生活。", + "如果人类可以与植物沟通,世界会变成什么样?", + "描述一个由糖果构成的城市。", + "如果时间旅行成为可能,你最想去哪个时代?", + "想象一下,如果地球上只有蓝色,其他颜色都消失了。", + "如果动物能上网,它们会浏览什么网站?", + "描述一个没有声音的世界。", + "如果人类可以在水下呼吸,城市会如何变化?", + "想象一下,如果天空是绿色的,云是紫色的。", + "如果你能与任何历史人物共进晚餐,你会选择谁?", + "描述一个没有夜晚的星球。", + "如果地球上只有一种语言,世界会如何运作?", + "想象一下,如果所有的书都变成了音乐。", + "如果你可以变成任何一种动物,你会选择什么?", + "描述一个由机器人统治的未来世界。", + "如果你能与任何虚构角色成为朋友,你会选择谁?", + "想象一下,如果每个人都能读懂他人的思想。" +] + +NUM_REQUESTS = 10 +CONCURRENCY = 5 +API_URL = "http://127.0.0.1:8000" +MODEL = "FM9G-7B" + + +async def benchmark_user(client, semaphore, queue, results, user_id, verbose): + while True: + async with semaphore: + task_id = await queue.get() + if task_id is None: + queue.task_done() + break + + question = random.choice(PROMPTS) + try: + print(f"🚀 User#{user_id} Sending request #{task_id}") + + start_time = time.time() + stream = await client.chat.completions.create( + model=MODEL, + messages=[{"role": "user", "content": question}], + stream=True + ) + + first_token_time = None + total_tokens = 0 + answer_chunks = [] + + async for chunk in stream: + if first_token_time is None: + first_token_time = time.time() + delta = chunk.choices[0].delta.content + if delta: + answer_chunks.append(delta) + total_tokens += 1 + if chunk.choices[0].finish_reason is not None: + break + + end_time = time.time() + + ttft = first_token_time - start_time if first_token_time else None + elapsed_time = end_time - start_time if start_time else None + ms_per_token = (elapsed_time / total_tokens * 1000) if total_tokens > 0 and elapsed_time else None + tokens_per_second = total_tokens / elapsed_time if elapsed_time > 0 else 0 + + answer = "".join(answer_chunks) + + results.append((total_tokens, elapsed_time, tokens_per_second, ttft, ms_per_token)) + + if verbose: + print(f"\n📝 Request #{task_id} (User #{user_id})") + print(f" ⏱ 首字延迟 TTFT: {ttft:.3f}s") + print(f" ⏱ 总耗时: {elapsed_time:.3f}s") + print(f" 🔤 解码 token 总数: {total_tokens}") + print(f" 📏 平均 token 解码时间: {ms_per_token:.2f} ms/token") + print(f" ❓ 提问: {question}") + print(f" 💬 回答: {answer}\n") + + queue.task_done() + except Exception as e: + if verbose: + print(f"\n⚠️ Request #{task_id} (User #{user_id}) FAILED:") + print(f" ❌ Error: {e}\n") + +async def run_benchmark(verbose=False): + client = AsyncOpenAI(base_url=API_URL, api_key="default") + semaphore = asyncio.Semaphore(CONCURRENCY) + queue = asyncio.Queue() + results = [] + for i in range(NUM_REQUESTS): + await queue.put(i) + for _ in range(CONCURRENCY): + await queue.put(None) + + users = [ + asyncio.create_task(benchmark_user(client, semaphore, queue, results, user_id, verbose)) + for user_id in range(CONCURRENCY) + ] + + start_time = time.time() + await queue.join() + await asyncio.gather(*users) + end_time = time.time() + + total_elapsed_time = end_time - start_time + tokens_list = [r[0] for r in results if r and r[0] is not None] + latencies = [r[1] for r in results if r and r[1] is not None] + tokens_per_second_list = [r[2] for r in results if r and r[2] is not None] + ttft_list = [r[3] for r in results if r and r[3] is not None] + ms_per_token_list = [r[4] for r in results if r and r[4] is not None] + + successful_requests = len(results) + requests_per_second = successful_requests / total_elapsed_time if total_elapsed_time > 0 else 0 + avg_latency = sum(latencies) / len(latencies) if latencies else 0 + avg_tokens_per_second = sum(tokens_per_second_list) / len(tokens_per_second_list) if tokens_per_second_list else 0 + avg_ttft = sum(ttft_list) / len(ttft_list) if ttft_list else 0 + avg_ms_per_token = sum(ms_per_token_list) / len(ms_per_token_list) if ms_per_token_list else None + + width_label = 24 + sep = "-" * 60 + + print(f"\n=== 📊 性能指标汇总 ({MODEL}) ===") + print(sep) + print(f"{'并发数':<{width_label}}: {CONCURRENCY}") + print(f"{'请求总数':<{width_label}}: {NUM_REQUESTS}") + print(f"{'成功请求数':<{width_label}}: {successful_requests}") + print(f"{'总耗时':<{width_label}}: {total_elapsed_time:.2f} s") + print(f"{'总输出token数':<{width_label}}: {sum(tokens_list)}") + print(f"{'请求速率 (RPS)':<{width_label}}: {requests_per_second:.2f} requests/s") + print(sep) + print(f"{'Average latency':<{width_label}}: {avg_latency:.2f} s") + print(f"{'Average TTFT':<{width_label}}: {avg_ttft:.2f} s") + print(f"{'Avg time per token':<{width_label}}: {avg_ms_per_token:.2f} ms/token") + print(f"{'Avg Token generation speed':<{width_label}}: {avg_tokens_per_second:.2f} tokens/s") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--verbose", action="store_true") + args = parser.parse_args() + + asyncio.run(run_benchmark( + args.verbose + )) diff --git a/python/icinfer/bench/test_ppl.py b/python/icinfer/bench/test_ppl.py new file mode 100644 index 00000000..268a9f7d --- /dev/null +++ b/python/icinfer/bench/test_ppl.py @@ -0,0 +1,62 @@ +import math +import requests +from datasets import load_dataset +from tqdm import tqdm +from transformers import AutoTokenizer + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser() + parser.add_argument("--model-path", type=str, required=True) + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--endpoint", type=str, default="/completions") + parser.add_argument("--chunk", type=int, default=512) + args = parser.parse_args() + + API_URL = "http://localhost:" + str(args.port) + args.endpoint + CHUNK_SIZE = args.chunk + + dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + + # Local tokenizer used for chunking + tokenizer = AutoTokenizer.from_pretrained(args.model_path) + + total_neg_log_likelihood = 0.0 + total_tokens = 0 + + for example in tqdm(dataset, desc="Evaluating PPL"): + text = example["text"].strip() + if not text: + continue + + # endcode, chunk and decode + tokens = tokenizer.encode(text, add_special_tokens=False) + for i in range(0, len(tokens), CHUNK_SIZE): + chunk_tokens = tokens[i : min(i + CHUNK_SIZE, len(tokens))] + chunk_text = tokenizer.decode(chunk_tokens) + + resp = requests.post( + API_URL, + headers={"Content-Type": "application/json"}, + json={ + "model": "", + "prompt": chunk_text, + "max_tokens": 0, + "temperature": 1.0, + "echo": True, + "logprobs": 0, + }, + ).json() + + logprobs = resp["choices"][0]["logprobs"]["token_logprobs"] + # skip first token's None + valid_logprobs = [lp for lp in logprobs[1:] if lp is not None] + + total_neg_log_likelihood += -sum(valid_logprobs) + total_tokens += len(valid_logprobs) + + # ==== Compute final PPL ==== + ppl = math.exp(total_neg_log_likelihood / total_tokens) + print(f"Perplexity: {ppl:.4f}") diff --git a/python/icinfer/config.py b/python/icinfer/config.py new file mode 100644 index 00000000..5fe498ab --- /dev/null +++ b/python/icinfer/config.py @@ -0,0 +1,43 @@ +import os +from dataclasses import dataclass +from transformers import AutoConfig + + +@dataclass +class Config: + model: str + max_num_batched_tokens: int = 16384 + max_num_seqs: int = 512 + max_model_len: int = 1024 + gpu_memory_utilization: float = 0.9 + tensor_parallel_size: int = 1 + enforce_eager: bool = False + hf_config: AutoConfig | None = None + eos: int = -1 + kvcache_block_size: int = 16 + max_kvcache_tokens: int = -1 + num_kvcache_blocks: int = -1 + trust_remote_code: bool = False + attention_bias: bool = False + enable_paged_attn: bool = False + + def __post_init__(self): + assert os.path.isdir(self.model) + assert self.kvcache_block_size % 4 == 0 + assert 1 <= self.tensor_parallel_size <= 8 + self.model_path = self.model + self.hf_config = AutoConfig.from_pretrained(self.model, trust_remote_code=self.trust_remote_code) + print(self.model_path) + self.check_hf_config() + self.max_model_len = min(self.max_model_len, self.hf_config.max_position_embeddings) + if self.num_kvcache_blocks < 0 and self.max_kvcache_tokens > 0: + self.num_kvcache_blocks = self.max_kvcache_tokens // self.kvcache_block_size + assert self.max_num_batched_tokens >= self.max_model_len + + def check_hf_config(self): + if getattr(self.hf_config, "head_dim", None) is None: + self.hf_config.head_dim = self.hf_config.hidden_size // self.hf_config.num_attention_heads + if getattr(self.hf_config, "attention_bias", None) is None: + self.hf_config.attention_bias = self.attention_bias + if getattr(self.hf_config, "kvcache_block_size", None) is None: + self.hf_config.kvcache_block_size = self.kvcache_block_size diff --git a/python/icinfer/engine/block_manager.py b/python/icinfer/engine/block_manager.py new file mode 100644 index 00000000..e5fda4e0 --- /dev/null +++ b/python/icinfer/engine/block_manager.py @@ -0,0 +1,114 @@ +from collections import deque +import xxhash +import numpy as np + +from icinfer.engine.sequence import Sequence + + +class Block: + + def __init__(self, block_id): + self.block_id = block_id + self.ref_count = 0 + self.hash = -1 + self.token_ids = [] + + def update(self, hash: int, token_ids: list[int]): + self.hash = hash + self.token_ids = token_ids + + def reset(self): + self.ref_count = 1 + self.hash = -1 + self.token_ids = [] + + +class BlockManager: + + def __init__(self, num_blocks: int, block_size: int): + assert num_blocks > 0 + self.block_size = block_size + self.blocks: list[Block] = [Block(i) for i in range(num_blocks)] + self.hash_to_block_id: dict[int, int] = dict() + self.free_block_ids: deque[int] = deque(range(num_blocks)) + self.used_block_ids: set[int] = set() + + @classmethod + def compute_hash(cls, token_ids: list[int], prefix: int = -1): + h = xxhash.xxh64() + if prefix != -1: + h.update(prefix.to_bytes(8, "little")) + h.update(np.array(token_ids).tobytes()) + return h.intdigest() + + def _allocate_block(self, block_id: int) -> Block: + block = self.blocks[block_id] + assert block.ref_count == 0 + block.reset() + self.free_block_ids.remove(block_id) + self.used_block_ids.add(block_id) + return self.blocks[block_id] + + def _deallocate_block(self, block_id: int) -> Block: + assert self.blocks[block_id].ref_count == 0 + self.used_block_ids.remove(block_id) + self.free_block_ids.append(block_id) + + def can_allocate(self, seq: Sequence) -> bool: + return len(self.free_block_ids) >= seq.num_blocks + + def allocate(self, seq: Sequence): + # TODO 对于这个机制还有点疑惑。 for i in range(seq.num_blocks): + assert not seq.block_table + h = -1 + cache_miss = False + for i in range(seq.num_blocks): + token_ids = seq.block(i) + h = self.compute_hash(token_ids, h) if len(token_ids) == self.block_size else -1 + block_id = self.hash_to_block_id.get(h, -1) + if block_id == -1 or self.blocks[block_id].token_ids != token_ids: + cache_miss = True + if cache_miss: + block_id = self.free_block_ids[0] + block = self._allocate_block(block_id) + else: + seq.num_cached_tokens += self.block_size + if block_id in self.used_block_ids: + block = self.blocks[block_id] + block.ref_count += 1 + else: + block = self._allocate_block(block_id) + if h != -1: + block.update(h, token_ids) + self.hash_to_block_id[h] = block_id + seq.block_table.append(block_id) + + def deallocate(self, seq: Sequence): + for block_id in reversed(seq.block_table): + block = self.blocks[block_id] + block.ref_count -= 1 + if block.ref_count == 0: + self._deallocate_block(block_id) + seq.num_cached_tokens = 0 + seq.block_table.clear() + + def can_append(self, seq: Sequence) -> bool: + return len(self.free_block_ids) >= (len(seq) % self.block_size == 1) + + def may_append(self, seq: Sequence): + block_table = seq.block_table + last_block = self.blocks[block_table[-1]] + if len(seq) % self.block_size == 1: + assert last_block.hash != -1 + block_id = self.free_block_ids[0] + self._allocate_block(block_id) + block_table.append(block_id) + elif len(seq) % self.block_size == 0: + assert last_block.hash == -1 + token_ids = seq.block(seq.num_blocks-1) + prefix = self.blocks[block_table[-2]].hash if len(block_table) > 1 else -1 + h = self.compute_hash(token_ids, prefix) + last_block.update(h, token_ids) + self.hash_to_block_id[h] = last_block.block_id + else: + assert last_block.hash == -1 diff --git a/python/icinfer/engine/infer_task.py b/python/icinfer/engine/infer_task.py new file mode 100644 index 00000000..f8987352 --- /dev/null +++ b/python/icinfer/engine/infer_task.py @@ -0,0 +1,195 @@ +from typing import List +from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref +from icinfer.models.libinfinicore_infer.base import ( + KVCacheCStruct, +) + + + + +class InferTask: + def __init__(self, id, tokens, max_tokens, temperature, topk, topp, end_tokens): + self.id = id + self.finish_reason = None + self.tokens = tokens + self.max_tokens = max_tokens + self.temperature = temperature + self.topk = topk + self.topp = topp + self.end_tokens = end_tokens + self._kv_cache = None + self.pos = 0 + + def bind_kvcache(self, kv_cache, pos=0): + self._kv_cache = kv_cache + self.pos = pos + self.tokens = self.tokens[pos:] + + def release_kvcache(self): + cache = self._kv_cache + self._kv_cache = None + return cache + + def kvcache(self): + return self._kv_cache + + def next(self, out_token): + if self._kv_cache is not None: + self._kv_cache.update_tokens(self.tokens, self.pos) + + self.pos += len(self.tokens) + if out_token == None or out_token in self.end_tokens: + self.finish_reason = "stop" + elif self.pos >= self.max_tokens: + self.finish_reason = "length" + else: + self.tokens = [out_token] + + +class InferBatchedTask: + def __init__(self, tasks: List[InferTask], is_prefill: int=1): + self.tasks = tasks + self.nreq = len(tasks) + self.is_prefill = is_prefill + + # Precompute fields + token_lists = [t.tokens for t in tasks] + self.req_lens_list = [len(toks) for toks in token_lists] + self.req_pos_list = [t.pos for t in tasks] + self.kv_cache_ptrs = [t.kvcache().data() for t in tasks] + self.temperaturas_list = [t.temperature for t in tasks] + self.topks_list = [t.topk for t in tasks] + self.topps_list = [t.topp for t in tasks] + + # Flatten token lists + flat_tokens = [tok for toks in token_lists for tok in toks] + self.ntok = len(flat_tokens) + + # Convert to ctypes arrays in one pass + self.tokens = (c_uint * self.ntok)(*flat_tokens) + self.req_lens = (c_uint * self.nreq)(*self.req_lens_list) + self.req_pos = (c_uint * self.nreq)(*self.req_pos_list) + self.kv_caches = (POINTER(KVCacheCStruct) * self.nreq)(*self.kv_cache_ptrs) + self.block_tables = POINTER(c_int)() + self.slot_mapping = POINTER(c_int)() + self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list) + self.topks = (c_uint * self.nreq)(*self.topks_list) + self.topps = (c_float * self.nreq)(*self.topps_list) + + def input_args(self): + return ( + self.tokens, + self.ntok, + self.req_lens, + self.nreq, + self.req_pos, + self.kv_caches, + self.block_tables, + self.slot_mapping, + self.temperaturas, + self.topks, + self.topps, + self.is_prefill, + ) + + +class InferPagedBatchedTask: + def __init__(self, tasks: List[InferTask], batch_block_tables: list[int]=[], slot_mapping: list[int]=[], paged_kvcache=None, is_prefill: int=1): + self.tasks = tasks + self.nreq = len(tasks) + self.is_prefill = is_prefill + self.batch_block_tables = batch_block_tables + self.slot_mapping = slot_mapping + + # Precompute fields + token_lists = [t.tokens for t in tasks] + self.req_lens_list = [len(toks) for toks in token_lists] + self.req_pos_list = [t.pos for t in tasks] + self.kv_cache_ptrs = [paged_kvcache.data()] + self.temperaturas_list = [t.temperature for t in tasks] + self.topks_list = [t.topk for t in tasks] + self.topps_list = [t.topp for t in tasks] + + # Flatten token lists + flat_tokens = [tok for toks in token_lists for tok in toks] + self.ntok = len(flat_tokens) + self.n_blocks = len(batch_block_tables) # self.nreq * max_block_table_lens + + # Convert to ctypes arrays in one pass + self.tokens = (c_uint * self.ntok)(*flat_tokens) + self.req_lens = (c_uint * self.nreq)(*self.req_lens_list) + self.req_pos = (c_uint * self.nreq)(*self.req_pos_list) + self.kv_caches = (POINTER(KVCacheCStruct) * 1)(*self.kv_cache_ptrs) + self.block_tables = (c_int * self.n_blocks)(*batch_block_tables) + self.slot_mapping = (c_int * self.ntok)(*slot_mapping) + self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list) + self.topks = (c_uint * self.nreq)(*self.topks_list) + self.topps = (c_float * self.nreq)(*self.topps_list) + + def input_args(self): + return ( + self.tokens, + self.ntok, + self.req_lens, + self.nreq, + self.req_pos, + self.kv_caches, + self.block_tables, + self.slot_mapping, + self.temperaturas, + self.topks, + self.topps, + self.is_prefill, + ) + + def input_args_for_logits(self): + return ( + self.tokens, + self.ntok, + self.req_lens, + self.nreq, + self.req_pos, + self.kv_caches, + self.block_tables, + self.slot_mapping, + self.is_prefill, + ) + + + +class KVCache: + def __init__(self, model): + self._kvcache = model.create_kv_cache() + self.tokens = [0 for _ in range(model.max_context_len())] + + def data(self): + return self._kvcache + + def drop(self, model): + model.drop_kv_cache(self._kvcache) + + def update_tokens(self, tokens, pos): + end = pos + len(tokens) + max_len = len(self.tokens) + + # If overflow, truncate tokens to fit + if end > max_len: + tokens = tokens[: max_len - pos] + end = max_len + + self.tokens[pos:end] = tokens + +class PagedKVCache: + def __init__(self, paged_kvcache): + self._kvcache = paged_kvcache + # self.tokens = [0 for _ in range(model.max_context_len())] + + def data(self): + return self._kvcache + + def drop(self, model): + model.drop_kv_cache(self._kvcache) + + def update_tokens(self, tokens, pos): + print("PagedKVCache need not to update tokens.") + pass diff --git a/python/icinfer/engine/kvcache_pool.py b/python/icinfer/engine/kvcache_pool.py new file mode 100644 index 00000000..b48d2695 --- /dev/null +++ b/python/icinfer/engine/kvcache_pool.py @@ -0,0 +1,90 @@ +from icinfer.engine.infer_task import KVCache + +import asyncio +from typing import List +import threading + + +class KVCachePool: + def __init__(self, model, max_caches: int = 32): + self.max_caches = max_caches + self.model = model + self._available: List[KVCache] = [] + self.num_caches = len(self._available) + self._lock = threading.Lock() + self._not_empty = threading.Condition(self._lock) + self._shutdown = False + + def acquire_sync(self, infer_task): + with self._not_empty: + while True: + if self._shutdown: + raise RuntimeError( + "KVCachePool is shutting down; cannot acquire new cache." + ) + if len(self._available) == 0: + if self.num_caches < self.max_caches: + self.num_caches += 1 + print( + f"[INFO] Task {infer_task.id} created new KVCachePoolItem" + ) + return infer_task.bind_kvcache(KVCache(self.model), 0) + else: + self._not_empty.wait() + else: + max_match, max_match_index = self.find_most_matching_cache( + infer_task.tokens + ) + kvcache = self._available.pop(max_match_index) + print( + f"[INFO] Task {infer_task.id} reused KVCachePoolItem {max_match_index} with {max_match} matches" + ) + return infer_task.bind_kvcache(kvcache, max_match) + + def release_sync(self, infer_task): + with self._not_empty: + print(f"[INFO] Task {infer_task.id} returned KVCachePoolItem to pool") + self._available.append(infer_task.release_kvcache()) + self._not_empty.notify() + + async def acquire(self, infer_task): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.acquire_sync, infer_task) + + async def release(self, infer_task): + loop = asyncio.get_event_loop() + return await loop.run_in_executor(None, self.release_sync, infer_task) + + def find_most_matching_cache(self, tokens: List[int]): + max_match = 0 + max_match_index = 0 + + def first_different_index(a_, b_): + for i_, (x_, y_) in enumerate(zip(a_, b_)): + if x_ != y_: + return i_ + return min(len(a_), len(b_)) + + for i, kvcache in enumerate(self._available): + common_elements = first_different_index(tokens, kvcache.tokens) + # print(f"{tokens}") + # print(f"{kvcache.tokens[:len(tokens)]}") + if common_elements > max_match: + max_match = common_elements + max_match_index = i + + return (min(max_match, len(tokens) - 1), max_match_index) + + def finalize(self): + with self._not_empty: + self._shutdown = True + while len(self._available) < self.num_caches: + self._not_empty.wait() + + for kvcache in self._available: + if kvcache is not None: + kvcache.drop(self.model) + + self._available.clear() + self.max_caches = 0 + self.num_caches = 0 diff --git a/python/icinfer/engine/llm_engine.py b/python/icinfer/engine/llm_engine.py new file mode 100644 index 00000000..359a5594 --- /dev/null +++ b/python/icinfer/engine/llm_engine.py @@ -0,0 +1,196 @@ +import atexit +from dataclasses import fields +from time import perf_counter +from tqdm.auto import tqdm +from transformers import AutoTokenizer +import torch.multiprocessing as mp +import math +from typing import List +import uuid + +from icinfer.config import Config +from icinfer.sampling_params import SamplingParams +from icinfer.engine.sequence import Sequence +from icinfer.engine.scheduler import Scheduler +from icinfer.engine.model_runner import ModelRunner +from icinfer.engine.infer_task import KVCache, InferTask + +import logging +logger = logging.getLogger(__name__) + + +class InfiniEngine: + + def __init__(self, model, device, **kwargs): + config_fields = {field.name for field in fields(Config)} + config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields} + config = Config(model, **config_kwargs) + + self.ps = [] + self.events = [] + # ctx = mp.get_context("spawn") + # for i in range(1, config.tensor_parallel_size): + # event = ctx.Event() + # process = ctx.Process(target=ModelRunner, args=(config, i, event)) + # process.start() + # self.ps.append(process) + # self.events.append(event) + self.model_runner = ModelRunner(config, device, 0, self.events) + self.eos_token_id = self.model_runner.eos_token_id + self.max_context_len = self.model_runner.max_context_len() + self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True, trust_remote_code=kwargs["trust_remote_code"]) + config.eos = self.tokenizer.eos_token_id + self.scheduler = Scheduler(config) + atexit.register(self.exit) + + def exit(self): + self.model_runner.call("exit") + del self.model_runner + for p in self.ps: + p.join() + + def add_request(self, prompt: str | list[int], sampling_params: SamplingParams): + if isinstance(prompt, str): + prompt = self.tokenizer.encode(prompt) + seq = Sequence(prompt, sampling_params, block_size=self.scheduler.block_size) + infer_task = InferTask(seq.seq_id, prompt, self.max_context_len, sampling_params.temperature, sampling_params.topk, sampling_params.topp, self.eos_token_id) + if self.model_runner.enable_paged_attn: + pass + else: + infer_task.bind_kvcache(KVCache(self.model_runner)) + seq.bind_infer_task(infer_task) + self.scheduler.add(seq) + return prompt + + def step(self): + seqs, is_prefill = self.scheduler.schedule() + token_ids = self.model_runner.call("run", seqs, is_prefill) + drop_kvcache_list = self.scheduler.postprocess(seqs, token_ids) + if self.model_runner.enable_paged_attn: + pass + else: + for kv_cache in drop_kvcache_list: + kv_cache.drop(self.model_runner) + outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] + num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs) + return outputs, num_tokens + + def is_finished(self): + return self.scheduler.is_finished() + + def generate( + self, + prompts: list[str] | list[list[int]], + sampling_params: SamplingParams | list[SamplingParams], + use_tqdm: bool = True, + ) -> list[str]: + if use_tqdm: + pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True) + if not isinstance(sampling_params, list): + sampling_params = [sampling_params] * len(prompts) + prompts_list = [] + for prompt, sp in zip(prompts, sampling_params): + prompts_list.append(self.add_request(prompt, sp)) + outputs = {} + prefill_throughput = decode_throughput = 0. + logger.info("start generating") + # perfile + avg_prefill_throughput = 0 + prefill_time = 0 + avg_decode_throughput = 0 + decode_time = 0 + ttft = 0 + ttft_count = 0 + tbt = 0 + tbt_count = 0 + + while not self.is_finished(): + t = perf_counter() + output, num_tokens = self.step() + if use_tqdm: + if num_tokens > 0: + check_time = perf_counter() + prefill_throughput = num_tokens / (check_time - t) + ttft += (check_time - t) + ttft_count += 1 + avg_prefill_throughput = (avg_prefill_throughput * prefill_time + num_tokens)/(prefill_time+(check_time - t)) + prefill_time += (check_time - t) + else: + check_time = perf_counter() + decode_throughput = -num_tokens / (check_time - t) + tbt += (check_time - t) + tbt_count += 1 + avg_decode_throughput = (avg_decode_throughput * decode_time - num_tokens)/(decode_time+(check_time - t)) + decode_time += (check_time - t) + pbar.set_postfix({ + "Prefill": f"{int(prefill_throughput)}tok/s", + "Decode": f"{int(decode_throughput)}tok/s", + }) + for seq_id, token_ids in output: + outputs[seq_id] = token_ids + if use_tqdm: + pbar.update(1) + outputs = [outputs[seq_id] for seq_id in sorted(outputs)] + outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs] + avg_ttft = ttft / ttft_count + avg_tbt = tbt / tbt_count + if not self.model_runner.enable_paged_attn: + max_model_len = self.model_runner.config.max_model_len + num_seqs = len(outputs) + used_tokens = [len(prompts_list[i])+len(outputs[i]) for i in range(num_seqs)] + used_tokens_count = sum(used_tokens) + cache_efficiency = used_tokens_count / (num_seqs * max_model_len) + else: + max_model_len = self.model_runner.config.max_model_len + num_seqs = len(outputs) + used_tokens = [len(prompts_list[i])+len(outputs[i]) for i in range(num_seqs)] + block_size = self.model_runner.config.kvcache_block_size + cache_memory = [(i_tokens + block_size - 1) // block_size * block_size for i_tokens in used_tokens] + cache_efficiency = sum(used_tokens) / sum(cache_memory) + + if use_tqdm: + pbar.close() + self.model_runner.exit() + return outputs, avg_prefill_throughput, avg_decode_throughput, avg_ttft, avg_tbt, cache_efficiency + + def add_perplexity_request(self, prompt: str | list[int], sampling_params: SamplingParams): + if isinstance(prompt, str): + prompt = self.tokenizer.encode(prompt) + input_tokens = prompt[:-1] + true_tokens = prompt[1:] + seq = Sequence(input_tokens, sampling_params, block_size=self.scheduler.block_size) + infer_task = InferTask(seq.seq_id, input_tokens, self.max_context_len, 1.0, 1, 1.0, self.eos_token_id) + seq.true_tokens = true_tokens + if self.model_runner.enable_paged_attn: + pass + else: + infer_task.bind_kvcache(KVCache(self.model_runner)) + seq.bind_infer_task(infer_task) + self.scheduler.add(seq) + + def perplexity_step(self): + seqs, is_prefill = self.scheduler.schedule() + nll, total_len, token_ids_none = self.model_runner.call("run_for_logits", seqs, is_prefill) + drop_kvcache_list = self.scheduler.postprocess(seqs, token_ids_none) + if self.model_runner.enable_paged_attn: + pass + else: + for kv_cache in drop_kvcache_list: + kv_cache.drop(self.model_runner) + return nll, total_len + # outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] + # num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs) + # return outputs, num_tokens + + def perplexity(self, test_sequences: List[List[int]]): + nll = 0.0 + total_len = 0 + + for i in range(len(test_sequences)): + self.add_perplexity_request(test_sequences[i], SamplingParams(temperature=1.0, topk=1, topp=1.0, max_tokens=1)) + while not self.is_finished(): + nll_i, total_len_i = self.perplexity_step() + nll += nll_i + total_len += total_len_i + + return math.exp(nll / total_len) \ No newline at end of file diff --git a/python/icinfer/engine/llm_engine_async.py b/python/icinfer/engine/llm_engine_async.py new file mode 100644 index 00000000..0e0d6ed3 --- /dev/null +++ b/python/icinfer/engine/llm_engine_async.py @@ -0,0 +1,264 @@ +import atexit +from dataclasses import fields +from time import perf_counter +from tqdm.auto import tqdm +from transformers import AutoTokenizer +import torch.multiprocessing as mp +import math +from typing import List +import uuid +import threading +import queue +import asyncio +from typing import Dict +import time +import collections + +from icinfer.config import Config +from icinfer.sampling_params import SamplingParams +from icinfer.engine.sequence import Sequence +from icinfer.engine.scheduler import Scheduler +from icinfer.engine.model_runner import ModelRunner +from icinfer.engine.infer_task import KVCache, InferTask + +import logging +logger = logging.getLogger(__name__) + + +class InfiniEngineAsync: + + def __init__(self, model, device, **kwargs): + config_fields = {field.name for field in fields(Config)} + config_kwargs = {k: v for k, v in kwargs.items() if k in config_fields} + config = Config(model, **config_kwargs) + + self.ps = [] + self.events = [] + # ctx = mp.get_context("spawn") + # for i in range(1, config.tensor_parallel_size): + # event = ctx.Event() + # process = ctx.Process(target=ModelRunner, args=(config, i, event)) + # process.start() + # self.ps.append(process) + # self.events.append(event) + self.model_runner = ModelRunner(config, device, 0, self.events) + self.eos_token_id = self.model_runner.eos_token_id + self.max_context_len = self.model_runner.max_context_len() + self.request_queue = queue.Queue() + self.result_queues: Dict[str, asyncio.Queue] = {} + self.main_loop = None + + self.tokenizer = AutoTokenizer.from_pretrained(config.model, use_fast=True, trust_remote_code=kwargs["trust_remote_code"]) + config.eos = self.tokenizer.eos_token_id + self.scheduler = Scheduler(config) + atexit.register(self.exit) + + + def exit(self): + self.model_runner.call("exit") + del self.model_runner + for p in self.ps: + p.join() + + async def add_request(self, prompt: str | list[int], sampling_params: SamplingParams, request_id: str): + if self.main_loop is None: + self.main_loop = asyncio.get_running_loop() + + result_queue = asyncio.Queue() + self.result_queues[request_id] = result_queue + self.request_queue.put((prompt, sampling_params, request_id)) + + return result_queue + + def add_request_action(self, prompt: str | list[int], sp, req_id): + if isinstance(prompt, str): + prompt_tokens = self.tokenizer.encode(prompt) + else: + prompt_tokens = prompt + + seq = Sequence(prompt_tokens, sp, block_size=self.scheduler.block_size, req_id=req_id) + infer_task = InferTask(seq.req_id, prompt_tokens, self.max_context_len, sp.temperature, sp.topk, sp.topp, self.eos_token_id) + if self.model_runner.enable_paged_attn: + pass + else: + infer_task.bind_kvcache(KVCache(self.model_runner)) + seq.bind_infer_task(infer_task) + self.scheduler.add(seq) + + def step(self): + seqs, is_prefill = self.scheduler.schedule() + token_ids = self.model_runner.call("run", seqs, is_prefill) + drop_kvcache_list = self.scheduler.postprocess(seqs, token_ids) + if self.model_runner.enable_paged_attn: + pass + else: + for kv_cache in drop_kvcache_list: + kv_cache.drop(self.model_runner) + outputs = [(seq.req_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] + num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs) + return outputs, num_tokens + + def is_finished(self): + return self.scheduler.is_finished() + + def engine_loop(self): + while True: + # 1. 从队列中获取新请求并添加到调度器 + while not self.request_queue.empty(): + prompt, sp, req_id = self.request_queue.get() + + self.add_request_action(prompt, sp, req_id) + + if self.request_queue.empty(): + time.sleep(0.1) + continue + + + # 2. 执行一步推理 + if not self.scheduler.is_finished(): + seqs, is_prefill = self.scheduler.schedule() + print(f"seqs_len: {len(seqs)}") + + # token_ids 是一个列表,按进入顺序排列的 + token_ids = self.model_runner.call("run", seqs, is_prefill) + + for seq_order_i in range(len(seqs)): + seq = seqs[seq_order_i] + new_token = token_ids[seq_order_i] + result_queue = self.result_queues.get(seq.req_id) + if result_queue: + self.main_loop.call_soon_threadsafe(result_queue.put_nowait, new_token) + + drop_kvcache_list = self.scheduler.postprocess(seqs, token_ids) + if self.model_runner.enable_paged_attn: + pass + else: + for kv_cache in drop_kvcache_list: + kv_cache.drop(self.model_runner) + + # 4. 处理完成的序列 + for seq in seqs: + if seq.is_finished: + result_queue = self.result_queues.get(seq.req_id) + if result_queue: + self.main_loop.call_soon_threadsafe(result_queue.put_nowait, None) + self.result_queues.pop(seq.req_id, None) + else: + time.sleep(0.01) + + # def generate( + # self, + # prompts: list[str] | list[list[int]], + # sampling_params: SamplingParams | list[SamplingParams], + # use_tqdm: bool = True, + # ) -> list[str]: + # if use_tqdm: + # pbar = tqdm(total=len(prompts), desc="Generating", dynamic_ncols=True) + # if not isinstance(sampling_params, list): + # sampling_params = [sampling_params] * len(prompts) + # prompts_list = [] + # for prompt, sp in zip(prompts, sampling_params): + # prompts_list.append(self.add_request(prompt, sp)) + # outputs = {} + # prefill_throughput = decode_throughput = 0. + # logger.info("start generating") + # # perfile + # avg_prefill_throughput = 0 + # prefill_time = 0 + # avg_decode_throughput = 0 + # decode_time = 0 + # ttft = 0 + # ttft_count = 0 + # tbt = 0 + # tbt_count = 0 + + # while not self.is_finished(): + # t = perf_counter() + # output, num_tokens = self.step() + # if use_tqdm: + # if num_tokens > 0: + # check_time = perf_counter() + # prefill_throughput = num_tokens / (check_time - t) + # ttft += (check_time - t) + # ttft_count += 1 + # avg_prefill_throughput = (avg_prefill_throughput * prefill_time + num_tokens)/(prefill_time+(check_time - t)) + # prefill_time += (check_time - t) + # else: + # check_time = perf_counter() + # decode_throughput = -num_tokens / (check_time - t) + # tbt += (check_time - t) + # tbt_count += 1 + # avg_decode_throughput = (avg_decode_throughput * decode_time - num_tokens)/(decode_time+(check_time - t)) + # decode_time += (check_time - t) + # pbar.set_postfix({ + # "Prefill": f"{int(prefill_throughput)}tok/s", + # "Decode": f"{int(decode_throughput)}tok/s", + # }) + # for seq_id, token_ids in output: + # outputs[seq_id] = token_ids + # if use_tqdm: + # pbar.update(1) + # outputs = [outputs[seq_id] for seq_id in sorted(outputs)] + # outputs = [{"text": self.tokenizer.decode(token_ids), "token_ids": token_ids} for token_ids in outputs] + # avg_ttft = ttft / ttft_count + # avg_tbt = tbt / tbt_count + # if not self.model_runner.enable_paged_attn: + # max_model_len = self.model_runner.config.max_model_len + # num_seqs = len(outputs) + # used_tokens = [len(prompts_list[i])+len(outputs[i]) for i in range(num_seqs)] + # used_tokens_count = sum(used_tokens) + # cache_efficiency = used_tokens_count / (num_seqs * max_model_len) + # else: + # max_model_len = self.model_runner.config.max_model_len + # num_seqs = len(outputs) + # used_tokens = [len(prompts_list[i])+len(outputs[i]) for i in range(num_seqs)] + # block_size = self.model_runner.config.kvcache_block_size + # cache_memory = [(i_tokens + block_size - 1) // block_size * block_size for i_tokens in used_tokens] + # cache_efficiency = sum(used_tokens) / sum(cache_memory) + + # if use_tqdm: + # pbar.close() + # self.model_runner.exit() + # return outputs, avg_prefill_throughput, avg_decode_throughput, avg_ttft, avg_tbt, cache_efficiency + + # def add_perplexity_request(self, prompt: str | list[int], sampling_params: SamplingParams): + # if isinstance(prompt, str): + # prompt = self.tokenizer.encode(prompt) + # input_tokens = prompt[:-1] + # true_tokens = prompt[1:] + # seq = Sequence(input_tokens, sampling_params, block_size=self.scheduler.block_size) + # infer_task = InferTask(seq.seq_id, input_tokens, self.max_context_len, 1.0, 1, 1.0, self.eos_token_id) + # seq.true_tokens = true_tokens + # if self.model_runner.enable_paged_attn: + # pass + # else: + # infer_task.bind_kvcache(KVCache(self.model_runner)) + # seq.bind_infer_task(infer_task) + # self.scheduler.add(seq) + + # def perplexity_step(self): + # seqs, is_prefill = self.scheduler.schedule() + # nll, total_len, token_ids_none = self.model_runner.call("run_for_logits", seqs, is_prefill) + # drop_kvcache_list = self.scheduler.postprocess(seqs, token_ids_none) + # if self.model_runner.enable_paged_attn: + # pass + # else: + # for kv_cache in drop_kvcache_list: + # kv_cache.drop(self.model_runner) + # return nll, total_len + # # outputs = [(seq.seq_id, seq.completion_token_ids) for seq in seqs if seq.is_finished] + # # num_tokens = sum(len(seq) for seq in seqs) if is_prefill else -len(seqs) + # # return outputs, num_tokens + + # def perplexity(self, test_sequences: List[List[int]]): + # nll = 0.0 + # total_len = 0 + + # for i in range(len(test_sequences)): + # self.add_perplexity_request(test_sequences[i], SamplingParams(temperature=1.0, topk=1, topp=1.0, max_tokens=1)) + # while not self.is_finished(): + # nll_i, total_len_i = self.perplexity_step() + # nll += nll_i + # total_len += total_len_i + + # return math.exp(nll / total_len) \ No newline at end of file diff --git a/python/icinfer/engine/model_runner.py b/python/icinfer/engine/model_runner.py new file mode 100644 index 00000000..01d22670 --- /dev/null +++ b/python/icinfer/engine/model_runner.py @@ -0,0 +1,486 @@ +import pickle +import torch +import torch.distributed as dist +from multiprocessing.synchronize import Event +from multiprocessing.shared_memory import SharedMemory +from ctypes import c_uint +from typing import List +import logging +import itertools + + +from icinfer.config import Config +from icinfer.engine.sequence import Sequence +from icinfer.models.libinfinicore_infer.base import ( + DeviceType, +) +# from icinfer.engine.libinfinicore_infer import ( +# JiugeMetaCStruct, +# JiugeWeightsCStruct, +# KVCacheCStruct, +# DataType, +# DeviceType, +# create_jiuge_model, +# destroy_jiuge_model, +# create_kv_cache, +# create_paged_kv_cache, +# drop_kv_cache, +# infer_batch, +# forward_batch, +# ) + +from icinfer.models.auto_modeling import AutoModelForCausalLM + +from icinfer.layers.sampler import Sampler +from icinfer.utils.context import set_context, get_context, reset_context +# from icinfer.utils.loader import load_model +# from icinfer.utils.jiuge_weights_loader import load_model +from icinfer.engine.infer_task import InferTask, InferBatchedTask, InferPagedBatchedTask, PagedKVCache + + +# infinicore infer +from typing import List, Sequence +from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref +import time +import math +from icinfer.engine.infer_task import InferTask, KVCache + + +logger = logging.getLogger(__name__) + + +class ModelRunner: + + def __init__(self, config: Config, device: DeviceType, rank: int, event: Event | list[Event]): + self.config = config + self.hf_config = config.hf_config + self.device = device + self.block_size = config.kvcache_block_size + self.enforce_eager = config.enforce_eager + self.enable_paged_attn = config.enable_paged_attn + self.world_size = config.tensor_parallel_size + self.meta = None + self.kv_cache = None + self.rank = rank + self.event = event + + # dist.init_process_group("nccl", "tcp://localhost:2333", world_size=self.world_size, rank=rank) + # torch.cuda.set_device(rank) + # default_dtype = torch.get_default_dtype() + # torch.set_default_dtype(hf_config.torch_dtype) + # torch.set_default_device("cuda") + + # model_set = { + # "qwen3": Qwen3ForCausalLM, + # "fm9g7b": FM9GForCausalLM, + # } + # ModelForCausalLm = model_set[hf_config.model_type] + # self.model = ModelForCausalLm(hf_config) + # load_model(self.model, config.model) + + self.model = AutoModelForCausalLM.from_config(self.config, device) + self.meta = self.model.meta + # self.tokenizer = transformers.AutoTokenizer.from_pretrained( + # model_dir_path, trust_remote_code=True + # ) + print("hello") + + eos_token_id = self.hf_config.eos_token_id + self.eos_token_id = ( + [eos_token_id] if type(eos_token_id) == int else eos_token_id + ) + + # self.sampler = Sampler() + # self.warmup_model() + # TODO 暂时先关掉 + if self.enable_paged_attn: + self.allocate_kv_cache() + if not self.enforce_eager: + self.capture_cudagraph() + # torch.set_default_device("cpu") + # torch.set_default_dtype(default_dtype) + + # if self.world_size > 1: + # if rank == 0: + # self.shm = SharedMemory(name="nanovllm", create=True, size=2**20) + # dist.barrier() + # else: + # dist.barrier() + # self.shm = SharedMemory(name="nanovllm") + # self.loop() + + def exit(self): + # if self.world_size > 1: + # self.shm.close() + # dist.barrier() + # if self.rank == 0: + # self.shm.unlink() + if not self.enforce_eager: + del self.graphs, self.graph_pool + # torch.cuda.synchronize() + self.destroy() + # dist.destroy_process_group() + + def __del__(self): + self.destroy() + + def destroy(self): + """ + 在程序退出时,安全地释放 C++ 侧的资源。 + """ + if hasattr(self, 'kv_cache') and self.kv_cache: + print("drop_kv_cache") + self.model.drop_kv_cache(self.kv_cache.data()) + self.kv_cache = None + if hasattr(self, 'model') and self.model: + self.model.destroy_model_instance() + self.model = None + + logger.info("ModelRunner model resources have been released.") + + # def loop(self): + # while True: + # method_name, args = self.read_shm() + # self.call(method_name, *args) + # if method_name == "exit": + # break + + # def read_shm(self): + # assert self.world_size > 1 and self.rank + # self.event.wait() + # n = int.from_bytes(self.shm.buf[0:4], "little") + # method_name, *args = pickle.loads(self.shm.buf[4:n+4]) + # self.event.clear() + # return method_name, args + + # def write_shm(self, method_name, *args): + # assert self.world_size > 1 and not self.rank + # data = pickle.dumps([method_name, *args]) + # n = len(data) + # self.shm.buf[0:4] = n.to_bytes(4, "little") + # self.shm.buf[4:n+4] = data + # for event in self.event: + # event.set() + + def call(self, method_name, *args): + # if self.world_size > 1 and self.rank == 0: + # self.write_shm(method_name, *args) + method = getattr(self, method_name, None) + return method(*args) + + # def warmup_model(self): + # torch.cuda.empty_cache() + # torch.cuda.reset_peak_memory_stats() + # max_num_batched_tokens, max_model_len = self.config.max_num_batched_tokens, self.config.max_model_len + # num_seqs = min(max_num_batched_tokens // max_model_len, self.config.max_num_seqs) + # seqs = [Sequence([0] * max_model_len) for _ in range(num_seqs)] + # self.run(seqs, True) + # torch.cuda.empty_cache() + + # def _calculate_num_blocks(self, num_kv_heads: int) -> int: + # config = self.config + # hf_config = config.hf_config + # gpu_memory_utilization = config.gpu_memory_utilization + + # free, total = torch.cuda.mem_get_info() + # used = total - free + # # todo torch.cuda需要用一个什么来替代这部分 + # peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"] + # current = torch.cuda.memory_stats()["allocated_bytes.all.current"] + # block_bytes = 2 * hf_config.num_hidden_layers * self.block_size * num_kv_heads * hf_config.head_dim * hf_config.torch_dtype.itemsize + # num_kvcache_blocks = int(total * gpu_memory_utilization - used - peak + current) // block_bytes + # assert num_kvcache_blocks > 0 + # return num_kvcache_blocks + + def allocate_kv_cache(self): + kv_cache = self.model.create_paged_kv_cache(self.config.max_kvcache_tokens) + self.kv_cache = PagedKVCache(kv_cache) + print("kvcache allocated ") + # config = self.config + # hf_config = config.hf_config + # num_kv_heads = hf_config.num_key_value_heads // self.world_size + # config.num_kvcache_blocks = self._calculate_num_blocks(num_kv_heads) + # self.kv_cache = torch.zeros(2, hf_config.num_hidden_layers, config.num_kvcache_blocks, self.block_size, num_kv_heads, hf_config.head_dim) + # layer_id = 0 + # for module in self.model.modules(): + # if hasattr(module, "k_cache") and hasattr(module, "v_cache"): + # module.k_cache = self.kv_cache[0, layer_id] + # module.v_cache = self.kv_cache[1, layer_id] + # layer_id += 1 + + def prepare_block_tables(self, seqs: list[Sequence]): + max_len = max(len(seq.block_table) for seq in seqs) + padded_lists_generator = ( + (seq.block_table + [0] * (max_len - len(seq.block_table))) + for seq in seqs + ) + block_tables_flat = list(itertools.chain.from_iterable(padded_lists_generator)) + # block_tables = [seq.block_table + [-1] * (max_len - len(seq.block_table)) for seq in seqs] + # block_tables = torch.tensor(block_tables, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + return block_tables_flat + + def prepare_prefill(self, seqs: list[Sequence]): + input_ids = [] + positions = [] + cu_seqlens_q = [0] + cu_seqlens_k = [0] + max_seqlen_q = 0 + max_seqlen_k = 0 + slot_mapping = [] + block_tables = [] + for seq in seqs: + seqlen = len(seq) + input_ids.extend(seq[seq.num_cached_tokens:]) + positions.extend(list(range(seq.num_cached_tokens, seqlen))) + seqlen_q = seqlen - seq.num_cached_tokens + seqlen_k = seqlen + cu_seqlens_q.append(cu_seqlens_q[-1] + seqlen_q) + cu_seqlens_k.append(cu_seqlens_k[-1] + seqlen_k) + max_seqlen_q = max(seqlen_q, max_seqlen_q) + max_seqlen_k = max(seqlen_k, max_seqlen_k) + if not seq.block_table: + continue + for i in range(seq.num_cached_blocks, seq.num_blocks): + start = seq.block_table[i] * self.block_size + if i != seq.num_blocks - 1: + end = start + self.block_size + else: + end = start + seq.last_block_num_tokens + slot_mapping.extend(list(range(start, end))) + # if cu_seqlens_k[-1] > cu_seqlens_q[-1]: # prefix cache + block_tables = self.prepare_block_tables(seqs) + # input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + # positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + # cu_seqlens_q = torch.tensor(cu_seqlens_q, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + # cu_seqlens_k = torch.tensor(cu_seqlens_k, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + # slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + # set_context(True, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, None, block_tables) + # return input_ids, positions + return block_tables, slot_mapping + + def prepare_decode(self, seqs: list[Sequence]): + input_ids = [] + positions = [] + slot_mapping = [] + context_lens = [] + for seq in seqs: + input_ids.append(seq.last_token) + positions.append(len(seq)) + context_lens.append(len(seq)) + slot_mapping.append(seq.block_table[-1] * self.block_size + seq.last_block_num_tokens - 1) + # input_ids = torch.tensor(input_ids, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + # positions = torch.tensor(positions, dtype=torch.int64, pin_memory=True).cuda(non_blocking=True) + # slot_mapping = torch.tensor(slot_mapping, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + # context_lens = torch.tensor(context_lens, dtype=torch.int32, pin_memory=True).cuda(non_blocking=True) + block_tables = self.prepare_block_tables(seqs) + # set_context(False, slot_mapping=slot_mapping, context_lens=context_lens, block_tables=block_tables) + return block_tables, slot_mapping + + # def prepare_sample(self, seqs: list[Sequence]): + # temperatures = [] + # for seq in seqs: + # temperatures.append(seq.temperature) + # temperatures = torch.tensor(temperatures, dtype=torch.float32, pin_memory=True).cuda(non_blocking=True) + # return temperatures + + @torch.inference_mode() + def run_model(self, input_ids: torch.Tensor, positions: torch.Tensor, is_prefill: bool): + if is_prefill or self.enforce_eager or input_ids.size(0) > 512: + return self.model.compute_logits(self.model(input_ids, positions)) + else: + bs = input_ids.size(0) + context = get_context() + graph = self.graphs[next(x for x in self.graph_bs if x >= bs)] + graph_vars = self.graph_vars + for k, v in graph_vars.items(): + if k != "outputs": + v.zero_() + graph_vars["input_ids"][:bs] = input_ids + graph_vars["positions"][:bs] = positions + graph_vars["slot_mapping"][:bs] = context.slot_mapping + graph_vars["context_lens"][:bs] = context.context_lens + graph_vars["block_tables"][:bs, :context.block_tables.size(1)] = context.block_tables + graph.replay() + return self.model.compute_logits(graph_vars["outputs"][:bs]) + + + # def run(self, seqs: list[Sequence], is_prefill: bool) -> list[int]: + # input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) + # temperatures = self.prepare_sample(seqs) if self.rank == 0 else None + # logits = self.run_model(input_ids, positions, is_prefill) + # token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None + # reset_context() + # return token_ids + + # @torch.inference_mode() + # def capture_cudagraph(self): + # config = self.config + # hf_config = config.hf_config + # max_bs = min(self.config.max_num_seqs, 512) + # max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size + # input_ids = torch.zeros(max_bs, dtype=torch.int64) + # positions = torch.zeros(max_bs, dtype=torch.int64) + # slot_mapping = torch.zeros(max_bs, dtype=torch.int32) + # context_lens = torch.zeros(max_bs, dtype=torch.int32) + # block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32) + # outputs = torch.zeros(max_bs, hf_config.hidden_size) + # self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16)) + # self.graphs = {} + # self.graph_pool = None + + # for bs in reversed(self.graph_bs): + # graph = torch.cuda.CUDAGraph() + # set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs]) + # outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup + # with torch.cuda.graph(graph, self.graph_pool): + # outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture + # if self.graph_pool is None: + # self.graph_pool = graph.pool() + # self.graphs[bs] = graph + # torch.cuda.synchronize() + # reset_context() + + # self.graph_vars = dict( + # input_ids=input_ids, + # positions=positions, + # slot_mapping=slot_mapping, + # context_lens=context_lens, + # block_tables=block_tables, + # outputs=outputs, + # ) + + # @torch.inference_mode() + # def capture_cudagraph(self): + # config = self.config + # hf_config = config.hf_config + # max_bs = min(self.config.max_num_seqs, 512) + # max_num_blocks = (config.max_model_len + self.block_size - 1) // self.block_size + # input_ids = torch.zeros(max_bs, dtype=torch.int64) + # positions = torch.zeros(max_bs, dtype=torch.int64) + # slot_mapping = torch.zeros(max_bs, dtype=torch.int32) + # context_lens = torch.zeros(max_bs, dtype=torch.int32) + # block_tables = torch.zeros(max_bs, max_num_blocks, dtype=torch.int32) + # outputs = torch.zeros(max_bs, hf_config.hidden_size) + # self.graph_bs = [1, 2, 4, 8] + list(range(16, max_bs + 1, 16)) + # self.graphs = {} + # self.graph_pool = None + + # for bs in reversed(self.graph_bs): + # graph = torch.cuda.CUDAGraph() + # set_context(False, slot_mapping=slot_mapping[:bs], context_lens=context_lens[:bs], block_tables=block_tables[:bs]) + # outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # warmup + # with torch.cuda.graph(graph, self.graph_pool): + # outputs[:bs] = self.model(input_ids[:bs], positions[:bs]) # capture + # if self.graph_pool is None: + # self.graph_pool = graph.pool() + # self.graphs[bs] = graph + # torch.cuda.synchronize() + # reset_context() + + # self.graph_vars = dict( + # input_ids=input_ids, + # positions=positions, + # slot_mapping=slot_mapping, + # context_lens=context_lens, + # block_tables=block_tables, + # outputs=outputs, + # ) + + + # infinifore infer + def max_context_len(self): + return self.meta.dctx + # return self.config.max_model_len + + def create_kv_cache(self): + return self.model.create_kv_cache() + + def drop_kv_cache(self, kv_cache): + # drop_kv_cache(self.model, kv_cache) + self.model.drop_kv_cache(kv_cache) + + def create_paged_kv_cache(self, max_kvcache_tokens): + return self.model.create_paged_kv_cache(max_kvcache_tokens) + + # @torch.inference_mode() + # def batch_infer_one_round(self, tasks: List[InferTask]): + # output = (c_uint * len(tasks))() + # batch_inputs = InferBatchedTask(tasks) + # infer_batch( + # self.model, + # *(batch_inputs.input_args()), + # output, + # ) + # return list(output) + + def batch_infer_one_round(self, tasks: List[InferTask], is_prefill: int, batch_block_tables: list[int], slot_mapping: list[int]): + output = (c_uint * len(tasks))() + batch_inputs = None + if self.enable_paged_attn: + batch_inputs = InferPagedBatchedTask(tasks, batch_block_tables, slot_mapping, self.kv_cache, is_prefill) + else: + batch_inputs = InferBatchedTask(tasks, is_prefill) + self.model.infer_batch( + # self.model, + *(batch_inputs.input_args()), + self.enable_paged_attn, + output, + ) + return list(output) + + def run(self, seqs: list[Sequence], is_prefill: int) -> list[int]: + # input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) + # temperatures = self.prepare_sample(seqs) if self.rank == 0 else None + # logits = self.run_model(input_ids, positions, is_prefill) + # token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None + # reset_context() + batch_block_tables, slot_mapping = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) + tasks = [seq.infer_task for seq in seqs] + token_ids = self.batch_infer_one_round(tasks, is_prefill, batch_block_tables, slot_mapping) + + return token_ids + + + def batch_infer_one_round_for_logits(self, tasks: List[InferTask], is_prefill: int, batch_block_tables: list[int], slot_mapping: list[int]): + batch_inputs = None + if self.enable_paged_attn: + batch_inputs = InferPagedBatchedTask(tasks, batch_block_tables, slot_mapping, self.kv_cache, is_prefill) + else: + batch_inputs = InferBatchedTask(tasks, is_prefill) + logits = torch.zeros((batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits) + self.model.forward_batch( + # self.model, + *(batch_inputs.input_args_for_logits()), + self.enable_paged_attn, + logits.data_ptr(), + ) + return logits, batch_inputs.req_lens_list, batch_inputs.ntok + + def run_for_logits(self, seqs: list[Sequence], is_prefill: int) -> torch.Tensor: + # input_ids, positions = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) + # temperatures = self.prepare_sample(seqs) if self.rank == 0 else None + # logits = self.run_model(input_ids, positions, is_prefill) + # token_ids = self.sampler(logits, temperatures).tolist() if self.rank == 0 else None + # reset_context() + nll = 0.0 + total_len = 0 + batch_block_tables, slot_mapping = self.prepare_prefill(seqs) if is_prefill else self.prepare_decode(seqs) + tasks = [seq.infer_task for seq in seqs] + true_tokens = [seq.true_tokens for seq in seqs] + logits, req_lens_list, ntok = self.batch_infer_one_round_for_logits(tasks, is_prefill, batch_block_tables, slot_mapping) + token_ids_none = [None] * len(seqs) + + logits = logits.float() + token_ids = torch.tensor(true_tokens, dtype=torch.int64).reshape(-1) # [ntok,] + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # (ntok, vocab) + token_logprobs = log_probs[ + torch.arange(ntok), token_ids + ] # (ntok,) + + start = 0 + for l in req_lens_list: + nll += -token_logprobs[start : start + l].sum().item() + start += l + total_len += token_logprobs.numel() + + return nll, total_len, token_ids_none diff --git a/python/icinfer/engine/scheduler.py b/python/icinfer/engine/scheduler.py new file mode 100644 index 00000000..6782792d --- /dev/null +++ b/python/icinfer/engine/scheduler.py @@ -0,0 +1,91 @@ +from collections import deque + +from icinfer.config import Config +from icinfer.engine.sequence import Sequence, SequenceStatus +from icinfer.engine.block_manager import BlockManager +from icinfer.engine.infer_task import KVCache + + +class Scheduler: + + def __init__(self, config: Config): + self.max_num_seqs = config.max_num_seqs + self.max_num_batched_tokens = config.max_num_batched_tokens + self.eos = config.eos + self.block_manager = BlockManager(config.num_kvcache_blocks, config.kvcache_block_size) + self.waiting: deque[Sequence] = deque() + self.running: deque[Sequence] = deque() + + + def is_finished(self): + return not self.waiting and not self.running + + def add(self, seq: Sequence): + self.waiting.append(seq) + + def schedule(self) -> tuple[list[Sequence], int]: + # prefill + scheduled_seqs = [] + num_seqs = 0 + num_batched_tokens = 0 + is_prefill = 0 + while self.waiting and num_seqs < self.max_num_seqs: + seq = self.waiting[0] + if num_batched_tokens + len(seq) > self.max_num_batched_tokens or not self.block_manager.can_allocate(seq): + break + num_seqs += 1 + self.block_manager.allocate(seq) + num_batched_tokens += len(seq) - seq.num_cached_tokens + seq.status = SequenceStatus.RUNNING + self.waiting.popleft() + self.running.append(seq) + scheduled_seqs.append(seq) + if scheduled_seqs: + is_prefill = 1 + return scheduled_seqs, is_prefill + + # decode + while self.running and num_seqs < self.max_num_seqs: + seq = self.running.popleft() + while not self.block_manager.can_append(seq): + if self.running: + self.preempt(self.running.pop()) + else: + self.preempt(seq) + break + else: + num_seqs += 1 + self.block_manager.may_append(seq) + scheduled_seqs.append(seq) + assert scheduled_seqs + self.running.extendleft(reversed(scheduled_seqs)) + # print(f"is_prefill: {is_prefill}, schedule over.\n") + return scheduled_seqs, is_prefill + + def preempt(self, seq: Sequence): + seq.status = SequenceStatus.WAITING + self.block_manager.deallocate(seq) + self.waiting.appendleft(seq) + + # def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[bool]: + # for seq, token_id in zip(seqs, token_ids): + # seq.append_token(token_id) + # if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens: + # seq.status = SequenceStatus.FINISHED + # self.block_manager.deallocate(seq) + # self.running.remove(seq) + + def postprocess(self, seqs: list[Sequence], token_ids: list[int]) -> list[KVCache]: + drop_kvcache_list = [] + for seq, token_id in zip(seqs, token_ids): + seq.append_token(token_id) + if (not seq.ignore_eos and token_id == self.eos) or seq.num_completion_tokens == seq.max_tokens: + seq.status = SequenceStatus.FINISHED + drop_kvcache_list.append(seq.infer_task.release_kvcache()) + self.block_manager.deallocate(seq) + self.running.remove(seq) + return drop_kvcache_list + + @property + def block_size(self): + return self.block_manager.block_size diff --git a/python/icinfer/engine/sequence.py b/python/icinfer/engine/sequence.py new file mode 100644 index 00000000..f27b5469 --- /dev/null +++ b/python/icinfer/engine/sequence.py @@ -0,0 +1,96 @@ +from copy import copy +from enum import Enum, auto +from itertools import count + +from icinfer.sampling_params import SamplingParams +from icinfer.engine.infer_task import InferTask + + +class SequenceStatus(Enum): + WAITING = auto() + RUNNING = auto() + FINISHED = auto() + + +class Sequence: + # block_size = 256 + counter = count() + + def __init__(self, token_ids: list[int], sampling_params = SamplingParams(), block_size = 256, req_id = None): + self.seq_id = next(Sequence.counter) + self.status = SequenceStatus.WAITING + self.token_ids = copy(token_ids) + self.last_token = token_ids[-1] + self.num_tokens = len(self.token_ids) + self.num_prompt_tokens = len(token_ids) + self.num_cached_tokens = 0 + self.block_size = block_size + # self.block_table = None + self.block_table = [] + self.infer_task = None + + # for online serving + self.req_id = req_id + + self.true_tokens = None # for perplexity + self.temperature = sampling_params.temperature + self.max_tokens = sampling_params.max_tokens + self.ignore_eos = sampling_params.ignore_eos + + def __len__(self): + return self.num_tokens + + def __getitem__(self, key): + return self.token_ids[key] + + @property + def is_finished(self): + return self.status == SequenceStatus.FINISHED + + @property + def num_completion_tokens(self): + return self.num_tokens - self.num_prompt_tokens + + @property + def prompt_token_ids(self): + return self.token_ids[:self.num_prompt_tokens] + + @property + def completion_token_ids(self): + return self.token_ids[self.num_prompt_tokens:] + + @property + def num_cached_blocks(self): + return self.num_cached_tokens // self.block_size + + @property + def num_blocks(self): + return (self.num_tokens + self.block_size - 1) // self.block_size + + @property + def last_block_num_tokens(self): + return self.num_tokens - (self.num_blocks - 1) * self.block_size + + def block(self, i): + assert 0 <= i < self.num_blocks + return self.token_ids[i*self.block_size: (i+1)*self.block_size] + + def append_token(self, token_id: int): + self.token_ids.append(token_id) + self.infer_task.next(token_id) + self.last_token = token_id + self.num_tokens += 1 + + def bind_infer_task(self, infer_task: InferTask): + self.infer_task = infer_task + + def __getstate__(self): + return (self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table, + self.token_ids if self.num_completion_tokens == 0 else self.last_token) + + def __setstate__(self, state): + self.num_tokens, self.num_prompt_tokens, self.num_cached_tokens, self.block_table = state[:-1] + if self.num_completion_tokens == 0: + self.token_ids = state[-1] + else: + self.last_token = state[-1] diff --git a/python/icinfer/layers/sampler.py b/python/icinfer/layers/sampler.py new file mode 100644 index 00000000..e4b9816e --- /dev/null +++ b/python/icinfer/layers/sampler.py @@ -0,0 +1,18 @@ +import torch +from torch import nn + + +class Sampler(nn.Module): + + def __init__(self): + super().__init__() + + def forward(self, logits: torch.Tensor, temperatures: torch.Tensor): + logits = logits.to(torch.float) + greedy_tokens = logits.argmax(dim=-1) + logits.div_(temperatures.unsqueeze(dim=1)) + probs = torch.softmax(logits, dim=-1, dtype=torch.float) + # logprobs = torch.log_softmax(logits, dim=-1, dtype=torch.float) + epsilon = 1e-10 + sample_tokens = probs.div_(torch.empty_like(probs).exponential_(1) + epsilon).argmax(dim=-1) + return torch.where(temperatures == 0, greedy_tokens, sample_tokens) diff --git a/python/icinfer/llm.py b/python/icinfer/llm.py new file mode 100644 index 00000000..be38ecb0 --- /dev/null +++ b/python/icinfer/llm.py @@ -0,0 +1,5 @@ +from icinfer.engine.llm_engine import InfiniEngine + + +class LLM(InfiniEngine): + pass diff --git a/python/icinfer/models/__init__.py b/python/icinfer/models/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/python/icinfer/models/auto_modeling/__init__.py b/python/icinfer/models/auto_modeling/__init__.py new file mode 100644 index 00000000..17cc7e3d --- /dev/null +++ b/python/icinfer/models/auto_modeling/__init__.py @@ -0,0 +1 @@ +from .factory import AutoModelForCausalLM \ No newline at end of file diff --git a/python/icinfer/models/auto_modeling/factory.py b/python/icinfer/models/auto_modeling/factory.py new file mode 100644 index 00000000..1a6ea097 --- /dev/null +++ b/python/icinfer/models/auto_modeling/factory.py @@ -0,0 +1,84 @@ +import importlib +from collections import OrderedDict +# from icinfer.models.libinfinicore_infer.base import DeviceType +import os +import json + +# 1. Define the mapping for model classes +# Using OrderedDict to maintain order for documentation purposes +MODEL_MAPPING_FILES = OrderedDict([ + ("fm9g7b", "jiuge"), + ("jiuge_awq", "jiuge_awq"), + ("deepseek_v3", "deepseek_v3"), +]) +MODEL_MAPPING_NAMES = OrderedDict([ + ("jiuge", "JiugeForCausalLM"), + ("jiuge_awq", "JiugeAWQForCausalLM"), + ("deepseek_v3", "DeepSeekV3ForCauslLM"), +]) + +# Dynamically import the module to avoid loading all model code at startup (lazy loading) +def get_model_class(config): + """Gets the corresponding model class based on model_type in the config.""" + # model_type = config.get("model_type") + model_type = config.hf_config.model_type + model_file = MODEL_MAPPING_FILES[model_type] + print(f"model_file: {model_file}") + if model_file in MODEL_MAPPING_NAMES: + model_name = MODEL_MAPPING_NAMES[model_file] + + # Dynamically construct the module path and import it. + # For example, if model_type="llama", it will import icinfer.models.llama + module_path = f"icinfer.models.{model_file}" + module = importlib.import_module(module_path) + + # Get the model class from the imported module + model_class = getattr(module, model_name, None) + if model_class is None: + raise AttributeError(f"Module '{module_path}' does not have a class named '{model_name}'") + return model_class + + raise KeyError( + f"Model type '{model_type}' not found in MODEL_MAPPING_NAMES. " + f"Available model types: {list(MODEL_MAPPING_NAMES.keys())}" + ) + + +class AutoModelForCausalLM: + """ + This is a generic model factory class. It will automatically instantiate the + correct model architecture based on the provided config. + """ + def __init__(self): + # Direct instantiation of AutoModel is not allowed + raise EnvironmentError( + "AutoModel is designed to be instantiated using the `AutoModel.from_config(config)` class method." + ) + + @classmethod + def from_config(cls, config, device): + """ + Instantiates one of the base model classes from a configuration. + + Args: + config (dict): A configuration dictionary containing a 'model_type' field. + + Returns: + An instance of a model class (e.g., LlamaModel, BertModel). + """ + model_class = get_model_class(config) + max_tokens = config.max_model_len + model_dir_path = config.model_path + ndev = config.tensor_parallel_size + # hf_config = config.hf_config + hf_config = None + with open(os.path.join(model_dir_path, "config.json"), "r") as f: + hf_config = json.load(f) + + model = model_class(model_dir_path=model_dir_path, + config=hf_config, + device=device, + ndev=ndev, + max_tokens=max_tokens) + # Call the from_config method of the specific model class + return model \ No newline at end of file diff --git a/python/icinfer/models/deepseek.py b/python/icinfer/models/deepseek.py new file mode 100644 index 00000000..bba5a373 --- /dev/null +++ b/python/icinfer/models/deepseek.py @@ -0,0 +1,779 @@ +import ctypes +from typing import List, Sequence + +from tqdm import tqdm + +from libinfinicore_infer import ( + DeepSeekV3Model, + DeepSeekV3MetaCStruct, + DeepSeekV3CacheCStruct, + DataType, + DeviceType, +) +from infer_task import InferTask, KVCache + +from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref +import os +from pathlib import Path +import safetensors +import sys +import time +import json +import math +import torch +import transformers + +torch.set_default_device("cpu") + + +class DeepseekR1WeightsNaming: + def __init__(self, dense_replace=3): + self.dense_replace = dense_replace + + def input_embd(self): + return "model.embed_tokens.weight" + + def output_norm(self): + return "model.norm.weight" + + def output_embd(self): + return "lm_head.weight" + + # MLA + def attn_norm(self, i): + return f"model.layers.{i}.input_layernorm.weight" + + def attn_kv_a_layernorm(self, i): + return f"model.layers.{i}.self_attn.kv_a_layernorm.weight" + + def attn_kv_a_proj_with_mqa_weight(self, i): + return f"model.layers.{i}.self_attn.kv_a_proj_with_mqa.qweight" + + def attn_kv_a_proj_with_mqa_scale(self, i): + return f"model.layers.{i}.self_attn.kv_a_proj_with_mqa.scales" + + def attn_kv_a_proj_with_mqa_zero(self, i): + return f"model.layers.{i}.self_attn.kv_a_proj_with_mqa.qzeros" + + def attn_kv_b_proj_weight(self, i): + return f"model.layers.{i}.self_attn.kv_b_proj.qweight" + + def attn_kv_b_proj_scale(self, i): + return f"model.layers.{i}.self_attn.kv_b_proj.scales" + + def attn_kv_b_proj_zero(self, i): + return f"model.layers.{i}.self_attn.kv_b_proj.qzeros" + + def attn_o_proj_weight(self, i): + return f"model.layers.{i}.self_attn.o_proj.qweight" + + def attn_o_proj_scale(self, i): + return f"model.layers.{i}.self_attn.o_proj.scales" + + def attn_o_proj_zero(self, i): + return f"model.layers.{i}.self_attn.o_proj.qzeros" + + def attn_q_a_layernorm(self, i): + return f"model.layers.{i}.self_attn.q_a_layernorm.weight" + + def attn_q_a_proj_weight(self, i): + return f"model.layers.{i}.self_attn.q_a_proj.qweight" + + def attn_q_a_proj_scale(self, i): + return f"model.layers.{i}.self_attn.q_a_proj.scales" + + def attn_q_a_proj_zero(self, i): + return f"model.layers.{i}.self_attn.q_a_proj.qzeros" + + def attn_q_b_proj_weight(self, i): + return f"model.layers.{i}.self_attn.q_b_proj.qweight" + + def attn_q_b_proj_scale(self, i): + return f"model.layers.{i}.self_attn.q_b_proj.scales" + + def attn_q_b_proj_zero(self, i): + return f"model.layers.{i}.self_attn.q_b_proj.qzeros" + + # MLP + + def mlp_norm(self, i): + return f"model.layers.{i}.post_attention_layernorm.weight" + + # First self.dense_replace layers are dense + def mlp_down_proj_weight(self, i): + assert i < self.dense_replace + return f"model.layers.{i}.mlp.down_proj.qweight" + + def mlp_down_proj_scale(self, i): + assert i < self.dense_replace + return f"model.layers.{i}.mlp.down_proj.scales" + + def mlp_down_proj_zero(self, i): + assert i < self.dense_replace + return f"model.layers.{i}.mlp.down_proj.qzeros" + + def mlp_up_proj_weight(self, i): + assert i < self.dense_replace + return f"model.layers.{i}.mlp.up_proj.qweight" + + def mlp_up_proj_scale(self, i): + assert i < self.dense_replace + return f"model.layers.{i}.mlp.up_proj.scales" + + def mlp_up_proj_zero(self, i): + assert i < self.dense_replace + return f"model.layers.{i}.mlp.up_proj.qzeros" + + def mlp_gate_proj_weight(self, i): + assert i < self.dense_replace + return f"model.layers.{i}.mlp.gate_proj.qweight" + + def mlp_gate_proj_scale(self, i): + assert i < self.dense_replace + return f"model.layers.{i}.mlp.gate_proj.scales" + + def mlp_gate_proj_zero(self, i): + assert i < self.dense_replace + return f"model.layers.{i}.mlp.gate_proj.qzeros" + + # Latter layers are sparse + # Gating + def mlp_gate_weight(self, i): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.gate.weight" + + def mlp_gate_bias(self, i): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.gate.e_score_correction_bias" + + # Experts + def mlp_shared_experts_down_proj_weight(self, i): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.shared_experts.down_proj.qweight" + + def mlp_shared_experts_down_proj_scale(self, i): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.shared_experts.down_proj.scales" + + def mlp_shared_experts_down_proj_zero(self, i): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.shared_experts.down_proj.qzeros" + + def mlp_shared_experts_gate_proj_weight(self, i): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.shared_experts.gate_proj.qweight" + + def mlp_shared_experts_gate_proj_scale(self, i): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.shared_experts.gate_proj.scales" + + def mlp_shared_experts_gate_proj_zero(self, i): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.shared_experts.gate_proj.qzeros" + + def mlp_shared_experts_up_proj_weight(self, i): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.shared_experts.up_proj.qweight" + + def mlp_shared_experts_up_proj_scale(self, i): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.shared_experts.up_proj.scales" + + def mlp_shared_experts_up_proj_zero(self, i): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.shared_experts.up_proj.qzeros" + + # Experts + def mlp_experts_down_proj_weight(self, i, e): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.experts.{e}.down_proj.qweight" + + def mlp_experts_down_proj_scale(self, i, e): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.experts.{e}.down_proj.scales" + + def mlp_experts_down_proj_zero(self, i, e): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.experts.{e}.down_proj.qzeros" + + def mlp_experts_gate_proj_weight(self, i, e): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.experts.{e}.gate_proj.qweight" + + def mlp_experts_gate_proj_scale(self, i, e): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.experts.{e}.gate_proj.scales" + + def mlp_experts_gate_proj_zero(self, i, e): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.experts.{e}.gate_proj.qzeros" + + def mlp_experts_up_proj_weight(self, i, e): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.experts.{e}.up_proj.qweight" + + def mlp_experts_up_proj_scale(self, i, e): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.experts.{e}.up_proj.scales" + + def mlp_experts_up_proj_zero(self, i, e): + assert i >= self.dense_replace + return f"model.layers.{i}.mlp.experts.{e}.up_proj.qzeros" + + +class DeepSeekV3Meta(DeepSeekV3MetaCStruct): + def __init__(self, config, dtype=torch.float16, max_tokens=None): + if dtype == torch.float16: + dt_ = DataType.INFINI_DTYPE_F16 + elif dtype == torch.bfloat16: + dt_ = DataType.INFINI_DTYPE_BF16 + else: + dt_ = DataType.INFINI_DTYPE_F16 + + super().__init__( + # dtypes + dt_logits=DataType.INFINI_DTYPE_F16, + dt_norm=DataType.INFINI_DTYPE_BF16, + dt_quant_weight=DataType.INFINI_DTYPE_I32, + dt_quant_scale=DataType.INFINI_DTYPE_F16, + dt_quant_zero=DataType.INFINI_DTYPE_I32, + dt_gate_weight=DataType.INFINI_DTYPE_BF16, + dt_gate_bias=DataType.INFINI_DTYPE_BF16, + # sizes + n_sparse_layer=config["num_hidden_layers"], + n_dense_layer=config.get("first_k_dense_replace", 0), + d=config["hidden_size"], + nh=config["num_attention_heads"], + nkvh=config.get("num_key_value_heads", config["num_attention_heads"]), + d_rope=config["qk_rope_head_dim"], + d_nope=config["qk_nope_head_dim"], + r_q=config["q_lora_rank"], + r_kv=config["kv_lora_rank"], + d_qk=config["qk_nope_head_dim"] + config["qk_rope_head_dim"], + d_v=config["v_head_dim"], + # routing / experts / vocab / ctx + routed_scale=config.get("routed_scaling_factor", 1.0), + nexperts=config["n_routed_experts"], + kexperts=config["num_experts_per_tok"], + di=config["intermediate_size"], + di_moe=config["moe_intermediate_size"], + dctx=( + config["max_position_embeddings"] if max_tokens is None else max_tokens + ), + dvoc=config["vocab_size"], + # misc + epsilon=config.get("rms_norm_eps", 1e-6), + rope_theta=config.get("rope_theta", 10000.0), + end_token=config.get("eos_token_id", 2), + ) + self.torch_dtype_logits = dtype + + +def load_specific_tensor(model_dir, tensor_name): + """ + Load a specific tensor from a sharded safetensors model using its index JSON. + """ + index_file = os.path.join(model_dir, "model.safetensors.index.json") + if not os.path.exists(index_file): + raise FileNotFoundError(f"Index file not found: {index_file}") + + with open(index_file, "r") as f: + index = json.load(f) + + # Get mapping: tensor name -> file name + weight_map = index["weight_map"] + if tensor_name not in weight_map: + raise KeyError(f"{tensor_name} not found in index") + + filename = weight_map[tensor_name] + tensor_file = os.path.join(model_dir, filename) + + # Open only the relevant file and tensor + with safetensors.safe_open(tensor_file, framework="pt", device="cpu") as f: + tensor = f.get_tensor(tensor_name) + return tensor + + +def load_deepseek_weights( + meta: DeepSeekV3Meta, + weights, + model_path: str, + ndev: int, +): + model_instance = DeepSeekV3Model() + weight_loader = model_instance.create_weight_loader() + names = DeepseekR1WeightsNaming() + input_embd = load_specific_tensor(model_path, names.input_embd()).to( + meta.torch_dtype_logits + ) + weight_loader.contents.load_input_embd(weights, input_embd.data_ptr()) + del input_embd + + output_norm = load_specific_tensor(model_path, names.output_norm()) + weight_loader.contents.load_output_norm(weights, output_norm.data_ptr()) + del output_norm + + output_embd = load_specific_tensor(model_path, names.output_embd()) + weight_loader.contents.load_output_embd(weights, output_embd.data_ptr()) + del output_embd + + # ------------------------------- + # Per-layer weights + # ------------------------------- + + def load_quant(w_name, s_name, zero_name, split_dim=0): + weight = load_specific_tensor(model_path, w_name) + scale = load_specific_tensor(model_path, s_name) + zero = load_specific_tensor(model_path, zero_name) + if split_dim == 0 or ndev == 1: + return weight, scale, zero + elif split_dim == 1: + weight = ( + weight.reshape(weight.shape[0], ndev, -1).permute(1, 0, 2).contiguous() + ) + scale = ( + scale.reshape(scale.shape[0], ndev, -1).permute(1, 0, 2).contiguous() + ) + zero = zero.reshape(zero.shape[0], ndev, -1).permute(1, 0, 2).contiguous() + return weight, scale, zero + else: + raise ValueError("split_dim must be 0 or 1") + + for i in tqdm( + range(meta.n_sparse_layer + meta.n_dense_layer), desc="Loading layers" + ): + + # Attention norms + projections + attn_norm = load_specific_tensor(model_path, names.attn_norm(i)) + weight_loader.contents.load_attn_norm(weights, attn_norm.data_ptr(), i) + del attn_norm + + load_attn_q_a_layernorm = load_specific_tensor( + model_path, names.attn_q_a_layernorm(i) + ) + weight_loader.contents.load_attn_q_a_layernorm( + weights, load_attn_q_a_layernorm.data_ptr(), i + ) + del load_attn_q_a_layernorm + + attn_kv_a_layernorm = load_specific_tensor( + model_path, names.attn_kv_a_layernorm(i) + ) + weight_loader.contents.load_attn_kv_a_layernorm( + weights, attn_kv_a_layernorm.data_ptr(), i + ) + del attn_kv_a_layernorm + + w, s, z = load_quant( + names.attn_q_a_proj_weight(i), + names.attn_q_a_proj_scale(i), + names.attn_q_a_proj_zero(i), + ) + weight_loader.contents.load_attn_q_a_proj( + weights, w.data_ptr(), s.data_ptr(), z.data_ptr(), i + ) + + w, s, z = load_quant( + names.attn_q_b_proj_weight(i), + names.attn_q_b_proj_scale(i), + names.attn_q_b_proj_zero(i), + ) + weight_loader.contents.load_attn_q_b_proj( + weights, w.data_ptr(), s.data_ptr(), z.data_ptr(), i + ) + + w, s, z = load_quant( + names.attn_kv_a_proj_with_mqa_weight(i), + names.attn_kv_a_proj_with_mqa_scale(i), + names.attn_kv_a_proj_with_mqa_zero(i), + ) + weight_loader.contents.load_attn_kv_a_proj_with_mqa( + weights, w.data_ptr(), s.data_ptr(), z.data_ptr(), i + ) + + w, s, z = load_quant( + names.attn_kv_b_proj_weight(i), + names.attn_kv_b_proj_scale(i), + names.attn_kv_b_proj_zero(i), + ) + + weight_loader.contents.load_attn_kv_b_proj( + weights, w.data_ptr(), s.data_ptr(), z.data_ptr(), i + ) + + w, s, z = load_quant( + names.attn_o_proj_weight(i), + names.attn_o_proj_scale(i), + names.attn_o_proj_zero(i), + 1, + ) + + weight_loader.contents.load_attn_o_proj( + weights, w.data_ptr(), s.data_ptr(), z.data_ptr(), i + ) + + # ------------------------------- + # MLP: dense or sparse + # ------------------------------- + mlp_norm = load_specific_tensor(model_path, names.mlp_norm(i)) + weight_loader.contents.load_mlp_norm(weights, mlp_norm.data_ptr(), i) + + if i < meta.n_dense_layer: + # Dense MLP is grouped into one call + w_gate, s_gate, z_gate = load_quant( + names.mlp_gate_proj_weight(i), + names.mlp_gate_proj_scale(i), + names.mlp_gate_proj_zero(i), + ) + w_up, s_up, z_up = load_quant( + names.mlp_up_proj_weight(i), + names.mlp_up_proj_scale(i), + names.mlp_up_proj_zero(i), + ) + w_down, s_down, z_down = load_quant( + names.mlp_down_proj_weight(i), + names.mlp_down_proj_scale(i), + names.mlp_down_proj_zero(i), + 1, + ) + weight_loader.contents.load_mlp_dense( + weights, + w_gate.data_ptr(), + s_gate.data_ptr(), + z_gate.data_ptr(), + w_up.data_ptr(), + s_up.data_ptr(), + z_up.data_ptr(), + w_down.data_ptr(), + s_down.data_ptr(), + z_down.data_ptr(), + i, + ) + + else: + # Sparse MLP gating + mlp_gate_weight = load_specific_tensor(model_path, names.mlp_gate_weight(i)) + weight_loader.contents.load_mlp_gate_weight( + weights, mlp_gate_weight.data_ptr(), i + ) + del mlp_gate_weight + + mlp_gate_bias = load_specific_tensor(model_path, names.mlp_gate_bias(i)) + weight_loader.contents.load_mlp_gate_bias( + weights, mlp_gate_bias.data_ptr(), i + ) + del mlp_gate_bias + + # Shared experts + w_gate, s_gate, z_gate = load_quant( + names.mlp_shared_experts_gate_proj_weight(i), + names.mlp_shared_experts_gate_proj_scale(i), + names.mlp_shared_experts_gate_proj_zero(i), + ) + w_up, s_up, z_up = load_quant( + names.mlp_shared_experts_up_proj_weight(i), + names.mlp_shared_experts_up_proj_scale(i), + names.mlp_shared_experts_up_proj_zero(i), + ) + w_down, s_down, z_down = load_quant( + names.mlp_shared_experts_down_proj_weight(i), + names.mlp_shared_experts_down_proj_scale(i), + names.mlp_shared_experts_down_proj_zero(i), + 1, + ) + weight_loader.contents.load_mlp_shared_experts( + weights, + w_gate.data_ptr(), + s_gate.data_ptr(), + z_gate.data_ptr(), + w_up.data_ptr(), + s_up.data_ptr(), + z_up.data_ptr(), + w_down.data_ptr(), + s_down.data_ptr(), + z_down.data_ptr(), + i, + ) + + # Per-expert MLP + for e in range(meta.nexperts): + w_gate, s_gate, z_gate = load_quant( + names.mlp_experts_gate_proj_weight(i, e), + names.mlp_experts_gate_proj_scale(i, e), + names.mlp_experts_gate_proj_zero(i, e), + ) + w_up, s_up, z_up = load_quant( + names.mlp_experts_up_proj_weight(i, e), + names.mlp_experts_up_proj_scale(i, e), + names.mlp_experts_up_proj_zero(i, e), + ) + w_down, s_down, z_down = load_quant( + names.mlp_experts_down_proj_weight(i, e), + names.mlp_experts_down_proj_scale(i, e), + names.mlp_experts_down_proj_zero(i, e), + 1, + ) + weight_loader.contents.load_mlp_experts( + weights, + w_gate.data_ptr(), + s_gate.data_ptr(), + z_gate.data_ptr(), + w_up.data_ptr(), + s_up.data_ptr(), + z_up.data_ptr(), + w_down.data_ptr(), + s_down.data_ptr(), + z_down.data_ptr(), + i, + e, + ) + + +class DeepSeekV3BatchedTask: + def __init__(self, tasks: List[InferTask]): + self.tasks = tasks + self.nreq = len(tasks) + + # Precompute fields + token_lists = [t.tokens for t in tasks] + self.req_lens_list = [len(toks) for toks in token_lists] + self.req_pos_list = [t.pos for t in tasks] + self.kv_cache_ptrs = [t.kvcache().data() for t in tasks] + self.temperaturas_list = [t.temperature for t in tasks] + self.topks_list = [t.topk for t in tasks] + self.topps_list = [t.topp for t in tasks] + + # Flatten token lists + flat_tokens = [tok for toks in token_lists for tok in toks] + self.ntok = len(flat_tokens) + + # Convert to ctypes arrays in one pass + self.tokens = (c_uint * self.ntok)(*flat_tokens) + self.req_lens = (c_uint * self.nreq)(*self.req_lens_list) + self.req_pos = (c_uint * self.nreq)(*self.req_pos_list) + self.kv_caches = (POINTER(DeepSeekV3CacheCStruct) * self.nreq)( + *self.kv_cache_ptrs + ) + self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list) + self.topks = (c_uint * self.nreq)(*self.topks_list) + self.topps = (c_float * self.nreq)(*self.topps_list) + + def input_args(self): + return ( + self.tokens, + self.ntok, + self.req_lens, + self.nreq, + self.req_pos, + self.kv_caches, + self.temperaturas, + self.topks, + self.topps, + ) + + +class DeepSeekV3ForCauslLM: + def __init__( + self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None + ): + with open(os.path.join(model_dir_path, "config.json"), "r") as f: + config = json.load(f) + self.config = config + eos_token_id = self.config["eos_token_id"] + self.eos_token_id = ( + [eos_token_id] if type(eos_token_id) == int else eos_token_id + ) + + print(model_dir_path) + + if "deepseek_v3" == config["model_type"]: + self.meta = DeepSeekV3Meta( + config, max_tokens=max_tokens, dtype=torch.float16 + ) + self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path) + else: + raise ValueError("Unsupported model architecture") + + print(f"Creating model on {ndev} devices...") + load_start_time = time.time() + dev_ids = (c_int * ndev)(*[i for i in range(ndev)]) + + self.model_instance = DeepSeekV3Model() + weights = self.model_instance.create_weights( + byref(self.meta), + device, + ndev, + dev_ids, + ) + # Load weights from host + load_deepseek_weights(self.meta, weights, model_dir_path, ndev) + # Create model instance + self.model_ptr = self.model_instance.create_model( + byref(self.meta), + weights, + ) + load_end_time = time.time() + print(f"Time used: {load_end_time - load_start_time:.3f}s") + + def max_context_len(self): + return self.meta.dctx + + def create_kv_cache(self): + return self.model_instance.create_cache(self.model_ptr) + + def drop_kv_cache(self, kv_cache): + self.model_instance.drop_cache(self.model_ptr, kv_cache) + + def batch_infer_one_round(self, tasks: List[InferTask]): + output = (c_uint * len(tasks))() + batch_inputs = DeepSeekV3BatchedTask(tasks) + self.model_instance.infer_batch( + self.model_ptr, + *(batch_inputs.input_args()), + output, + ) + return list(output) + + def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0): + input_content = self.tokenizer.apply_chat_template( + conversation=[{"role": "user", "content": input_content}], + add_generation_prompt=True, + tokenize=False, + ) + + tokens = self.tokenizer.encode(input_content) + infer_task = InferTask( + 0, + tokens, + self.max_context_len(), + temperature_, + topk_, + topp_, + self.eos_token_id, + ) + infer_task.bind_kvcache(KVCache(self)) + print(input_content, end="", flush=True) + steps = 0 + total_time = 0 + output_content = "" + + for step_i in range(max_steps): + start_time = time.time() + output_tokens = self.batch_infer_one_round([infer_task]) + end_time = time.time() + steps += 1 + output_str = ( + self.tokenizer._tokenizer.id_to_token(output_tokens[0]) + .replace("▁", " ") + .replace("<0x0A>", "\n") + ) + output_content += output_str + print(output_str, end="", flush=True) + if output_tokens[0] in self.eos_token_id: + break + infer_task.next(output_tokens[0]) + + if step_i > 0: + total_time += end_time - start_time + + print("\n") + avg_time = total_time * 1000 / (steps - 1) + print(f"Time per step: {avg_time:.3f}ms") + + infer_task._kv_cache.drop(self) + return output_content, avg_time + + # def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10): + # tasks = [ + # InferTask(i, [], self.max_context_len(), 1.0, 1, 1.0, self.eos_token_id) + # for i in range(batch_size) + # ] + # kv_caches = [KVCache(self) for _ in range(batch_size)] + + # nll = 0.0 + # total_len = 0 + + # for i in range(0, len(test_sequences), batch_size): + # batch_id = 0 + # true_tokens = [] + # while batch_id < batch_size and batch_id + i < len(test_sequences): + # input_tokens = test_sequences[i + batch_id][:-1] + # true_tokens.extend(test_sequences[i + batch_id][1:]) + # tasks[batch_id].tokens = input_tokens + # tasks[batch_id].bind_kvcache(kv_caches[batch_id]) + # batch_id += 1 + + # batch_inputs = DeepSeekV3BatchedTask(tasks[:batch_id]) + # logits = torch.zeros( + # (batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits + # ) + # forward_batch_deepseek_v3( + # self.model_instance, + # batch_inputs.tokens, + # batch_inputs.ntok, + # batch_inputs.req_lens, + # batch_inputs.nreq, + # batch_inputs.req_pos, + # batch_inputs.kv_caches, + # logits.data_ptr(), + # ) + + # logits = logits.float() + # token_ids = torch.tensor(true_tokens, dtype=torch.int64) # [ntok,] + # log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # (ntok, vocab) + # token_logprobs = log_probs[ + # torch.arange(batch_inputs.ntok), token_ids + # ] # (ntok,) + + # start = 0 + # for l in batch_inputs.req_lens_list: + # nll += -token_logprobs[start : start + l].sum().item() + # start += l + # total_len += token_logprobs.numel() + + # for task in tasks: + # task.release_kvcache() + + # return math.exp(nll / total_len) + + def destroy_model_instance(self): + self.model_instance.destroy_model(self.model_ptr) + print("Model destroyed") + + +def test(): + if len(sys.argv) < 3: + print( + "Usage: python deepseek.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" + ) + sys.exit(1) + model_path = sys.argv[2] + device_type = DeviceType.DEVICE_TYPE_CPU + if sys.argv[1] == "--cpu": + device_type = DeviceType.DEVICE_TYPE_CPU + elif sys.argv[1] == "--nvidia": + device_type = DeviceType.DEVICE_TYPE_NVIDIA + elif sys.argv[1] == "--cambricon": + device_type = DeviceType.DEVICE_TYPE_CAMBRICON + elif sys.argv[1] == "--ascend": + device_type = DeviceType.DEVICE_TYPE_ASCEND + elif sys.argv[1] == "--metax": + device_type = DeviceType.DEVICE_TYPE_METAX + elif sys.argv[1] == "--moore": + device_type = DeviceType.DEVICE_TYPE_MOORE + elif sys.argv[1] == "--iluvatar": + device_type = DeviceType.DEVICE_TYPE_ILUVATAR + else: + print( + "Usage: python deepseek.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" + ) + sys.exit(1) + + ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1 + model = DeepSeekV3ForCauslLM(model_path, device_type, ndev, max_tokens=1024) + model.generate("山东最高的山是?", 50) + model.destroy_model_instance() + + +if __name__ == "__main__": + test() diff --git a/python/icinfer/models/jiuge.py b/python/icinfer/models/jiuge.py new file mode 100644 index 00000000..b2edbacf --- /dev/null +++ b/python/icinfer/models/jiuge.py @@ -0,0 +1,733 @@ +from typing import List, Sequence +from sympy import true +import math +from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref +import os +from pathlib import Path +import safetensors +import sys +import time +import json +import torch +import transformers + +from icinfer.models.libinfinicore_infer.jiuge import ( + JiugeModel, + JiugeMetaCStruct, + JiugeWeightsCStruct, + DataType, + DeviceType, + KVCacheCStruct, +) +from icinfer.engine.infer_task import InferTask, KVCache + +import logging +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + +torch.set_default_device("cpu") + + +class LlamaWeightsNaming: + def input_embd(self): + return "model.embed_tokens.weight" + + def output_norm(self): + return "model.norm.weight" + + def output_embd(self): + return "lm_head.weight" + + def attn_norm(self, i): + return f"model.layers.{i}.input_layernorm.weight" + + def attn_q(self, i): + return f"model.layers.{i}.self_attn.q_proj.weight" + + def attn_k(self, i): + return f"model.layers.{i}.self_attn.k_proj.weight" + + def attn_v(self, i): + return f"model.layers.{i}.self_attn.v_proj.weight" + + def attn_o(self, i): + return f"model.layers.{i}.self_attn.o_proj.weight" + + def attn_q_b(self, i): + return f"model.layers.{i}.self_attn.q_proj.bias" + + def attn_k_b(self, i): + return f"model.layers.{i}.self_attn.k_proj.bias" + + def attn_v_b(self, i): + return f"model.layers.{i}.self_attn.v_proj.bias" + + def ffn_norm(self, i): + return f"model.layers.{i}.post_attention_layernorm.weight" + + def gate(self, i): + return f"model.layers.{i}.mlp.gate_proj.weight" + + def up(self, i): + return f"model.layers.{i}.mlp.up_proj.weight" + + def down(self, i): + return f"model.layers.{i}.mlp.down_proj.weight" + + def match(state_dict): + return ( + "model.norm.weight" in state_dict + and "model.layers.0.self_attn.q_proj.weight" in state_dict + ) + + +class JiugeMetaFromLlama(JiugeMetaCStruct): + def __init__(self, config, dtype=torch.float16, max_tokens=None): + if dtype == torch.float16: + dt_ = DataType.INFINI_DTYPE_F16 + elif dtype == torch.float32: + dt_ = DataType.INFINI_DTYPE_F32 + elif dtype == torch.bfloat16: + dt_ = DataType.INFINI_DTYPE_BF16 + else: + dt_ = DataType.INFINI_DTYPE_F16 + + self.scale_input = 1.0 + self.scale_output = 1.0 + self.scale_o = 1.0 + self.scale_down = 1.0 + if ( + config["model_type"] in ["fm9g", "minicpm"] + and "scale_emb" in config + and "scale_depth" in config + and "dim_model_base" in config + ): + self.scale_input = config["scale_emb"] + self.scale_output = config["hidden_size"] // config["dim_model_base"] + self.scale_o = config["scale_depth"] / math.sqrt( + config["num_hidden_layers"] + ) + self.scale_down = config["scale_depth"] / math.sqrt( + config["num_hidden_layers"] + ) + + super().__init__( + dt_logits=dt_, + nlayer=config["num_hidden_layers"], + d=config["hidden_size"], + nh=config["num_attention_heads"], + nkvh=( + config["num_key_value_heads"] + if "num_key_value_heads" in config + else config["num_attention_heads"] + ), + dh=config["hidden_size"] // config["num_attention_heads"], + di=config["intermediate_size"], + dctx=( + config["max_position_embeddings"] if max_tokens is None else max_tokens + ), + dvoc=config["vocab_size"], + epsilon=config["rms_norm_eps"], + theta=(config["rope_theta"] if "rope_theta" in config else 100000.0), + end_token=2, + ) + self.torch_dtype_logits = dtype + + +class JiugeWeightsImpl(JiugeWeightsCStruct): + def __init__( + self, + meta, + naming, + state_dict, + torch_dt_mat=torch.float16, + torch_dt_norm=torch.float32, + ndev=1, + transpose_weight=True, + ): + nlayer = meta.nlayer + nh = meta.nh + nkvh = meta.nkvh + dh = meta.dh + d = meta.d + di = meta.di + scale_input = meta.scale_input + scale_output = meta.scale_output + scale_o = meta.scale_o + scale_down = meta.scale_down + assert nh % nkvh == 0 + assert nh % ndev == 0 + assert nkvh % ndev == 0 + assert di % ndev == 0 + torch_dt_logits = meta.torch_dtype_logits + if torch_dt_mat == torch.float16: + self.dt_mat = DataType.INFINI_DTYPE_F16 + elif torch_dt_mat == torch.float32: + self.dt_mat = DataType.INFINI_DTYPE_F32 + elif torch_dt_mat == torch.bfloat16: + self.dt_mat = DataType.INFINI_DTYPE_BF16 + else: + raise ValueError("Unsupported proj weight data type") + if torch_dt_norm == torch.float16: + self.dt_norm = DataType.INFINI_DTYPE_F16 + elif torch_dt_norm == torch.float32: + self.dt_norm = DataType.INFINI_DTYPE_F32 + elif torch_dt_norm == torch.bfloat16: + self.dt_norm = DataType.INFINI_DTYPE_BF16 + else: + raise ValueError("Unsupported norm weight data type") + + input_embd_naming = ( + naming.input_embd() + if naming.input_embd() in state_dict + else naming.output_embd() + ) + output_embd_naming = ( + naming.output_embd() + if naming.output_embd() in state_dict + else naming.input_embd() + ) + self.transpose_linear_weights = 1 if transpose_weight else 0 + self.nlayer = nlayer + self.input_embd_tensor = ( + state_dict[input_embd_naming].to(torch_dt_logits) * scale_input + ) + self.input_embd = self.input_embd_tensor.data_ptr() + self.output_norm_tensor = ( + state_dict[naming.output_norm()].to(torch_dt_norm) * scale_output + ) + self.output_norm = self.output_norm_tensor.data_ptr() + self.output_embd_tensor = state_dict[output_embd_naming].to(torch_dt_mat) + if not transpose_weight: + self.output_embd_tensor = self.output_embd_tensor.transpose( + 0, 1 + ).contiguous() + self.output_embd = self.output_embd_tensor.data_ptr() + + self.attn_norm_tensors = [ + state_dict[naming.attn_norm(i)].to(torch_dt_norm) for i in range(nlayer) + ] + self.attn_norm_ptrs = [ + self.attn_norm_tensors[i].data_ptr() for i in range(nlayer) + ] + self.attn_norm = (c_void_p * nlayer)(*self.attn_norm_ptrs) + + def qkv_slices(_i): + _Q = ( + state_dict[naming.attn_q(_i)] + .reshape([nh, 2, dh // 2, d]) + .transpose(1, 2) + ) + _K = ( + state_dict[naming.attn_k(_i)] + .reshape([nkvh, 2, dh // 2, d]) + .transpose(1, 2) + ) + _V = state_dict[naming.attn_v(_i)].reshape([nkvh, dh // 2, 2, d]) + _result = [] + _nh = nh // ndev + _nkvh = nkvh // ndev + for _idev in range(ndev): + _result.append(_Q[_idev * _nh : (_idev + 1) * _nh, :, :, :]) + _result.append(_K[_idev * _nkvh : (_idev + 1) * _nkvh, :, :, :]) + _result.append(_V[_idev * _nkvh : (_idev + 1) * _nkvh, :, :]) + return _result + + self.qkv_tensor = [ + torch.concat(qkv_slices(i)).to(torch_dt_mat) for i in range(nlayer) + ] + if not transpose_weight: + for i in range(nlayer): + self.qkv_tensor[i] = ( + self.qkv_tensor[i] + .reshape(ndev, (nh + 2 * nkvh) // ndev * dh, d) + .transpose(1, 2) + .contiguous() + ) + self.qkv_tensor_ptrs = [self.qkv_tensor[i].data_ptr() for i in range(nlayer)] + self.attn_qkv = (c_void_p * nlayer)(*self.qkv_tensor_ptrs) + + def qkv_b_slices(_i): + _QB = ( + state_dict[naming.attn_q_b(_i)] + .reshape([nh, 2, dh // 2]) + .transpose(1, 2) + ) + _KB = ( + state_dict[naming.attn_k_b(_i)] + .reshape([nkvh, 2, dh // 2]) + .transpose(1, 2) + ) + _VB = state_dict[naming.attn_v_b(_i)].reshape([nkvh, dh // 2, 2]) + _result = [] + _nh = nh // ndev + _nkvh = nkvh // ndev + for _idev in range(ndev): + _result.append(_QB[_idev * _nh : (_idev + 1) * _nh, :, :].flatten()) + _result.append(_KB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :].flatten()) + _result.append(_VB[_idev * _nkvh : (_idev + 1) * _nkvh, :, :].flatten()) + return _result + + if naming.attn_q_b(0) in state_dict: + self.qkv_b_tensors = [ + torch.concat(qkv_b_slices(i)).to(torch_dt_logits) for i in range(nlayer) + ] + self.qkv_b_tensor_ptrs = [ + self.qkv_b_tensors[i].data_ptr() for i in range(nlayer) + ] + self.attn_qkv_b = (c_void_p * nlayer)(*self.qkv_b_tensor_ptrs) + else: + self.attn_qkv_b = None + + self.attn_o_tensor = [ + ( + state_dict[naming.attn_o(i)] + .to(torch_dt_mat) + .reshape([d, ndev, nh // ndev * dh]) + .transpose(0, 1) + .contiguous() + if transpose_weight + else state_dict[naming.attn_o(i)] + .transpose(0, 1) + .to(torch_dt_mat) + .contiguous() + ) + * scale_o + for i in range(nlayer) + ] + self.attn_o_ptrs = [self.attn_o_tensor[i].data_ptr() for i in range(nlayer)] + self.attn_o = (c_void_p * nlayer)(*self.attn_o_ptrs) + + self.ffn_norm_tensors = [ + state_dict[naming.ffn_norm(i)].to(torch_dt_norm) for i in range(nlayer) + ] + self.ffn_norm_ptrs = [ + self.ffn_norm_tensors[i].data_ptr() for i in range(nlayer) + ] + self.ffn_norm = (c_void_p * nlayer)(*self.ffn_norm_ptrs) + + def gate_up_slices(_i): + _result = [] + _di = di // ndev + for _idev in range(ndev): + _start = _idev * _di + _end = (_idev + 1) * _di + _result.append(state_dict[naming.gate(_i)][_start:_end, :]) + _result.append(state_dict[naming.up(_i)][_start:_end, :]) + return _result + + self.gate_up_tensors = [ + torch.concat(gate_up_slices(i)).to(torch_dt_mat) for i in range(nlayer) + ] + if not transpose_weight: + for i in range(nlayer): + self.gate_up_tensors[i] = ( + self.gate_up_tensors[i] + .reshape(ndev, 2 * di // ndev, d) + .transpose(1, 2) + .contiguous() + ) + self.gate_up_ptrs = [self.gate_up_tensors[i].data_ptr() for i in range(nlayer)] + self.ffn_gate_up = (c_void_p * nlayer)(*self.gate_up_ptrs) + + self.ffn_down_tensor = [ + ( + state_dict[naming.down(i)] + .to(torch_dt_mat) + .reshape([d, ndev, di // ndev]) + .transpose(0, 1) + .contiguous() + if transpose_weight + else state_dict[naming.down(i)] + .transpose(0, 1) + .to(torch_dt_mat) + .contiguous() + ) + * scale_down + for i in range(nlayer) + ] + self.ffn_down_ptrs = [self.ffn_down_tensor[i].data_ptr() for i in range(nlayer)] + self.ffn_down = (c_void_p * nlayer)(*self.ffn_down_ptrs) + + +class JiugeBatchedTask: + def __init__(self, tasks: List[InferTask]): + self.tasks = tasks + self.nreq = len(tasks) + self.is_prefill = 1 + + # Precompute fields + token_lists = [t.tokens for t in tasks] + self.req_lens_list = [len(toks) for toks in token_lists] + self.req_pos_list = [t.pos for t in tasks] + self.kv_cache_ptrs = [t.kvcache().data() for t in tasks] + self.temperaturas_list = [t.temperature for t in tasks] + self.topks_list = [t.topk for t in tasks] + self.topps_list = [t.topp for t in tasks] + + # Flatten token lists + flat_tokens = [tok for toks in token_lists for tok in toks] + self.ntok = len(flat_tokens) + + # Convert to ctypes arrays in one pass + self.tokens = (c_uint * self.ntok)(*flat_tokens) + self.req_lens = (c_uint * self.nreq)(*self.req_lens_list) + self.req_pos = (c_uint * self.nreq)(*self.req_pos_list) + self.kv_caches = (POINTER(KVCacheCStruct) * self.nreq)(*self.kv_cache_ptrs) + self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list) + self.topks = (c_uint * self.nreq)(*self.topks_list) + self.topps = (c_float * self.nreq)(*self.topps_list) + print(list(self.tokens)) + + def input_args(self): + return ( + self.tokens, + self.ntok, + self.req_lens, + self.nreq, + self.req_pos, + self.kv_caches, + self.temperaturas, + self.topks, + self.topps, + self.is_prefill, + ) + + +class JiugeForCausalLM: + def __init__( + self, model_dir_path, config, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None + ): + def load_all_safetensors_from_dir(dir_path_: str): + tensors_ = {} + dir_path_ = Path(dir_path_) + for file in sorted(dir_path_.glob("*.safetensors")): + data_ = safetensors.safe_open(file, "pt") + for name_ in data_.keys(): + tensors_[name_] = data_.get_tensor(name_) + return tensors_ + + print("Loading model weights to host...") + load_start_time = time.time() + + # with open(os.path.join(model_dir_path, "config.json"), "r") as f: + # config = json.load(f) + # self.config = config + # eos_token_id = self.config["eos_token_id"] + # self.eos_token_id = ( + # [eos_token_id] if type(eos_token_id) == int else eos_token_id + # ) + # model_type = config.model_type + model_type = config["model_type"] + transpose_weight = ( + device != DeviceType.DEVICE_TYPE_ASCEND + ) # y = xW is faster than y=xW^T on Ascend + + self.jiuge_model = JiugeModel() + + if "llama" == model_type: + model = ( + transformers.LlamaForCausalLM.from_pretrained(model_dir_path) + .cpu() + .half() + ) + self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) + self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_dir_path, trust_remote_code=True) + self.weights = JiugeWeightsImpl( + self.meta, + LlamaWeightsNaming(), + model.state_dict(), + ndev=ndev, + transpose_weight=transpose_weight, + ) + elif "fm9g" == model_type or "minicpm" == model_type: + if any( + file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir() + ): + state_dict = load_all_safetensors_from_dir(model_dir_path) + else: + state_dict = torch.load( + os.path.join(model_dir_path, "pytorch_model.bin"), + weights_only=True, + map_location="cpu", + ) + if LlamaWeightsNaming.match(state_dict): + self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) + self.weights = JiugeWeightsImpl( + self.meta, + LlamaWeightsNaming(), + state_dict, + ndev=ndev, + transpose_weight=transpose_weight, + ) + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_dir_path, trust_remote_code=True + ) + else: + raise ValueError("Unsupported weight naming") + elif "fm9g7b" == model_type: + if any( + file.suffix == ".safetensors" for file in Path(model_dir_path).iterdir() + ): + state_dict = load_all_safetensors_from_dir(model_dir_path) + else: + state_dict = torch.load( + os.path.join(model_dir_path, "pytorch_model.bin"), + weights_only=True, + map_location="cpu", + ) + if LlamaWeightsNaming.match(state_dict): + self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) + self.weights = JiugeWeightsImpl( + self.meta, + LlamaWeightsNaming(), + state_dict, + ndev=ndev, + transpose_weight=transpose_weight, + ) + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_dir_path, trust_remote_code=True + ) + else: + raise ValueError("Unsupported weight naming") + elif "qwen2" == model_type: + state_dict = load_all_safetensors_from_dir(model_dir_path) + if LlamaWeightsNaming.match(state_dict): + self.meta = JiugeMetaFromLlama(config, max_tokens=max_tokens) + self.weights = JiugeWeightsImpl( + self.meta, + LlamaWeightsNaming(), + state_dict, + ndev=ndev, + transpose_weight=transpose_weight, + ) + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_dir_path + ) + else: + raise ValueError("Unsupported model architecture") + + load_end_time = time.time() + print(f"Time used: {load_end_time - load_start_time:.3f}s") + + print(f"Creating model on {ndev} devices...") + load_start_time = time.time() + self.dev_ids = (c_int * ndev)(*[i for i in range(ndev)]) + self.ndev = ndev + self.device = device + + self.model_instance = self.jiuge_model.create_model( + byref(self.meta), + byref(self.weights), + device, + ndev, + self.dev_ids, + ) + load_end_time = time.time() + print(f"Time used: {load_end_time - load_start_time:.3f}s") + + def max_context_len(self): + return self.meta.dctx + + def create_kv_cache(self): + return self.jiuge_model.create_kv_cache( + self.meta.nlayer, + self.meta.dctx, + self.meta.nkvh, + self.meta.dh, + self.meta.dh, + self.meta.dt_logits, + self.device, + self.dev_ids, + self.ndev, + ) + + def create_paged_kv_cache(self, max_kvcache_tokens): + return self.jiuge_model.create_paged_kv_cache( + self.meta.nlayer, + self.meta.nkvh, + self.meta.dh, + self.meta.dh, + self.meta.dt_logits, + self.device, + self.dev_ids, + self.ndev, + self.meta.kvcache_block_size, + max_kvcache_tokens, + ) + + def drop_kv_cache(self, kv_cache): + self.jiuge_model.drop_kv_cache(kv_cache) + + # def batch_infer_one_round(self, tasks: List[InferTask]): + # output = (c_uint * len(tasks))() + # batch_inputs = JiugeBatchedTask(tasks) + # self.jiuge_model.infer_batch( + # self.model_instance, + # *(batch_inputs.input_args()), + # output, + # ) + # return list(output) + + def infer_batch(self, *args, **kwargs): + self.jiuge_model.infer_batch( + self.model_instance, + *args, + **kwargs + ) + + def forward_batch(self, *args, **kwargs): + self.jiuge_model.forward_batch( + self.model_instance, + *args, + **kwargs + ) + + # def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0): + # input_content = self.tokenizer.apply_chat_template( + # conversation=[{"role": "user", "content": input_content}], + # add_generation_prompt=True, + # tokenize=False, + # ) + # print(input_content, end="", flush=True) + # tokens = self.tokenizer.encode(input_content) + # infer_task = InferTask( + # 0, + # tokens, + # self.max_context_len(), + # temperature_, + # topk_, + # topp_, + # self.eos_token_id, + # ) + # infer_task.bind_kvcache(KVCache(self)) + + # steps = 0 + # total_time = 0 + # output_content = "" + + # for step_i in range(max_steps): + # start_time = time.time() + # output_tokens = self.batch_infer_one_round([infer_task]) + # end_time = time.time() + # steps += 1 + # output_str = ( + # self.tokenizer._tokenizer.id_to_token(output_tokens[0]) + # .replace("▁", " ") + # .replace("<0x0A>", "\n") + # ) + # output_content += output_str + # print(output_str, end="", flush=True) + # if output_tokens[0] in self.eos_token_id: + # break + # infer_task.next(output_tokens[0]) + + # if step_i > 0: + # total_time += end_time - start_time + + # print("\n") + # avg_time = total_time * 1000 / (steps - 1) + # print(f"Time per step: {avg_time:.3f}ms") + + # infer_task._kv_cache.drop(self) + # return output_content, avg_time + + # def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10): + # tasks = [ + # InferTask(i, [], self.max_context_len(), 1.0, 1, 1.0, self.eos_token_id) + # for i in range(batch_size) + # ] + # kv_caches = [KVCache(self) for _ in range(batch_size)] + + # nll = 0.0 + # total_len = 0 + + # for i in range(0, len(test_sequences), batch_size): + # batch_id = 0 + # true_tokens = [] + # while batch_id < batch_size and batch_id + i < len(test_sequences): + # input_tokens = test_sequences[i + batch_id][:-1] + # true_tokens.extend(test_sequences[i + batch_id][1:]) + # tasks[batch_id].tokens = input_tokens + # tasks[batch_id].bind_kvcache(kv_caches[batch_id]) + # batch_id += 1 + + # batch_inputs = JiugeBatchedTask(tasks[:batch_id]) + # logits = torch.zeros( + # (batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits + # ) + # self.jiuge_model.forward_batch( + # self.model_instance, + # batch_inputs.tokens, + # batch_inputs.ntok, + # batch_inputs.req_lens, + # batch_inputs.nreq, + # batch_inputs.req_pos, + # batch_inputs.kv_caches, + # logits.data_ptr(), + # ) + + # logits = logits.float() + # token_ids = torch.tensor(true_tokens, dtype=torch.int64) # [ntok,] + # log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # (ntok, vocab) + # token_logprobs = log_probs[ + # torch.arange(batch_inputs.ntok), token_ids + # ] # (ntok,) + + # start = 0 + # for l in batch_inputs.req_lens_list: + # nll += -token_logprobs[start : start + l].sum().item() + # start += l + # total_len += token_logprobs.numel() + + # for task in tasks: + # task.release_kvcache() + + # return math.exp(nll / total_len) + + def destroy_model_instance(self): + self.jiuge_model.destroy_model(self.model_instance) + print("Model destroyed") + + +def test(): + if len(sys.argv) < 3: + print( + "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" + ) + sys.exit(1) + model_path = sys.argv[2] + device_type = DeviceType.DEVICE_TYPE_CPU + if sys.argv[1] == "--cpu": + device_type = DeviceType.DEVICE_TYPE_CPU + elif sys.argv[1] == "--nvidia": + device_type = DeviceType.DEVICE_TYPE_NVIDIA + elif sys.argv[1] == "--cambricon": + device_type = DeviceType.DEVICE_TYPE_CAMBRICON + elif sys.argv[1] == "--ascend": + device_type = DeviceType.DEVICE_TYPE_ASCEND + elif sys.argv[1] == "--metax": + device_type = DeviceType.DEVICE_TYPE_METAX + elif sys.argv[1] == "--moore": + device_type = DeviceType.DEVICE_TYPE_MOORE + elif sys.argv[1] == "--iluvatar": + device_type = DeviceType.DEVICE_TYPE_ILUVATAR + elif sys.argv[1] == "--kunlun": + device_type = DeviceType.DEVICE_TYPE_KUNLUN + else: + print( + "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" + ) + sys.exit(1) + + ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1 + model = JiugeForCausalLM(model_path, device_type, ndev) + model.generate("山东最高的山是?", 500) + model.destroy_model_instance() + + +if __name__ == "__main__": + test() \ No newline at end of file diff --git a/python/icinfer/models/jiuge_awq.py b/python/icinfer/models/jiuge_awq.py new file mode 100644 index 00000000..a14836b7 --- /dev/null +++ b/python/icinfer/models/jiuge_awq.py @@ -0,0 +1,374 @@ +from typing import List, Sequence +import math +import os +from pathlib import Path +import safetensors +import sys +import time +import json +import torch +import transformers + +from libinfinicore_infer import ( + JiugeAWQModel, + JiugeAWQMetaCStruct, + DataType, + DeviceType, + KVCacheCStruct, +) +from infer_task import InferTask, KVCache + +from ctypes import POINTER, c_float, c_int, c_uint, c_void_p, byref + +torch.set_default_device("cpu") + + +class JiugeAWQMetaFromConfig(JiugeAWQMetaCStruct): + def __init__(self, config, dtype=torch.float16, max_tokens=None): + if config["torch_dtype"] == "float16": + dt_ = DataType.INFINI_DTYPE_F16 + elif config["torch_dtype"] == "float32": + dt_ = DataType.INFINI_DTYPE_F32 + elif config["torch_dtype"] == "bfloat16": + dt_ = DataType.INFINI_DTYPE_BF16 + else: + dt_ = DataType.INFINI_DTYPE_F16 + + self.scale_input = 1.0 + self.scale_output = 1.0 + self.scale_o = 1.0 + self.scale_down = 1.0 + if ( + config["model_type"] in ["fm9g", "minicpm"] + and "scale_emb" in config + and "scale_depth" in config + and "dim_model_base" in config + ): + self.scale_input = config["scale_emb"] + self.scale_output = config["hidden_size"] // config["dim_model_base"] + self.scale_o = config["scale_depth"] / math.sqrt( + config["num_hidden_layers"] + ) + self.scale_down = config["scale_depth"] / math.sqrt( + config["num_hidden_layers"] + ) + + has_qkv_bias = ( + 1 if "attention_bias" in config and config["attention_bias"] else 0 + ) + if config["model_type"] in ["qwen2", "qwen3"]: + has_qkv_bias = 1 + + eos_token_id = ( + config["eos_token_id"][0] + if type(config["eos_token_id"]) == list + else config["eos_token_id"] + ) + + super().__init__( + dt_logits=dt_, + dt_linear_w=DataType.INFINI_DTYPE_I32, + dt_norm_w=dt_, + nlayer=config["num_hidden_layers"], + d=config["hidden_size"], + nh=config["num_attention_heads"], + nkvh=( + config["num_key_value_heads"] + if "num_key_value_heads" in config + else config["num_attention_heads"] + ), + dh=config["hidden_size"] // config["num_attention_heads"], + di=config["intermediate_size"], + dctx=( + config["max_position_embeddings"] if max_tokens is None else max_tokens + ), + dvoc=config["vocab_size"], + epsilon=config["rms_norm_eps"], + theta=(config["rope_theta"] if "rope_theta" in config else 100000.0), + end_token=eos_token_id, + nbit=config["quantization_config"]["bits"], + quant_group_size=config["quantization_config"]["group_size"], + has_qkv_bias=has_qkv_bias, + ) + self.torch_dtype_logits = dtype + + +class JiugeAWQBatchedTask: + def __init__(self, tasks: List[InferTask]): + self.tasks = tasks + self.nreq = len(tasks) + + # Precompute fields + token_lists = [t.tokens for t in tasks] + self.req_lens_list = [len(toks) for toks in token_lists] + self.req_pos_list = [t.pos for t in tasks] + self.kv_cache_ptrs = [t.kvcache().data() for t in tasks] + self.temperaturas_list = [t.temperature for t in tasks] + self.topks_list = [t.topk for t in tasks] + self.topps_list = [t.topp for t in tasks] + + # Flatten token lists + flat_tokens = [tok for toks in token_lists for tok in toks] + self.ntok = len(flat_tokens) + + # Convert to ctypes arrays in one pass + self.tokens = (c_uint * self.ntok)(*flat_tokens) + self.req_lens = (c_uint * self.nreq)(*self.req_lens_list) + self.req_pos = (c_uint * self.nreq)(*self.req_pos_list) + self.kv_caches = (POINTER(KVCacheCStruct) * self.nreq)(*self.kv_cache_ptrs) + self.temperaturas = (c_float * self.nreq)(*self.temperaturas_list) + self.topks = (c_uint * self.nreq)(*self.topks_list) + self.topps = (c_float * self.nreq)(*self.topps_list) + + def input_args(self): + return ( + self.tokens, + self.ntok, + self.req_lens, + self.nreq, + self.req_pos, + self.kv_caches, + self.temperaturas, + self.topks, + self.topps, + ) + + +class JiugeAWQForCausalLM: + def __init__( + self, model_dir_path, device=DeviceType.DEVICE_TYPE_CPU, ndev=1, max_tokens=None + ): + + load_start_time = time.time() + print(f"Creating model on {ndev} devices...") + with open(os.path.join(model_dir_path, "config.json"), "r") as f: + config = json.load(f) + self.config = config + eos_token_id = self.config["eos_token_id"] + self.eos_token_id = ( + [eos_token_id] if type(eos_token_id) == int else eos_token_id + ) + self.dev_ids = (c_int * ndev)(*[i for i in range(ndev)]) + self.ndev = ndev + self.device = device + self.meta = JiugeAWQMetaFromConfig(config, max_tokens=max_tokens) + + self.jiuge_awq_model = JiugeAWQModel() + + self.weights = self.jiuge_awq_model.create_weights( + byref(self.meta), + self.device, + ndev, + self.dev_ids, + ) + self.tokenizer = transformers.AutoTokenizer.from_pretrained( + model_dir_path, trust_remote_code=True + ) + + load_end_time = time.time() + print(f"Time used: {load_end_time - load_start_time:.3f}s") + + load_start_time = time.time() + print("Loading model weights to host...") + + self.load_all_safetensors_from_dir(os.path.join(model_dir_path)) + + self.model_instance = self.jiuge_awq_model.create_model( + byref(self.meta), + self.weights, + ) + load_end_time = time.time() + print(f"Time used: {load_end_time - load_start_time:.3f}s") + + def load_all_safetensors_from_dir(self, dir_path_: str): + dir_path_ = Path(dir_path_) + for file in sorted(dir_path_.glob("*.safetensors")): + with safetensors.safe_open(file, framework="pt", device="cpu") as f: + for key in f.keys(): + # print(key) + tensor = f.get_tensor(key) + if "o_proj.scales" in key: + tensor = tensor * self.meta.scale_o + elif "down_proj.scales" in key: + tensor = tensor * self.meta.scale_down + elif "embed_tokens.weight" in key: + tensor = tensor * self.meta.scale_input + elif "lm_head.weight" in key: + tensor = tensor * self.meta.scale_output + self.jiuge_awq_model.load_weight( + self.weights, key, tensor.data_ptr() + ) + + def max_context_len(self): + return self.meta.dctx + + def create_kv_cache(self): + return self.jiuge_awq_model.create_kv_cache( + self.meta.nlayer, + self.meta.dctx, + self.meta.nkvh, + self.meta.dh, + self.meta.dh, + self.meta.dt_logits, + self.device, + self.dev_ids, + self.ndev, + ) + + def drop_kv_cache(self, kv_cache): + self.jiuge_awq_model.drop_kv_cache(kv_cache) + + def batch_infer_one_round(self, tasks: List[InferTask]): + output = (c_uint * len(tasks))() + batch_inputs = JiugeAWQBatchedTask(tasks) + self.jiuge_awq_model.infer_batch( + self.model_instance, + *(batch_inputs.input_args()), + output, + ) + return list(output) + + def generate(self, input_content, max_steps, topp_=1.0, topk_=1, temperature_=1.0): + input_content = self.tokenizer.apply_chat_template( + conversation=[{"role": "user", "content": input_content}], + add_generation_prompt=True, + tokenize=False, + ) + print(input_content, end="", flush=True) + tokens = self.tokenizer.encode(input_content) + infer_task = InferTask( + 0, + tokens, + self.max_context_len(), + temperature_, + topk_, + topp_, + self.eos_token_id, + ) + infer_task.bind_kvcache(KVCache(self)) + + steps = 0 + total_time = 0 + output_content = "" + + for step_i in range(max_steps): + start_time = time.time() + output_tokens = self.batch_infer_one_round([infer_task]) + end_time = time.time() + steps += 1 + # output_str = ( + # self.tokenizer._tokenizer.id_to_token(output_tokens[0]) + # .replace("▁", " ") + # .replace("<0x0A>", "\n") + # ) + output_str = self.tokenizer.decode(output_tokens[0]) + output_content += output_str + print(output_str, end="", flush=True) + if output_tokens[0] in self.eos_token_id: + break + infer_task.next(output_tokens[0]) + + if step_i > 0: + total_time += end_time - start_time + + print("\n") + avg_time = total_time * 1000 / (steps - 1) + print(f"Time per step: {avg_time:.3f}ms") + + infer_task._kv_cache.drop(self) + return output_content, avg_time + + def perplexity(self, test_sequences: List[Sequence[int]], batch_size=10): + tasks = [ + InferTask(i, [], self.max_context_len(), 1.0, 1, 1.0, self.eos_token_id) + for i in range(batch_size) + ] + kv_caches = [KVCache(self) for _ in range(batch_size)] + + nll = 0.0 + total_len = 0 + + for i in range(0, len(test_sequences), batch_size): + batch_id = 0 + true_tokens = [] + while batch_id < batch_size and batch_id + i < len(test_sequences): + input_tokens = test_sequences[i + batch_id][:-1] + true_tokens.extend(test_sequences[i + batch_id][1:]) + tasks[batch_id].tokens = input_tokens + tasks[batch_id].bind_kvcache(kv_caches[batch_id]) + batch_id += 1 + + batch_inputs = JiugeAWQBatchedTask(tasks[:batch_id]) + logits = torch.zeros( + (batch_inputs.ntok, self.meta.dvoc), dtype=self.meta.torch_dtype_logits + ) + self.jiuge_awq_model.forward_batch( + self.model_instance, + batch_inputs.tokens, + batch_inputs.ntok, + batch_inputs.req_lens, + batch_inputs.nreq, + batch_inputs.req_pos, + batch_inputs.kv_caches, + logits.data_ptr(), + ) + + logits = logits.float() + token_ids = torch.tensor(true_tokens, dtype=torch.int64) # [ntok,] + log_probs = torch.nn.functional.log_softmax(logits, dim=-1) # (ntok, vocab) + token_logprobs = log_probs[ + torch.arange(batch_inputs.ntok), token_ids + ] # (ntok,) + + start = 0 + for l in batch_inputs.req_lens_list: + nll += -token_logprobs[start : start + l].sum().item() + start += l + total_len += token_logprobs.numel() + + for task in tasks: + task.release_kvcache() + + return math.exp(nll / total_len) + + def destroy_model_instance(self): + self.jiuge_awq_model.destroy_model(self.model_instance) + print("Model destroyed") + + +def test(): + if len(sys.argv) < 3: + print( + "Usage: python jiuge_awq.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" + ) + sys.exit(1) + model_path = sys.argv[2] + device_type = DeviceType.DEVICE_TYPE_CPU + if sys.argv[1] == "--cpu": + device_type = DeviceType.DEVICE_TYPE_CPU + elif sys.argv[1] == "--nvidia": + device_type = DeviceType.DEVICE_TYPE_NVIDIA + elif sys.argv[1] == "--cambricon": + device_type = DeviceType.DEVICE_TYPE_CAMBRICON + elif sys.argv[1] == "--ascend": + device_type = DeviceType.DEVICE_TYPE_ASCEND + elif sys.argv[1] == "--metax": + device_type = DeviceType.DEVICE_TYPE_METAX + elif sys.argv[1] == "--moore": + device_type = DeviceType.DEVICE_TYPE_MOORE + elif sys.argv[1] == "--iluvatar": + device_type = DeviceType.DEVICE_TYPE_ILUVATAR + else: + print( + "Usage: python main_jiuge_awq.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" + ) + sys.exit(1) + + ndev = int(sys.argv[3]) if len(sys.argv) > 3 else 1 + model = JiugeAWQForCausalLM(model_path, device_type, ndev) + model.generate("山东最高的山是?", 500) + model.destroy_model_instance() + + +if __name__ == "__main__": + test() diff --git a/python/icinfer/models/libinfinicore_infer/__init__.py b/python/icinfer/models/libinfinicore_infer/__init__.py new file mode 100644 index 00000000..8fc5f4db --- /dev/null +++ b/python/icinfer/models/libinfinicore_infer/__init__.py @@ -0,0 +1,27 @@ +from .base import DataType, DeviceType, KVCacheCStruct +from .jiuge import JiugeModel, JiugeMetaCStruct, JiugeWeightsCStruct +from .jiuge_awq import JiugeAWQModel, JiugeAWQMetaCStruct, ModelWeightsCStruct +from .deepseek_v3 import ( + DeepSeekV3Model, + DeepSeekV3MetaCStruct, + DeepSeekV3WeightsCStruct, + DeepSeekV3WeightLoaderCStruct, + DeepSeekV3CacheCStruct, +) + +__all__ = [ + "DataType", + "DeviceType", + "KVCacheCStruct", + "JiugeModel", + "JiugeMetaCStruct", + "JiugeWeightsCStruct", + "JiugeAWQModel", + "JiugeAWQMetaCStruct", + "ModelWeightsCStruct", + "DeepSeekV3Model", + "DeepSeekV3MetaCStruct", + "DeepSeekV3WeightsCStruct", + "DeepSeekV3WeightLoaderCStruct", + "ModelRegister", +] diff --git a/python/icinfer/models/libinfinicore_infer/base.py b/python/icinfer/models/libinfinicore_infer/base.py new file mode 100644 index 00000000..bac58e3c --- /dev/null +++ b/python/icinfer/models/libinfinicore_infer/base.py @@ -0,0 +1,69 @@ +import ctypes +from ctypes import c_char, c_char_p, c_size_t, c_uint, c_int, c_float, c_void_p, POINTER +import os + + +class DataType(ctypes.c_int): + INFINI_DTYPE_INVALID = 0 + INFINI_DTYPE_BYTE = 1 + INFINI_DTYPE_BOOL = 2 + INFINI_DTYPE_I8 = 3 + INFINI_DTYPE_I16 = 4 + INFINI_DTYPE_I32 = 5 + INFINI_DTYPE_I64 = 6 + INFINI_DTYPE_U8 = 7 + INFINI_DTYPE_U16 = 8 + INFINI_DTYPE_U32 = 9 + INFINI_DTYPE_U64 = 10 + INFINI_DTYPE_F8 = 11 + INFINI_DTYPE_F16 = 12 + INFINI_DTYPE_F32 = 13 + INFINI_DTYPE_F64 = 14 + INFINI_DTYPE_C16 = 15 + INFINI_DTYPE_C32 = 16 + INFINI_DTYPE_C64 = 17 + INFINI_DTYPE_C128 = 18 + INFINI_DTYPE_BF16 = 19 + + +class DeviceType(ctypes.c_int): + DEVICE_TYPE_CPU = 0 + DEVICE_TYPE_NVIDIA = 1 + DEVICE_TYPE_CAMBRICON = 2 + DEVICE_TYPE_ASCEND = 3 + DEVICE_TYPE_METAX = 4 + DEVICE_TYPE_MOORE = 5 + DEVICE_TYPE_ILUVATAR = 6 + DEVICE_TYPE_KUNLUN = 7 + + +class KVCacheCStruct(ctypes.Structure): + pass + + +# Model registration system +_model_registry = [] + + +def register_model(model_class): + """Decorator to register a model class""" + _model_registry.append(model_class) + return model_class + + +def register_lib_functions(lib): + """Register all model functions with the library""" + for model_class in _model_registry: + model_class.register_lib(lib) + + +class BaseModel: + def __init__(self): + self.lib = self._load_library() + register_lib_functions(self.lib) + + def _load_library(self): + lib_path = os.path.join( + os.environ.get("INFINI_ROOT"), "lib", "libinfinicore_infer.so" + ) + return ctypes.CDLL(lib_path) diff --git a/python/icinfer/models/libinfinicore_infer/deepseek_v3.py b/python/icinfer/models/libinfinicore_infer/deepseek_v3.py new file mode 100644 index 00000000..b2c380b7 --- /dev/null +++ b/python/icinfer/models/libinfinicore_infer/deepseek_v3.py @@ -0,0 +1,209 @@ +from .base import BaseModel, DataType, DeviceType, KVCacheCStruct, register_model +from ctypes import ( + c_size_t, + c_uint, + c_int, + c_float, + c_void_p, + POINTER, + Structure, + CFUNCTYPE, +) + + +class DeepSeekV3MetaCStruct(Structure): + _fields_ = [ + ("dt_logits", DataType), + ("dt_norm", DataType), + ("dt_quant_weight", DataType), + ("dt_quant_scale", DataType), + ("dt_quant_zero", DataType), + ("dt_gate_weight", DataType), + ("dt_gate_bias", DataType), + ("n_sparse_layer", c_size_t), + ("n_dense_layer", c_size_t), + ("d", c_size_t), + ("nh", c_size_t), + ("nkvh", c_size_t), + ("d_rope", c_size_t), + ("d_nope", c_size_t), + ("r_q", c_size_t), + ("r_kv", c_size_t), + ("d_qk", c_size_t), + ("d_v", c_size_t), + ("routed_scale", c_float), + ("nexperts", c_size_t), + ("kexperts", c_size_t), + ("di", c_size_t), + ("di_moe", c_size_t), + ("dctx", c_size_t), + ("dvoc", c_size_t), + ("epsilon", c_float), + ("rope_theta", c_float), + ("end_token", c_uint), + ] + + +class DeepSeekV3WeightsCStruct(Structure): + pass + + +class DeepSeekV3ModelCStruct(Structure): + pass + + +class DeepSeekV3CacheCStruct(Structure): + pass + + +load_global_fn = CFUNCTYPE(None, POINTER(DeepSeekV3WeightsCStruct), c_void_p) +load_layer_fn = CFUNCTYPE(None, POINTER(DeepSeekV3WeightsCStruct), c_void_p, c_size_t) +load_layer_linear_fn = CFUNCTYPE( + None, POINTER(DeepSeekV3WeightsCStruct), c_void_p, c_void_p, c_void_p, c_size_t +) +load_layer_mlp_fn = CFUNCTYPE( + None, + POINTER(DeepSeekV3WeightsCStruct), + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_size_t, +) +load_layer_expert_mlp_fn = CFUNCTYPE( + None, + POINTER(DeepSeekV3WeightsCStruct), + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_void_p, + c_size_t, + c_size_t, +) + + +class DeepSeekV3WeightLoaderCStruct(Structure): + _fields_ = [ + ("load_input_embd", load_global_fn), + ("load_output_norm", load_global_fn), + ("load_output_embd", load_global_fn), + ("load_attn_norm", load_layer_fn), + ("load_attn_q_a_proj", load_layer_linear_fn), + ("load_attn_q_a_layernorm", load_layer_fn), + ("load_attn_q_b_proj", load_layer_linear_fn), + ("load_attn_kv_a_proj_with_mqa", load_layer_linear_fn), + ("load_attn_kv_a_layernorm", load_layer_fn), + ("load_attn_kv_b_proj", load_layer_linear_fn), + ("load_attn_o_proj", load_layer_linear_fn), + ("load_mlp_norm", load_layer_fn), + ("load_mlp_dense", load_layer_mlp_fn), + ("load_mlp_gate_weight", load_layer_fn), + ("load_mlp_gate_bias", load_layer_fn), + ("load_mlp_shared_experts", load_layer_mlp_fn), + ("load_mlp_experts", load_layer_expert_mlp_fn), + ] + + +@register_model +class DeepSeekV3Model(BaseModel): + @classmethod + def register_lib(cls, lib): + """Register DeepSeekV3 model functions with the library""" + lib.createDeepSeekV3WeightLoader.argtypes = [] + lib.createDeepSeekV3WeightLoader.restype = POINTER( + DeepSeekV3WeightLoaderCStruct + ) + + lib.createDeepSeekV3Weights.argtypes = [ + POINTER(DeepSeekV3MetaCStruct), + DeviceType, + c_int, + POINTER(c_int), + ] + lib.createDeepSeekV3Weights.restype = POINTER(DeepSeekV3WeightsCStruct) + + lib.createDeepSeekV3Model.argtypes = [ + POINTER(DeepSeekV3MetaCStruct), + POINTER(DeepSeekV3WeightsCStruct), + ] + lib.createDeepSeekV3Model.restype = POINTER(DeepSeekV3ModelCStruct) + + lib.destroyDeepSeekV3Model.argtypes = [POINTER(DeepSeekV3ModelCStruct)] + + lib.createDeepSeekV3Cache.argtypes = [POINTER(DeepSeekV3ModelCStruct)] + lib.createDeepSeekV3Cache.restype = POINTER(DeepSeekV3CacheCStruct) + + lib.dropDeepSeekV3Cache.argtypes = [ + POINTER(DeepSeekV3ModelCStruct), + POINTER(DeepSeekV3CacheCStruct), + ] + + lib.inferBatchDeepSeekV3.argtypes = [ + POINTER(DeepSeekV3ModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), + POINTER(POINTER(DeepSeekV3CacheCStruct)), + POINTER(c_float), + POINTER(c_uint), + POINTER(c_float), + POINTER(c_uint), + ] + + def create_weight_loader(self): + return self.lib.createDeepSeekV3WeightLoader() + + def create_weights(self, meta, device_type, ndev, dev_ids): + return self.lib.createDeepSeekV3Weights(meta, device_type, ndev, dev_ids) + + def create_model(self, meta, weights): + return self.lib.createDeepSeekV3Model(meta, weights) + + def destroy_model(self, model): + self.lib.destroyDeepSeekV3Model(model) + + def create_cache(self, model): + return self.lib.createDeepSeekV3Cache(model) + + def drop_cache(self, model, cache): + self.lib.dropDeepSeekV3Cache(model, cache) + + def infer_batch( + self, + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + caches, + temperature, + topk, + topp, + output, + ): + self.lib.inferBatchDeepSeekV3( + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + caches, + temperature, + topk, + topp, + output, + ) diff --git a/python/icinfer/models/libinfinicore_infer/jiuge.py b/python/icinfer/models/libinfinicore_infer/jiuge.py new file mode 100644 index 00000000..f05db15f --- /dev/null +++ b/python/icinfer/models/libinfinicore_infer/jiuge.py @@ -0,0 +1,211 @@ +from .base import BaseModel, DataType, DeviceType, KVCacheCStruct, register_model +from ctypes import c_size_t, c_uint, c_int, c_float, c_void_p, POINTER, Structure, byref, c_bool + + +class JiugeMetaCStruct(Structure): + _fields_ = [ + ("dt_logits", DataType), + ("nlayer", c_size_t), + ("d", c_size_t), + ("nh", c_size_t), + ("nkvh", c_size_t), + ("dh", c_size_t), + ("di", c_size_t), + ("dctx", c_size_t), + ("dvoc", c_size_t), + ("epsilon", c_float), + ("theta", c_float), + ("end_token", c_uint), + ] + + +class JiugeWeightsCStruct(Structure): + _fields_ = [ + ("nlayer", c_size_t), + ("dt_norm", DataType), + ("dt_mat", DataType), + ("transpose_linear_weights", c_int), + ("input_embd", c_void_p), + ("output_norm", c_void_p), + ("output_embd", c_void_p), + ("attn_norm", POINTER(c_void_p)), + ("attn_qkv", POINTER(c_void_p)), + ("attn_qkv_b", POINTER(c_void_p)), + ("attn_o", POINTER(c_void_p)), + ("ffn_norm", POINTER(c_void_p)), + ("ffn_gate_up", POINTER(c_void_p)), + ("ffn_down", POINTER(c_void_p)), + ] + + +class JiugeModelCStruct(Structure): + pass + + +@register_model +class JiugeModel(BaseModel): + @classmethod + def register_lib(cls, lib): + lib.createJiugeModel.restype = POINTER(JiugeModelCStruct) + lib.createJiugeModel.argtypes = [ + POINTER(JiugeMetaCStruct), + POINTER(JiugeWeightsCStruct), + DeviceType, + c_int, + POINTER(c_int), + ] + + lib.destroyJiugeModel.argtypes = [POINTER(JiugeModelCStruct)] + + lib.createKVCache.argtypes = [ + c_size_t, + c_size_t, + c_size_t, + c_size_t, + c_size_t, + DataType, + DeviceType, + POINTER(c_int), + c_size_t, + ] + lib.createKVCache.restype = POINTER(KVCacheCStruct) + + lib.createPagedKVCache.argtypes = [ + c_size_t, + c_size_t, + c_size_t, + c_size_t, + DataType, + DeviceType, + POINTER(c_int), + c_size_t, + c_size_t, + c_size_t, + ] + lib.createPagedKVCache.restype = POINTER(KVCacheCStruct) + + lib.dropKVCache.argtypes = [POINTER(KVCacheCStruct)] + + # lib.inferBatchJiuge.argtypes = [ + # POINTER(JiugeModelCStruct), + # POINTER(c_uint), + # c_uint, + # POINTER(c_uint), + # c_uint, + # POINTER(c_uint), + # POINTER(POINTER(KVCacheCStruct)), + # POINTER(c_float), + # POINTER(c_uint), + # POINTER(c_float), + # POINTER(c_uint), + # ] + + # lib.forwardBatchJiuge.argtypes = [ + # POINTER(JiugeModelCStruct), + # POINTER(c_uint), + # c_uint, + # POINTER(c_uint), + # c_uint, + # POINTER(c_uint), + # POINTER(POINTER(KVCacheCStruct)), + # c_void_p, + # ] + + lib.inferBatchJiuge.argtypes = [ + POINTER(JiugeModelCStruct), # struct JiugeModel const * + POINTER(c_uint), # unsigned int const *tokens + c_uint, # unsigned int ntok + POINTER(c_uint), # unsigned int const *req_lens + c_uint, # unsigned int nreq + POINTER(c_uint), # unsigned int const *req_pos + POINTER(POINTER(KVCacheCStruct)), # struct KVCache **kv_caches + POINTER(c_int), # unsigned int const *block_tables + POINTER(c_int), # unsigned int const *slot_mapping + POINTER(c_float), # float temperature + POINTER(c_uint), # unsigned int topk + POINTER(c_float), # float topp + c_uint, # unsigned int is_prefill + c_bool, # bool enable_paged_attn + POINTER(c_uint), # unsigned int *output + ] + lib.forwardBatchJiuge.argtypes = [ + POINTER(JiugeModelCStruct), # struct JiugeModel const * + POINTER(c_uint), # unsigned int const *tokens + c_uint, # unsigned int ntok + POINTER(c_uint), # unsigned int const *req_lens + c_uint, # unsigned int nreq + POINTER(c_uint), # unsigned int const *req_pos + POINTER(POINTER(KVCacheCStruct)), # struct KVCache **kv_caches + POINTER(c_int), # unsigned int const *block_tables + POINTER(c_int), # unsigned int const *slot_mapping + c_uint, # unsigned int is_prefill + c_bool, # bool enable_paged_attn + c_void_p, # void *logits + ] + + def create_model(self, meta, weights, device_type, ndev, dev_ids): + return self.lib.createJiugeModel(meta, weights, device_type, ndev, dev_ids) + + def destroy_model(self, model): + self.lib.destroyJiugeModel(model) + + def create_kv_cache( + self, nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev + ): + return self.lib.createKVCache( + nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev + ) + + def create_paged_kv_cache( + self, nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev, max_kvcache_tokens + ): + return self.lib.createPagedKVCache( + nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev, max_kvcache_tokens + ) + + def drop_kv_cache(self, kv_cache): + self.lib.dropKVCache(kv_cache) + + def infer_batch( + self, + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + block_tables, + slot_mapping, + temperature, + topk, + topp, + is_prefill, + enable_paged_attn, + output, + ): + self.lib.inferBatchJiuge( + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + block_tables, + slot_mapping, + temperature, + topk, + topp, + is_prefill, + enable_paged_attn, + output, + ) + + def forward_batch( + self, model, tokens, ntok, req_lens, nreq, req_pos, + kv_caches, block_tables, slot_mapping, is_prefill, enable_paged_attn, logits + ): + self.lib.forwardBatchJiuge( + model, tokens, ntok, req_lens, nreq, req_pos, kv_caches, block_tables, slot_mapping, is_prefill, enable_paged_attn, logits + ) diff --git a/python/icinfer/models/libinfinicore_infer/jiuge_awq.py b/python/icinfer/models/libinfinicore_infer/jiuge_awq.py new file mode 100644 index 00000000..2f47ca8c --- /dev/null +++ b/python/icinfer/models/libinfinicore_infer/jiuge_awq.py @@ -0,0 +1,167 @@ +from .base import BaseModel, DataType, DeviceType, KVCacheCStruct, register_model +from ctypes import ( + c_size_t, + c_uint, + c_int, + c_float, + c_void_p, + POINTER, + Structure, + c_char, + c_char_p, +) + + +class JiugeAWQMetaCStruct(Structure): + _fields_ = [ + ("dt_logits", DataType), + ("dt_linear_w", DataType), + ("dt_norm_w", DataType), + ("nlayer", c_size_t), + ("d", c_size_t), + ("nh", c_size_t), + ("nkvh", c_size_t), + ("dh", c_size_t), + ("di", c_size_t), + ("dctx", c_size_t), + ("dvoc", c_size_t), + ("epsilon", c_float), + ("theta", c_float), + ("end_token", c_uint), + ("nbit", c_size_t), + ("quant_group_size", c_size_t), + ("has_qkv_bias", c_char), + ] + + +class ModelWeightsCStruct(Structure): + pass + + +class JiugeAWQModelCStruct(Structure): + pass + + +@register_model +class JiugeAWQModel(BaseModel): + @classmethod + def register_lib(cls, lib): + """Register JiugeAWQ model functions with the library""" + lib.createJiugeAWQWeights.restype = POINTER(ModelWeightsCStruct) + lib.createJiugeAWQWeights.argtypes = [ + POINTER(JiugeAWQMetaCStruct), + DeviceType, + c_int, + POINTER(c_int), + ] + + lib.createJiugeAWQModel.restype = POINTER(JiugeAWQModelCStruct) + lib.createJiugeAWQModel.argtypes = [ + POINTER(JiugeAWQMetaCStruct), + POINTER(ModelWeightsCStruct), + ] + + lib.destroyJiugeAWQModel.argtypes = [POINTER(JiugeAWQModelCStruct)] + + lib.createKVCache.argtypes = [ + c_size_t, + c_size_t, + c_size_t, + c_size_t, + c_size_t, + DataType, + DeviceType, + POINTER(c_int), + c_size_t, + ] + lib.createKVCache.restype = POINTER(KVCacheCStruct) + + lib.dropKVCache.argtypes = [POINTER(KVCacheCStruct)] + + lib.inferBatchJiugeAWQ.argtypes = [ + POINTER(JiugeAWQModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), + POINTER(POINTER(KVCacheCStruct)), + POINTER(c_float), + POINTER(c_uint), + POINTER(c_float), + POINTER(c_uint), + ] + + lib.forwardBatchJiugeAWQ.argtypes = [ + POINTER(JiugeAWQModelCStruct), + POINTER(c_uint), + c_uint, + POINTER(c_uint), + c_uint, + POINTER(c_uint), + POINTER(POINTER(KVCacheCStruct)), + c_void_p, + ] + + lib.loadModelWeight.argtypes = [ + POINTER(ModelWeightsCStruct), + c_char_p, + c_void_p, + ] + + def create_weights(self, meta, device_type, ndev, dev_ids): + return self.lib.createJiugeAWQWeights(meta, device_type, ndev, dev_ids) + + def create_model(self, meta, weights): + return self.lib.createJiugeAWQModel(meta, weights) + + def destroy_model(self, model): + self.lib.destroyJiugeAWQModel(model) + + def create_kv_cache( + self, nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev + ): + return self.lib.createKVCache( + nlayer, max_len, nkvh, dk, dv, dtype, device, dev_ids, ndev + ) + + def drop_kv_cache(self, kv_cache): + self.lib.dropKVCache(kv_cache) + + def load_weight(self, weights, name, data): + self.lib.loadModelWeight(weights, name.encode("utf-8"), data) + + def infer_batch( + self, + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + temperature, + topk, + topp, + output, + ): + self.lib.inferBatchJiugeAWQ( + model, + tokens, + ntok, + req_lens, + nreq, + req_pos, + kv_caches, + temperature, + topk, + topp, + output, + ) + + def forward_batch( + self, model, tokens, ntok, req_lens, nreq, req_pos, kv_caches, logits + ): + self.lib.forwardBatchJiugeAWQ( + model, tokens, ntok, req_lens, nreq, req_pos, kv_caches, logits + ) diff --git a/python/icinfer/sampling_params.py b/python/icinfer/sampling_params.py new file mode 100644 index 00000000..38733f2e --- /dev/null +++ b/python/icinfer/sampling_params.py @@ -0,0 +1,10 @@ +from dataclasses import dataclass + + +@dataclass +class SamplingParams: + temperature: float = 1.0 + topp: float = 1.0 + topk: int = 1 + max_tokens: int = 64 + ignore_eos: bool = False diff --git a/python/icinfer/utils/context.py b/python/icinfer/utils/context.py new file mode 100644 index 00000000..2281888f --- /dev/null +++ b/python/icinfer/utils/context.py @@ -0,0 +1,27 @@ +from dataclasses import dataclass +import torch + + +@dataclass +class Context: + is_prefill: bool = False + cu_seqlens_q: torch.Tensor | None = None + cu_seqlens_k: torch.Tensor | None = None + max_seqlen_q: int = 0 + max_seqlen_k: int = 0 + slot_mapping: torch.Tensor | None = None + context_lens: torch.Tensor | None = None + block_tables: torch.Tensor | None = None + +_CONTEXT = Context() + +def get_context(): + return _CONTEXT + +def set_context(is_prefill, cu_seqlens_q=None, cu_seqlens_k=None, max_seqlen_q=0, max_seqlen_k=0, slot_mapping=None, context_lens=None, block_tables=None): + global _CONTEXT + _CONTEXT = Context(is_prefill, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, slot_mapping, context_lens, block_tables) + +def reset_context(): + global _CONTEXT + _CONTEXT = Context() diff --git a/python/pyproject.toml b/python/pyproject.toml new file mode 100644 index 00000000..5dad46e6 --- /dev/null +++ b/python/pyproject.toml @@ -0,0 +1,26 @@ +[build-system] +requires = ["setuptools>=61"] +build-backend = "setuptools.build_meta" + +[project] +name = "icinfer" +version = "0.1.0" +authors = [{ name = "" }] +license = "MIT" +license-files = ["LICENSE"] +readme = "README.md" +description = "a lightweight, hardware-agnostic, unified inference engine implementation built from scratch, based on InfiniCore" +requires-python = ">=3.10,<3.13" +dependencies = [ + "torch>=2.4.0", + "triton>=3.0.0", + "transformers>=4.51.0", + "xxhash", +] + +[project.urls] +Homepage="https://github.com/InfiniTensor/InfiniLM" + +[tool.setuptools.packages.find] +where = ["."] +include = ["icinfer*"] diff --git a/python/tests/test_attention.py b/python/tests/test_attention.py new file mode 100644 index 00000000..16590ab3 --- /dev/null +++ b/python/tests/test_attention.py @@ -0,0 +1,522 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +import random +from typing import Optional + +import pytest +import torch + +from tests.kernels.allclose_default import get_default_atol, get_default_rtol +from tests.kernels.utils import opcheck +from vllm import _custom_ops as ops +from vllm.attention.layer import Attention, MultiHeadAttention +from vllm.platforms import current_platform +from vllm.utils import get_max_shared_memory_bytes + +if not current_platform.is_rocm(): + from xformers import ops as xops + from xformers.ops.fmha.attn_bias import BlockDiagonalCausalMask + + from vllm.attention.backends.xformers import _make_alibi_bias + +FLOAT32_BYTES = torch.finfo(torch.float).bits // 8 +# This will change depending on the compute capability. +# - 512 as a buffer +MAX_SEQ_LEN = get_max_shared_memory_bytes() // FLOAT32_BYTES - 512 +# There may not be enough gpu memory due to large NUM_BLOCKS. +# Reduce NUM_BLOCKS when it happens. +NUM_BLOCKS = 4321 # Arbitrary values for testing +PARTITION_SIZE = 512 +PARTITION_SIZE_ROCM = 256 +DTYPES = [torch.bfloat16] +NUM_GEN_SEQS = [7] # Arbitrary values for testing +NUM_PREFILL_SEQS = [3] # Arbitrary values for testing +NUM_HEADS = [(40, 40), (64, 8)] # Arbitrary values for testing + +# This should be sync with get_supported_head_sizes() in +# vllm.attention.ops.paged_attn.PagedAttention +HEAD_SIZES = [32, 80, 128, 256] + +BLOCK_SIZES = [16, 32] +USE_ALIBI = [False, True] +KV_CACHE_DTYPE = ["auto", "fp8"] +SEEDS = [0] +CUDA_DEVICES = [ + f"cuda:{i}" for i in range(1 if torch.cuda.device_count() == 1 else 2) +] + + +def ref_masked_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + attn_mask: Optional[torch.Tensor] = None, +) -> torch.Tensor: + attn_weights = scale * torch.einsum("qhd,khd->hqk", query, key).float() + if attn_mask is not None: + attn_weights = attn_weights + attn_mask.float() + attn_weights = torch.softmax(attn_weights, dim=-1).to(value.dtype) + out = torch.einsum("hqk,khd->qhd", attn_weights, value) + return out + + +def ref_single_query_cached_kv_attention( + output: torch.Tensor, + query: torch.Tensor, + num_queries_per_kv: int, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + seq_lens: torch.Tensor, + scale: float, + alibi_slopes: Optional[torch.Tensor], +) -> None: + num_query_heads = query.shape[1] + num_kv_heads = value_cache.shape[1] + head_size = value_cache.shape[2] + block_size = value_cache.shape[3] + num_seqs = query.shape[0] + + block_tables_lst = block_tables.cpu().tolist() + seq_lens_lst = seq_lens.cpu().tolist() + for i in range(num_seqs): + q = query[i].unsqueeze(0) + block_table = block_tables_lst[i] + seq_len = int(seq_lens_lst[i]) + + keys_lst: list[torch.Tensor] = [] + values_lst: list[torch.Tensor] = [] + for j in range(seq_len): + block_number = int(block_table[j // block_size]) + block_offset = j % block_size + + k = key_cache[block_number, :, :, block_offset, :] + k = k.reshape(num_kv_heads, head_size) + keys_lst.append(k) + + v = value_cache[block_number, :, :, block_offset] + values_lst.append(v) + keys = torch.stack(keys_lst, dim=0) + values = torch.stack(values_lst, dim=0) + if num_queries_per_kv > 1: + # Handle MQA and GQA + keys = torch.repeat_interleave(keys, num_queries_per_kv, dim=1) + values = torch.repeat_interleave(values, num_queries_per_kv, dim=1) + + alibi_bias = None + if alibi_slopes is not None: + # Create the ALiBi bias used in the paged attention kernel. + position_ids = torch.arange(seq_len).int() + alibi_bias = (position_ids - seq_len + 1).float() + alibi_bias = alibi_slopes.view(-1, 1, 1) * alibi_bias.view( + 1, 1, -1) + + out = ref_masked_attention(q, keys, values, scale, alibi_bias) + out = out.view(num_query_heads, head_size) + output[i].copy_(out, non_blocking=True) + + +@pytest.mark.parametrize( + "version", + ["v1", "v2"] if not current_platform.is_rocm() else ["v1", "v2", "rocm"]) +@pytest.mark.parametrize("num_seqs", NUM_GEN_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("use_alibi", USE_ALIBI) +@pytest.mark.parametrize("block_size", BLOCK_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPE) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +def test_paged_attention( + kv_cache_factory, + version: str, + num_seqs: int, + num_heads: tuple[int, int], + head_size: int, + use_alibi: bool, + block_size: int, + dtype: torch.dtype, + kv_cache_dtype: str, + seed: int, + device: str, +) -> None: + if ((kv_cache_dtype == "fp8" and head_size % 16) + or (version == "rocm" and head_size not in (64, 128))): + pytest.skip() + + if (version == "rocm" and current_platform.is_navi() + and (kv_cache_dtype == "fp8" or head_size != 128 + or block_size != 16 or use_alibi)): + pytest.skip() + + global PARTITION_SIZE + + current_platform.seed_everything(seed) + torch.set_default_device(device) + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + query = torch.empty(num_seqs, num_query_heads, head_size, dtype=dtype) + query.uniform_(-scale, scale) + + assert num_query_heads % num_kv_heads == 0 + num_queries_per_kv = num_query_heads // num_kv_heads + alibi_slopes = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) + + seq_lens = [random.randint(1, MAX_SEQ_LEN) for _ in range(num_seqs)] + seq_lens[-1] = MAX_SEQ_LEN + max_seq_len = max(seq_lens) + seq_lens = torch.tensor(seq_lens, dtype=torch.int) + + # Create the block tables. + max_num_blocks_per_seq = (max_seq_len + block_size - 1) // block_size + block_tables_lst: list[list[int]] = [] + for _ in range(num_seqs): + block_table = [ + random.randint(0, NUM_BLOCKS - 1) + for _ in range(max_num_blocks_per_seq) + ] + block_tables_lst.append(block_table) + + block_tables = torch.tensor(block_tables_lst, dtype=torch.int) + + # Create the KV caches. + key_caches, value_caches = kv_cache_factory(NUM_BLOCKS, block_size, 1, + num_kv_heads, head_size, + kv_cache_dtype, dtype, seed, + device) + key_cache, value_cache = key_caches[0], value_caches[0] + + # Using default kv_scale + k_scale = v_scale = torch.tensor(1.0, dtype=torch.float32, device=device) + + # Call the paged attention kernel. + output = torch.empty_like(query) + if version == "v1": + ops.paged_attention( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + + opcheck(torch.ops._C.paged_attention, + (output, query, key_cache, value_cache, num_kv_heads, scale, + block_tables, seq_lens, block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) + + elif version in ("v2", "rocm"): + if current_platform.is_rocm() and version == "rocm": + PARTITION_SIZE = PARTITION_SIZE_ROCM + + num_partitions = ((max_seq_len + PARTITION_SIZE - 1) // PARTITION_SIZE) + assert PARTITION_SIZE % block_size == 0 + num_seqs, num_heads, head_size = output.shape + tmp_output = torch.empty( + size=(num_seqs, num_heads, num_partitions, head_size), + dtype=output.dtype, + ) + exp_sums = torch.empty( + size=(num_seqs, num_heads, num_partitions), + dtype=torch.float32, + ) + max_logits = torch.empty_like(exp_sums) + if version == "v2": + ops.paged_attention_v2( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + + opcheck(torch.ops._C.paged_attention_v2, + (output, exp_sums, max_logits, tmp_output, query, + key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale, 0, 0, 0, 64, 0), + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) + + else: + ops.paged_attention_rocm( + output, + exp_sums, + max_logits, + tmp_output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + seq_lens, + None, + block_size, + max_seq_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + ) + + opcheck(torch.ops._rocm_C.paged_attention, + (output, exp_sums, max_logits, tmp_output, query, + key_cache, value_cache, num_kv_heads, scale, block_tables, + seq_lens, None, block_size, max_seq_len, alibi_slopes, + kv_cache_dtype, k_scale, v_scale), + cond=(head_size == HEAD_SIZES[0] + and block_size == BLOCK_SIZES[0])) + + else: + raise AssertionError(f"Unknown version: {version}") + + # Run the reference implementation. + if kv_cache_dtype == "fp8": + # Convert cache data back to dtype. + x = 16 // torch.tensor([], dtype=dtype).element_size() + key_cache_shape = (NUM_BLOCKS, num_kv_heads, head_size // x, + block_size, x) + dequantized_key_cache = torch.empty(size=key_cache_shape, + dtype=dtype, + device=device) + ops.convert_fp8(dequantized_key_cache, key_cache) + key_cache = dequantized_key_cache + + value_cache_shape = value_cache.shape + dequantized_value_cache = torch.empty(size=value_cache_shape, + dtype=dtype, + device=device) + ops.convert_fp8(dequantized_value_cache, value_cache) + value_cache = dequantized_value_cache + + ref_output = torch.empty_like(query) + ref_single_query_cached_kv_attention( + ref_output, + query, + num_queries_per_kv, + key_cache, + value_cache, + block_tables, + seq_lens, + scale, + alibi_slopes, + ) + + # NOTE(woosuk): Due to the kernel-level differences in the two + # implementations, there is a small numerical difference in the two + # outputs. Thus, we use a relaxed tolerance for the test. + atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 + rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 + + # NOTE(zhaoyang): FP8 KV Cache will introduce quantization error, + # so we use a relaxed tolerance for the test. + atol, rtol = 1e-3, 1e-5 + if kv_cache_dtype == "fp8": + atol, rtol = 1e-2, 1e-5 + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) + + +def ref_multi_query_kv_attention( + cu_seq_lens: list[int], + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + scale: float, + alibi_bias: Optional[list[torch.Tensor]], + dtype: torch.dtype, +) -> torch.Tensor: + num_seqs = len(cu_seq_lens) - 1 + ref_outputs: list[torch.Tensor] = [] + if alibi_bias: + assert len(alibi_bias) == num_seqs + for i in range(num_seqs): + start_idx = cu_seq_lens[i] + end_idx = cu_seq_lens[i + 1] + seq_len = end_idx - start_idx + + # Create attention mask. ALiBi already includes a tril causal mask. + if alibi_bias: + attn_mask = alibi_bias[i] + else: + attn_mask = torch.triu(torch.ones(seq_len, seq_len, dtype=dtype), + diagonal=1) + attn_mask = attn_mask * torch.finfo(dtype).min + attn_mask = attn_mask.to(dtype=dtype) + + ref_output = ref_masked_attention( + query[start_idx:end_idx], + key[start_idx:end_idx], + value[start_idx:end_idx], + scale, + attn_mask=attn_mask, + ) + ref_outputs.append(ref_output) + + return torch.cat(ref_outputs, dim=0) + + +@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.skipif(current_platform.is_rocm(), + reason="Xformers backend is not supported on ROCm.") +@torch.inference_mode() +def test_multi_query_kv_attention( + num_seqs: int, + num_heads: tuple[int, int], + head_size: int, + dtype: torch.dtype, + seed: int, + device: str, + use_alibi: bool = False, +) -> None: + current_platform.seed_everything(seed) + torch.set_default_device(device) + # MAX_SEQ_LEN sometimes causes OOM in the reference implementation. + # As the xformers library is already tested with its own tests, we can use + # a smaller MAX_SEQ_LEN here. + max_len = min(MAX_SEQ_LEN, 4096) + seq_lens = random.sample(range(1, max_len), num_seqs) + num_tokens = sum(seq_lens) + + scale = float(1.0 / (head_size**0.5)) + num_query_heads, num_kv_heads = num_heads + qkv = torch.empty(num_tokens, + num_query_heads + 2 * num_kv_heads, + head_size, + dtype=dtype) + qkv.uniform_(-scale, scale) + query, key, value = qkv.split( + [num_query_heads, num_kv_heads, num_kv_heads], dim=1) + + num_queries_per_kv = num_query_heads // num_kv_heads + if num_queries_per_kv > 1: + # Handle MQA and GQA + key = torch.repeat_interleave(key, num_queries_per_kv, dim=1) + value = torch.repeat_interleave(value, num_queries_per_kv, dim=1) + alibi_bias = None + if use_alibi: + alibi_slopes = torch.randn(num_query_heads, dtype=torch.float) + attn_bias = _make_alibi_bias(alibi_slopes, num_kv_heads, dtype, + seq_lens) + output = torch.empty_like(query) + start = 0 + # Dynamic sequence length not supported with custom attn_bias. + for i, seq_len in enumerate(seq_lens): + end = start + seq_len + out = xops.memory_efficient_attention_forward( + query[None, start:end], + key[None, start:end], + value[None, start:end], + attn_bias=attn_bias[i], + p=0.0, + scale=scale) + output[start:end].copy_(out.view_as(query[start:end])) + start += seq_len + # xformers.AttentionBias to Tensor for use in reference impl. + alibi_bias = [ + b.materialize((1, num_query_heads, i, i), device=device).squeeze() + for b, i in zip(attn_bias, seq_lens) + ] + else: + attn_bias = BlockDiagonalCausalMask.from_seqlens(seq_lens) + output = xops.memory_efficient_attention_forward( + query.unsqueeze(0), + key.unsqueeze(0), + value.unsqueeze(0), + attn_bias=attn_bias, + p=0.0, + scale=scale, + ) + output = output.squeeze(0) + + cu_seq_lens = [0] + for seq_len in seq_lens: + cu_seq_lens.append(cu_seq_lens[-1] + seq_len) + ref_output = ref_multi_query_kv_attention( + cu_seq_lens, + query, + key, + value, + scale, + alibi_bias, + dtype, + ) + atol = get_default_atol(output) if current_platform.is_rocm() else 1e-3 + rtol = get_default_rtol(output) if current_platform.is_rocm() else 1e-5 + torch.testing.assert_close(output, ref_output, atol=atol, rtol=rtol) + + +@pytest.mark.parametrize("num_seqs", NUM_PREFILL_SEQS) +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("head_size", [64]) +@pytest.mark.parametrize("dtype", DTYPES) +@pytest.mark.parametrize("seed", SEEDS) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.skipif(current_platform.is_rocm(), + reason="Xformers backend is not supported on ROCm.") +@torch.inference_mode() +def test_multi_query_kv_attention_with_alibi( + num_seqs: int, + num_heads: tuple[int, int], + head_size: int, + dtype: torch.dtype, + seed: int, + device: str, +) -> None: + return test_multi_query_kv_attention( + num_seqs, + num_heads, + head_size, + dtype, + seed, + device, + use_alibi=True, + ) + + +@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention]) +def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None: + head_size = 64 + scale = float(1.0 / (head_size**0.5)) + num_heads = 16 + num_kv_heads = 5 + with pytest.raises(AssertionError): + _ = attention_cls( + num_heads=num_heads, + head_size=head_size, + scale=scale, + num_kv_heads=num_kv_heads, + ) diff --git a/scripts/jiuge.py b/scripts/jiuge.py index 6c5b07ad..d716901e 100644 --- a/scripts/jiuge.py +++ b/scripts/jiuge.py @@ -690,4 +690,4 @@ def test(): if __name__ == "__main__": - test() + test() \ No newline at end of file diff --git a/scripts/jiuge_ppl.py b/scripts/jiuge_ppl.py index 67dc2326..f836871d 100644 --- a/scripts/jiuge_ppl.py +++ b/scripts/jiuge_ppl.py @@ -1,7 +1,7 @@ import torch from transformers import AutoTokenizer, AutoModelForCausalLM from datasets import load_dataset -from jiuge import JiugeForCauslLM +from jiuge import JiugeForCausalLM from libinfinicore_infer import DeviceType DEVICE_TYPE_MAP = { @@ -25,7 +25,7 @@ def test_torch(input_ids_list, device_): device = TORCH_DEVICE_TYPE_MAP[device_] - model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True).to( + model = AutoModelForCausalLM.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16).to( device ) model.eval() @@ -59,7 +59,7 @@ def test_torch(input_ids_list, device_): def test_infinicore(input_ids_list, device_, ndev_): device = DEVICE_TYPE_MAP[device_] - model = JiugeForCauslLM( + model = JiugeForCausalLM( model_path, device, max_tokens=len(input_ids_list[0]), ndev=ndev_ ) perplexity = model.perplexity(input_ids_list) @@ -99,9 +99,9 @@ def test_infinicore(input_ids_list, device_, ndev_): for i in range(0, len(ids) - seq_len + 1, seq_len): input_ids_list.append(ids[i : i + seq_len]) - perplexity = test_infinicore(input_ids_list, args.dev, args.ndev) - print(f"InfiniCore Perplexity: {perplexity:.2f}") + InfiniCore_perplexity = test_infinicore(input_ids_list, args.dev, args.ndev) + print(f"InfiniCore Perplexity: {InfiniCore_perplexity:.2f}") if args.ndev == 1: # Todo: support multi-device testing with torch - perplexity = test_torch(input_ids_list, args.dev) - print(f"Torch Perplexity: {perplexity.item():.2f}") + Torch_perplexity = test_torch(input_ids_list, args.dev) + print(f"Torch Perplexity: {Torch_perplexity.item():.2f}") diff --git a/scripts/test_jiuge.py b/scripts/test_jiuge.py new file mode 100644 index 00000000..9dbb7f6c --- /dev/null +++ b/scripts/test_jiuge.py @@ -0,0 +1,55 @@ +from jiuge import JiugeForCausalLM +import sys +import logging +import argparse + +from libinfinicore_infer import DeviceType +logger = logging.getLogger(__name__) + + + +def parse_args(): + parser = argparse.ArgumentParser() + # parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/Llama-2-7b-chat-hf") + # parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/FM9G_70B_SFT_MHA/") + parser.add_argument("--model-path", type=str, default="/home/wanghaojie/vllm/huggingface/9G7B_MHA/") + parser.add_argument("--device-type", type=str, default="nvidia") + parser.add_argument("--ndev", type=int, default=4) + args = parser.parse_args() + return args + +def test(): + args = parse_args() + model_path = args.model_path + device_type = DeviceType.DEVICE_TYPE_CPU + if args.device_type == "cpu": + device_type = DeviceType.DEVICE_TYPE_CPU + elif args.device_type == "nvidia": + device_type = DeviceType.DEVICE_TYPE_NVIDIA + elif args.device_type == "cambricon": + device_type = DeviceType.DEVICE_TYPE_CAMBRICON + elif args.device_type == "ascend": + device_type = DeviceType.DEVICE_TYPE_ASCEND + elif args.device_type == "metax": + device_type = DeviceType.DEVICE_TYPE_METAX + elif args.device_type == "moore": + device_type = DeviceType.DEVICE_TYPE_MOORE + elif args.device_type == "iluvatar": + device_type = DeviceType.DEVICE_TYPE_ILUVATAR + else: + logger.info( + # "Usage: python jiuge.py [--cpu | --nvidia| --cambricon | --ascend | --metax | --moore] [n_device]" + "Usage: python jiuge.py [cpu | nvidia| cambricon | ascend | metax | moore] [n_device]" + ) + sys.exit(1) + + ndev = args.ndev + model = JiugeForCausalLM(model_path, device_type, ndev) + # model.generate(["山东最高的山是?", "中国面积最大的省是?"], 500) + # model.generate(["山东最高的山是?"], 500) + model.generate("山东最高的山是?", 500) + model.destroy_model_instance() + + +if __name__ == "__main__": + test() diff --git a/scripts/test_perf.py b/scripts/test_perf.py index a6b26f3b..5ef34a31 100644 --- a/scripts/test_perf.py +++ b/scripts/test_perf.py @@ -28,8 +28,8 @@ "想象一下,如果每个人都能读懂他人的思想。" ] -NUM_REQUESTS = 10 -CONCURRENCY = 5 +NUM_REQUESTS = 20 +CONCURRENCY = 20 API_URL = "http://127.0.0.1:8000" MODEL = "FM9G-7B" @@ -122,6 +122,7 @@ async def run_benchmark(verbose=False): successful_requests = len(results) requests_per_second = successful_requests / total_elapsed_time if total_elapsed_time > 0 else 0 + throughput = sum(tokens_list) / total_elapsed_time if total_elapsed_time > 0 else 0 avg_latency = sum(latencies) / len(latencies) if latencies else 0 avg_tokens_per_second = sum(tokens_per_second_list) / len(tokens_per_second_list) if tokens_per_second_list else 0 avg_ttft = sum(ttft_list) / len(ttft_list) if ttft_list else 0 @@ -139,6 +140,7 @@ async def run_benchmark(verbose=False): print(f"{'总输出token数':<{width_label}}: {sum(tokens_list)}") print(f"{'请求速率 (RPS)':<{width_label}}: {requests_per_second:.2f} requests/s") print(sep) + print(f"{'吞吐量':<{width_label}}: {throughput:.2f} tok/s") print(f"{'Average latency':<{width_label}}: {avg_latency:.2f} s") print(f"{'Average TTFT':<{width_label}}: {avg_ttft:.2f} s") print(f"{'Avg time per token':<{width_label}}: {avg_ms_per_token:.2f} ms/token") diff --git a/scripts/test_ppl.py b/scripts/test_ppl.py index 268a9f7d..1278c569 100644 --- a/scripts/test_ppl.py +++ b/scripts/test_ppl.py @@ -11,17 +11,25 @@ parser = argparse.ArgumentParser() parser.add_argument("--model-path", type=str, required=True) parser.add_argument("--port", type=int, default=8000) - parser.add_argument("--endpoint", type=str, default="/completions") + parser.add_argument("--endpoint", type=str, default="/chat/completions") parser.add_argument("--chunk", type=int, default=512) args = parser.parse_args() API_URL = "http://localhost:" + str(args.port) + args.endpoint CHUNK_SIZE = args.chunk - - dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") + + print("Loading dataset...") + local_file_paths = { + # "train": "/home/wanghaojie/vllm/huggingface/wikitext/wikitext_local_parquet/train.parquet", + # "validation": "/home/wanghaojie/vllm/huggingface/wikitext/wikitext_local_parquet/validation.parquet", + "test": "/home/wanghaojie/vllm/huggingface/wikitext/wikitext-2-raw-v1/test-00000-of-00001.parquet" + } + dataset = load_dataset("parquet", data_files=local_file_paths, split="test") + print("Dataset loaded.") + # dataset = load_dataset("wikitext", "wikitext-2-raw-v1", split="test") # Local tokenizer used for chunking - tokenizer = AutoTokenizer.from_pretrained(args.model_path) + tokenizer = AutoTokenizer.from_pretrained(args.model_path, trust_remote_code=True) total_neg_log_likelihood = 0.0 total_tokens = 0 @@ -41,8 +49,10 @@ API_URL, headers={"Content-Type": "application/json"}, json={ + "messages": [ + {"role": "user", "content": chunk_text} + ], "model": "", - "prompt": chunk_text, "max_tokens": 0, "temperature": 1.0, "echo": True, diff --git a/setup.sh b/setup.sh new file mode 100755 index 00000000..a06bc889 --- /dev/null +++ b/setup.sh @@ -0,0 +1,589 @@ +#!/bin/bash + +set -e + +echo "===================================================================" +echo "Anthropic API 环境变量配置脚本" +echo "注意:本脚本需要在bash环境中运行" +echo "Windows用户请在git bash终端环境下使用" +echo "Mac/Linux用户可直接在终端中运行" +echo "===================================================================" + +# 1. 检查终端环境 +echo "正在检查运行环境..." + +# 检查是否为bash环境 +if [ -z "$BASH_VERSION" ]; then + echo "❌ 错误: 当前不是bash环境" + echo "请在bash终端中运行此脚本:" + echo " - Windows: 使用 Git Bash 或 WSL" + echo " - Mac/Linux: 使用系统终端" + exit 1 +fi + +# 检测操作系统 +OS_TYPE="unknown" +case "$(uname -s)" in + Linux*) OS_TYPE="Linux";; + Darwin*) OS_TYPE="Mac";; + CYGWIN*|MINGW*|MSYS*) OS_TYPE="Windows";; + *) OS_TYPE="unknown";; +esac + +echo "✓ 检测到操作系统: $OS_TYPE" +echo "✓ bash环境检查通过 (版本: $BASH_VERSION)" + +# Node.js 安装函数 +install_nodejs() { + local platform=$(uname -s) + + case "$platform" in + Linux|Darwin) + echo "🚀 正在安装 Node.js..." + + echo "📥 下载并安装 nvm..." + curl -o- https://raw.githubusercontent.com/nvm-sh/nvm/v0.40.3/install.sh | bash + + echo "🔄 加载 nvm 环境..." + \. "$HOME/.nvm/nvm.sh" + + echo "📦 下载并安装 Node.js v22..." + nvm install 22 + + echo -n "✅ Node.js 安装完成!版本: " + node -v + echo -n "✅ npm 版本: " + npm -v + ;; + *) + echo "❌ 不支持的平台: $platform" + echo "请手动安装 Node.js: https://nodejs.org/" + exit 1 + ;; + esac +} + +# 检查 Node.js 环境 +echo "检查 Node.js 环境..." +if command -v node >/dev/null 2>&1; then + current_version=$(node -v | sed 's/v//') + major_version=$(echo $current_version | cut -d. -f1) + + if [ "$major_version" -ge 18 ]; then + echo "✓ Node.js 已安装: v$current_version" + else + echo "⚠️ Node.js v$current_version 版本过低 (需要 >= 18),正在升级..." + install_nodejs + fi +else + echo "📦 Node.js 未安装,正在安装..." + install_nodejs +fi + +# 检查 npm 环境 +if command -v npm >/dev/null 2>&1; then + echo "✓ npm 已安装: $(npm -v)" +else + echo "❌ npm 未找到,Node.js 安装可能有问题" + exit 1 +fi + +# 2. 确定环境变量配置文件 +echo "正在扫描所有可能的环境变量配置文件..." + +# 初始化配置文件数组 +CONFIG_FILES=() + +# 检测当前shell类型 +current_shell=$(basename "$SHELL") + +# 根据shell类型和操作系统,列出所有可能的配置文件 +case "$current_shell" in + bash) + # Bash配置文件优先级顺序 + if [ "$OS_TYPE" = "Mac" ]; then + # macOS上bash配置文件 + potential_files=( + "$HOME/.bash_profile" + "$HOME/.bashrc" + "$HOME/.profile" + ) + else + # Linux/Windows上bash配置文件 + potential_files=( + "$HOME/.bashrc" + "$HOME/.bash_profile" + "$HOME/.profile" + ) + fi + ;; + zsh) + # Zsh配置文件优先级顺序 + potential_files=( + "$HOME/.zshrc" + "$HOME/.zprofile" + "$HOME/.zshenv" + "$HOME/.profile" + ) + + # 检查是否使用Oh My Zsh,避免冲突 + if [ -n "$ZSH" ] && [ -d "$ZSH" ]; then + echo "⚠️ 检测到Oh My Zsh环境,将在配置文件末尾添加变量" + fi + ;; + fish) + # Fish shell配置文件 + potential_files=( + "$HOME/.config/fish/config.fish" + ) + + # 创建fish配置目录(如果不存在) + if [ ! -d "$HOME/.config/fish" ]; then + mkdir -p "$HOME/.config/fish" + echo "创建fish配置目录: ~/.config/fish/" + fi + ;; + *) + # 其他shell的通用配置文件 + potential_files=( + "$HOME/.profile" + "$HOME/.bashrc" + ) + ;; +esac + +# 检查每个可能的配置文件 +echo "检查以下配置文件:" +for file in "${potential_files[@]}"; do + if [ -f "$file" ]; then + CONFIG_FILES+=("$file") + echo " ✓ 找到: ${file/#$HOME/~}" + else + echo " × 不存在: ${file/#$HOME/~}" + fi +done + +# 如果没有找到任何配置文件,创建默认的 +if [ ${#CONFIG_FILES[@]} -eq 0 ]; then + # 根据shell类型创建默认配置文件 + case "$current_shell" in + bash) + if [ "$OS_TYPE" = "Mac" ]; then + DEFAULT_FILE="$HOME/.bash_profile" + else + DEFAULT_FILE="$HOME/.bashrc" + fi + ;; + zsh) + DEFAULT_FILE="$HOME/.zshrc" + ;; + fish) + DEFAULT_FILE="$HOME/.config/fish/config.fish" + ;; + *) + DEFAULT_FILE="$HOME/.profile" + ;; + esac + + touch "$DEFAULT_FILE" + CONFIG_FILES+=("$DEFAULT_FILE") + echo "创建新的配置文件: ${DEFAULT_FILE/#$HOME/~}" +fi + +echo "" +echo "✓ 将更新 ${#CONFIG_FILES[@]} 个配置文件" + +# 3. 检查现有配置(支持不同shell语法) +echo "" +echo "检查现有Anthropic配置..." +EXISTING_CONFIGS=() +BACKUP_FILES=() + +# 检查每个配置文件中的现有配置 +for config_file in "${CONFIG_FILES[@]}"; do + has_config=false + + # 根据文件名判断语法类型 + if [[ "$config_file" == *"fish"* ]]; then + # fish shell 语法: set -x ANTHROPIC_AUTH_TOKEN + if grep -q "set -x ANTHROPIC_AUTH_TOKEN\|set -x ANTHROPIC_BASE_URL" "$config_file" 2>/dev/null; then + has_config=true + fi + else + # bash/zsh 语法: export ANTHROPIC_AUTH_TOKEN + if grep -q "ANTHROPIC_AUTH_TOKEN\|ANTHROPIC_BASE_URL" "$config_file" 2>/dev/null; then + has_config=true + fi + fi + + if [ "$has_config" = true ]; then + EXISTING_CONFIGS+=("$config_file") + echo "⚠️ 在 ${config_file/#$HOME/~} 中检测到已存在的Anthropic配置:" + if [[ "$config_file" == *"fish"* ]]; then + grep -n "set -x ANTHROPIC_" "$config_file" | sed 's/^/ /' || true + else + grep -n "ANTHROPIC_" "$config_file" | sed 's/^/ /' || true + fi + fi +done + +# 如果有现有配置,询问是否覆盖 +if [ ${#EXISTING_CONFIGS[@]} -gt 0 ]; then + echo "" + echo "📋 在 ${#EXISTING_CONFIGS[@]} 个文件中发现现有配置" + read -p "是否要覆盖所有现有配置?(y/N): " overwrite + if [[ ! "$overwrite" =~ ^[Yy]$ ]]; then + echo "操作已取消" + exit 0 + fi + + # 备份所有包含配置的文件 + echo "" + echo "正在备份现有配置文件..." + for config_file in "${EXISTING_CONFIGS[@]}"; do + backup_file="${config_file}.backup.$(date +%Y%m%d_%H%M%S)" + cp "$config_file" "$backup_file" + BACKUP_FILES+=("$backup_file") + echo " ✓ 已备份: ${backup_file/#$HOME/~}" + done +fi + +# 颜色定义 +colorReset='\033[0m' +colorBright='\033[1m' +colorCyan='\033[36m' +colorYellow='\033[33m' +colorMagenta='\033[35m' +colorRed='\033[31m' +colorBlue='\033[34m' +colorWhite='\033[37m' +colorGreen='\033[32m' + +# 显示API密钥获取横幅 +show_api_banner() { + printf "${colorBright}${colorRed} █████╗ ██╗ ${colorBlue}██████╗ ██████╗ ██████╗ ███████╗${colorMagenta} ██╗ ██╗██╗████████╗██╗ ██╗${colorReset}\n" + printf "${colorBright}${colorRed} ██╔══██╗██║ ${colorBlue}██╔════╝██╔═══██╗██╔══██╗██╔════╝${colorMagenta} ██║ ██║██║╚══██╔══╝██║ ██║${colorReset}\n" + printf "${colorBright}${colorRed} ███████║██║ ${colorBlue}██║ ██║ ██║██║ ██║█████╗ ${colorMagenta} ██║ █╗ ██║██║ ██║ ███████║${colorReset}\n" + printf "${colorBright}${colorRed} ██╔══██║██║ ${colorBlue}██║ ██║ ██║██║ ██║██╔══╝ ${colorMagenta} ██║███╗██║██║ ██║ ██╔══██║${colorReset}\n" + printf "${colorBright}${colorRed} ██║ ██║██║ ${colorBlue}╚██████╗╚██████╔╝██████╔╝███████╗${colorMagenta} ╚███╔███╔╝██║ ██║██╗██║ ██║${colorReset}\n" + printf "${colorBright}${colorRed} ╚═╝ ╚═╝╚═╝ ${colorBlue} ╚═════╝ ╚═════╝ ╚═════╝ ╚══════╝${colorMagenta} ╚══╝╚══╝ ╚═╝ ╚═╝╚═╝╚═╝ ╚═╝${colorReset}\n" + printf "\n" + printf "${colorBright}${colorYellow}🌐 请从以下网址获取您的API密钥:${colorReset}\n" + printf "${colorBright}${colorCyan}📋 https://aicodewith.com/dashboard/api-keys${colorReset}\n" + printf "\n" + printf "${colorBright}${colorGreen}📝 API密钥格式: sk-acw-********-****************${colorReset}\n" + printf "\n" +} + +# 4. 获取API密钥 +echo "" +show_api_banner + +# 输入API密钥并验证 +while true; do + read -p "请输入ANTHROPIC_AUTH_TOKEN: " auth_token + echo "" + + # 基本格式验证 + if [[ "$auth_token" =~ ^sk-acw-.{8}-.{16}$ ]]; then + echo "✓ API密钥格式验证通过" + break + else + echo "❌ API密钥格式不正确" + echo " 正确格式: sk-acw-********-****************" + echo " 请重新输入" + fi +done + +# 5. 更新配置文件 +echo "" +echo "正在更新配置文件..." +UPDATE_COUNT=0 +FAILED_FILES=() + +# 处理每个配置文件 +for config_file in "${CONFIG_FILES[@]}"; do + echo " 📝 处理: ${config_file/#$HOME/~}" + + # 判断文件类型和语法 + is_fish=false + if [[ "$config_file" == *"fish"* ]]; then + is_fish=true + fi + + # 移除旧的Anthropic配置 + if grep -q "ANTHROPIC_AUTH_TOKEN\|ANTHROPIC_BASE_URL" "$config_file" 2>/dev/null || \ + grep -q "set -x ANTHROPIC_AUTH_TOKEN\|set -x ANTHROPIC_BASE_URL" "$config_file" 2>/dev/null; then + + # 创建临时文件,移除旧配置 + temp_file=$(mktemp) + if [ "$is_fish" = true ]; then + # 移除fish语法的配置行 + grep -v "set -x ANTHROPIC_AUTH_TOKEN\|set -x ANTHROPIC_BASE_URL" "$config_file" > "$temp_file" + else + # 移除bash/zsh语法的配置行 + grep -v "ANTHROPIC_AUTH_TOKEN\|ANTHROPIC_BASE_URL" "$config_file" > "$temp_file" + fi + mv "$temp_file" "$config_file" + fi + + # 添加新配置 + if [ "$is_fish" = true ]; then + # fish shell 语法 + { + echo "" + echo "# Anthropic API Configuration - $(date '+%Y-%m-%d %H:%M:%S')" + echo "set -x ANTHROPIC_AUTH_TOKEN $auth_token" + echo "set -x ANTHROPIC_BASE_URL https://api.jiuwanliguoxue.com/" + } >> "$config_file" + else + # bash/zsh 语法 + { + echo "" + echo "# Anthropic API Configuration - $(date '+%Y-%m-%d %H:%M:%S')" + echo "export ANTHROPIC_AUTH_TOKEN=$auth_token" + echo "export ANTHROPIC_BASE_URL=https://api.jiuwanliguoxue.com/" + } >> "$config_file" + fi + + # 验证是否写入成功 + if [ "$is_fish" = true ]; then + if grep -q "set -x ANTHROPIC_AUTH_TOKEN $auth_token" "$config_file" && \ + grep -q "set -x ANTHROPIC_BASE_URL" "$config_file"; then + echo " ✓ 配置成功写入" + ((UPDATE_COUNT++)) + else + echo " ❌ 配置写入失败" + FAILED_FILES+=("$config_file") + fi + else + if grep -q "ANTHROPIC_AUTH_TOKEN=$auth_token" "$config_file" && \ + grep -q "ANTHROPIC_BASE_URL=" "$config_file"; then + echo " ✓ 配置成功写入" + ((UPDATE_COUNT++)) + else + echo " ❌ 配置写入失败" + FAILED_FILES+=("$config_file") + fi + fi +done + +echo "" +echo "✓ 成功更新 $UPDATE_COUNT/${#CONFIG_FILES[@]} 个配置文件" + +# 如果有失败的文件,显示错误信息 +if [ ${#FAILED_FILES[@]} -gt 0 ]; then + echo "" + echo "❌ 以下文件更新失败:" + for failed_file in "${FAILED_FILES[@]}"; do + echo " - ${failed_file/#$HOME/~}" + done +fi + +# 6. 加载环境变量并验证 +echo "" +echo "正在加载和验证环境变量..." + +# 尝试从非fish配置文件加载环境变量 +if [[ "$current_shell" != "fish" ]]; then + # 从所有非fish配置文件中提取并加载Anthropic环境变量 + for config_file in "${CONFIG_FILES[@]}"; do + if [[ "$config_file" != *"fish"* ]]; then + eval "$(grep "^export ANTHROPIC_" "$config_file" 2>/dev/null || true)" + fi + done +else + echo "⚠️ Fish shell配置文件不兼容bash,跳过自动加载" +fi + +# 手动设置环境变量用于当前会话 +export ANTHROPIC_AUTH_TOKEN=$auth_token +export ANTHROPIC_BASE_URL=https://api.jiuwanliguoxue.com/ + +# 验证配置是否成功 +if [ "$UPDATE_COUNT" -eq "${#CONFIG_FILES[@]}" ]; then + echo "✅ 所有配置文件更新成功!" + echo "" + echo "📊 当前配置:" + echo " ANTHROPIC_BASE_URL: $ANTHROPIC_BASE_URL" + echo " ANTHROPIC_AUTH_TOKEN: ${ANTHROPIC_AUTH_TOKEN:0:12}...(已隐藏)" + echo "" + + # 显示更新的配置文件列表 + echo "📋 已更新的配置文件:" + for config_file in "${CONFIG_FILES[@]}"; do + echo " - ${config_file/#$HOME/~}" + done + echo "" + echo "🎉 配置完成!" + echo "" + + # 7. 检查并安装/更新Claude Code客户端 + echo "🔍 检查Claude Code客户端..." + if command -v claude >/dev/null 2>&1; then + echo "✓ Claude Code已安装: $(claude --version)" + echo "" + echo "🚀 是否要更新Claude Code客户端到最新版本?" + read -p "这将执行: npm uninstall/install -g @anthropic-ai/claude-code (y/N): " update_claude + + if [[ "$update_claude" =~ ^[Yy]$ ]]; then + echo "🔄 正在更新Claude Code客户端..." + + echo "步骤1: 卸载旧版本..." + npm uninstall -g @anthropic-ai/claude-code + + echo "步骤2: 安装最新版本..." + if npm install -g @anthropic-ai/claude-code --registry=https://registry.npmmirror.com; then + echo "✅ Claude Code客户端更新成功!" + else + echo "❌ Claude Code客户端安装失败,请手动执行:" + echo " npm install -g @anthropic-ai/claude-code --registry=https://registry.npmmirror.com" + fi + fi + else + echo "📦 Claude Code未安装,正在安装..." + if npm install -g @anthropic-ai/claude-code --registry=https://registry.npmmirror.com; then + echo "✅ Claude Code客户端安装成功!" + else + echo "❌ Claude Code客户端安装失败,请手动执行:" + echo " npm install -g @anthropic-ai/claude-code --registry=https://registry.npmmirror.com" + exit 1 + fi + fi + + # 8. 配置Claude Code跳过引导 + echo "" + echo "🔧 配置Claude Code跳过引导..." + node --eval " + const fs = require('fs'); + const os = require('os'); + const path = require('path'); + + const homeDir = os.homedir(); + const filePath = path.join(homeDir, '.claude.json'); + + try { + if (fs.existsSync(filePath)) { + const content = JSON.parse(fs.readFileSync(filePath, 'utf-8')); + fs.writeFileSync(filePath, JSON.stringify({ ...content, hasCompletedOnboarding: true }, null, 2), 'utf-8'); + console.log('✅ 已更新现有Claude配置文件'); + } else { + fs.writeFileSync(filePath, JSON.stringify({ hasCompletedOnboarding: true }, null, 2), 'utf-8'); + console.log('✅ 已创建Claude配置文件并跳过引导'); + } + } catch (error) { + console.log('⚠️ 配置Claude引导跳过时出错:', error.message); + } + " + echo "" + + # 9. 检测并清理Claude配置文件中的代理设置 + echo "" + echo "🔍 检测Claude配置文件中的代理设置..." + # Claude配置文件可能的路径(优先检查settings.json) + CLAUDE_SETTING_FILE="" + if [ -f "$HOME/.claude/settings.json" ]; then + CLAUDE_SETTING_FILE="$HOME/.claude/settings.json" + elif [ -f "$HOME/.claude/settings.local.json" ]; then + CLAUDE_SETTING_FILE="$HOME/.claude/settings.local.json" + elif [ -f "$HOME/.claude/setting.json" ]; then + CLAUDE_SETTING_FILE="$HOME/.claude/setting.json" + fi + + if [ -n "$CLAUDE_SETTING_FILE" ]; then + echo "✓ 找到Claude配置文件: ${CLAUDE_SETTING_FILE/#$HOME/~}" + + # 检测是否存在代理设置 + PROXY_FOUND=false + PROXY_SETTINGS="" + + # 检查是否有HTTP代理设置(不区分大小写) + if grep -iq "http_proxy\|https_proxy\|httpproxy\|httpsproxy" "$CLAUDE_SETTING_FILE" 2>/dev/null; then + PROXY_FOUND=true + echo "" + echo "⚠️ 检测到残留的代理配置:" + PROXY_SETTINGS=$(grep -in "http_proxy\|https_proxy\|httpproxy\|httpsproxy" "$CLAUDE_SETTING_FILE" | sed 's/^/ /') + echo "$PROXY_SETTINGS" + echo "" + echo "📝 这些代理设置可能会影响Claude Code的正常使用" + echo " 建议删除这些设置以避免连接问题" + echo "" + + read -p "是否要删除这些代理设置?(y/N): " remove_proxy + if [[ "$remove_proxy" =~ ^[Yy]$ ]]; then + # 备份原配置文件 + backup_claude_file="${CLAUDE_SETTING_FILE}.backup.$(date +%Y%m%d_%H%M%S)" + cp "$CLAUDE_SETTING_FILE" "$backup_claude_file" + echo "✓ 已备份Claude配置到: ${backup_claude_file/#$HOME/~}" + + # 删除代理设置行(不区分大小写) + # 使用sed删除包含代理相关设置的行 + if [[ "$OS_TYPE" = "Mac" ]]; then + # Mac版本的sed需要备份文件参数 + sed -i '' '/[Hh][Tt][Tt][Pp]_[Pp][Rr][Oo][Xx][Yy]\|[Hh][Tt][Tt][Pp][Ss]_[Pp][Rr][Oo][Xx][Yy]\|[Hh][Tt][Tt][Pp][Pp][Rr][Oo][Xx][Yy]\|[Hh][Tt][Tt][Pp][Ss][Pp][Rr][Oo][Xx][Yy]/d' "$CLAUDE_SETTING_FILE" + else + # Linux/Windows版本的sed + sed -i '/[Hh][Tt][Tt][Pp]_[Pp][Rr][Oo][Xx][Yy]\|[Hh][Tt][Tt][Pp][Ss]_[Pp][Rr][Oo][Xx][Yy]\|[Hh][Tt][Tt][Pp][Pp][Rr][Oo][Xx][Yy]\|[Hh][Tt][Tt][Pp][Ss][Pp][Rr][Oo][Xx][Yy]/d' "$CLAUDE_SETTING_FILE" + fi + + # 验证删除结果(不区分大小写) + if ! grep -iq "http_proxy\|https_proxy\|httpproxy\|httpsproxy" "$CLAUDE_SETTING_FILE" 2>/dev/null; then + echo "✅ 代理设置已成功删除" + echo "📋 Claude Code现在应该能正常使用默认网络连接" + else + echo "❌ 代理设置删除失败" + echo " 请手动编辑文件: $CLAUDE_SETTING_FILE" + echo " 或恢复备份: cp $backup_claude_file $CLAUDE_SETTING_FILE" + fi + else + echo "跳过代理设置清理" + fi + else + echo "✓ 未发现代理设置,配置文件正常" + fi + else + echo "ℹ️ 未找到Claude配置文件(${CLAUDE_SETTING_FILE/#$HOME/~})" + echo " 这是正常的,配置文件会在首次使用Claude Code时自动创建" + fi + echo "" + +# 显示配置完成横幅 +show_complete_banner() { + printf "\n" + printf "${colorBright}${colorRed} █████╗ ██╗ ${colorBlue}██████╗ ██████╗ ██████╗ ███████╗${colorMagenta} ██╗ ██╗██╗████████╗██╗ ██╗${colorReset}\n" + printf "${colorBright}${colorRed} ██╔══██╗██║ ${colorBlue}██╔════╝██╔═══██╗██╔══██╗██╔════╝${colorMagenta} ██║ ██║██║╚══██╔══╝██║ ██║${colorReset}\n" + printf "${colorBright}${colorRed} ███████║██║ ${colorBlue}██║ ██║ ██║██║ ██║█████╗ ${colorMagenta} ██║ █╗ ██║██║ ██║ ███████║${colorReset}\n" + printf "${colorBright}${colorRed} ██╔══██║██║ ${colorBlue}██║ ██║ ██║██║ ██║██╔══╝ ${colorMagenta} ██║███╗██║██║ ██║ ██╔══██║${colorReset}\n" + printf "${colorBright}${colorRed} ██║ ██║██║ ${colorBlue}╚██████╗╚██████╔╝██████╔╝███████╗${colorMagenta} ╚███╔███╔╝██║ ██║██╗██║ ██║${colorReset}\n" + printf "${colorBright}${colorRed} ╚═╝ ╚═╝╚═╝ ${colorBlue} ╚═════╝ ╚═════╝ ╚═════╝ ╚══════╝${colorMagenta} ╚══╝╚══╝ ╚═╝ ╚═╝╚═╝╚═╝ ╚═╝${colorReset}\n" + printf "\n" + printf "${colorBright}${colorYellow}📌 请执行以下命令使配置立即生效:${colorReset}\n" + printf "${colorBright}${colorCyan} source ${CONFIG_FILE/#$HOME/~}${colorReset}\n" + printf "\n" + printf "${colorBright}${colorGreen}🔄 或者重启终端让配置自动生效${colorReset}\n" + printf "\n" +} + + show_complete_banner + echo "" + echo "🔧 如需修改配置,可编辑: ${CONFIG_FILE/#$HOME/~}" +else + # 方案3: 改进错误提示,说明可能的原因 + echo "❌ 配置文件验证失败,可能的原因:" + echo " 1. 配置文件写入过程中出现错误" + echo " 2. 磁盘空间不足或权限问题" + echo " 3. API密钥格式在写入时被意外修改" + echo "" + echo "🔍 调试信息:" + echo " 配置文件路径: $CONFIG_FILE" + echo " API密钥长度: ${#auth_token}" + echo " 配置文件末尾内容:" + tail -5 "$CONFIG_FILE" 2>/dev/null || echo " 无法读取配置文件" + echo "" + echo "💡 建议解决方案:" + echo " 1. 检查磁盘空间: df -h $HOME" + echo " 2. 检查文件权限: ls -la $CONFIG_FILE" + echo " 3. 手动验证配置: cat $CONFIG_FILE | grep ANTHROPIC" + echo " 4. 重新运行脚本" + exit 1 +fi \ No newline at end of file diff --git a/src/cache_manager/kvcache.cpp b/src/cache_manager/kvcache.cpp index db3093c9..cdcceafc 100644 --- a/src/cache_manager/kvcache.cpp +++ b/src/cache_manager/kvcache.cpp @@ -1,4 +1,5 @@ #include "../cache.hpp" +#include __C struct KVCache *createKVCache( size_t nlayers, @@ -31,6 +32,37 @@ __C struct KVCache *createKVCache( return cache; } +__C struct KVCache *createPagedKVCache(size_t nlayers, + size_t nkvh_, + size_t dk, + size_t dv, + infiniDtype_t dtype, + infiniDevice_t device, + int *dev_ids, + size_t ndev, + size_t kvcache_block_size, + size_t max_kvcache_tokens) { + KVCache *cache = new KVCache(); + auto max_num_blocks = max_kvcache_tokens / kvcache_block_size; + assert(kvcache_block_size > 0); + auto shape_k = std::vector{max_num_blocks, nkvh_, kvcache_block_size, dk}; + auto shape_v = std::vector{max_num_blocks, nkvh_, kvcache_block_size, dv}; + for (unsigned int idev = 0; idev < ndev; idev++) { + RUN_INFINI(infinirtSetDevice(device, dev_ids[idev])); + auto kcache = std::vector>(); + auto vcache = std::vector>(); + for (unsigned int layer = 0; layer < nlayers; layer++) { + kcache.push_back(std::move(Tensor::buffer(dtype, shape_k))); + vcache.push_back(std::move(Tensor::buffer(dtype, shape_v))); + } + cache->k.push_back(kcache); + cache->v.push_back(vcache); + } + + return cache; +} + + __C struct KVCache *duplicateKVCache(const KVCache *kv_cache, size_t seq_len) { auto ndev = kv_cache->k.size(); auto nlayers = kv_cache->k[0].size(); diff --git a/src/cache_manager/opcache_manager.hpp b/src/cache_manager/opcache_manager.hpp index 333583e8..669dc265 100644 --- a/src/cache_manager/opcache_manager.hpp +++ b/src/cache_manager/opcache_manager.hpp @@ -162,6 +162,8 @@ class CacheManager { DECLARE_OP_CACHE(SwiGLU) DECLARE_OP_CACHE(RandomSample) DECLARE_OP_CACHE(Dequantize) + DECLARE_OP_CACHE(PagedCaching) + DECLARE_OP_CACHE(PagedAttention) CacheManager(size_t capacity = 100) : Add_cache(capacity, DESTROY_FUNC(Add)), @@ -173,7 +175,9 @@ class CacheManager { Topkrouter_cache(capacity, DESTROY_FUNC(Topkrouter)), SwiGLU_cache(capacity, DESTROY_FUNC(SwiGLU)), RandomSample_cache(capacity, DESTROY_FUNC(RandomSample)), - Dequantize_cache(capacity, DESTROY_FUNC(Dequantize)) {} + Dequantize_cache(capacity, DESTROY_FUNC(Dequantize)), + PagedCaching_cache(capacity, DESTROY_FUNC(PagedCaching)), + PagedAttention_cache(capacity, DESTROY_FUNC(PagedAttention)) {} template static size_t createDescriptorKey(Tensors... tensors) { diff --git a/src/models/inference_context.cpp b/src/models/inference_context.cpp index e41e4bb3..9ffca8d2 100644 --- a/src/models/inference_context.cpp +++ b/src/models/inference_context.cpp @@ -259,6 +259,68 @@ void InferenceContext::linear(std::shared_ptr c, } } +void InferenceContext::pagedCaching(std::shared_ptr k, + std::shared_ptr v, + std::shared_ptr k_cache, + std::shared_ptr v_cache, + std::shared_ptr slot_mapping) { + size_t key = CacheManager::createDescriptorKey(k, v, k_cache, v_cache, slot_mapping); + + infiniopPagedCachingDescriptor_t desc; + if (!cache_manager->getPagedCachingDescriptor(key, desc)) { + RUN_INFINI(infiniopCreatePagedCachingDescriptor( + op_handle, &desc, k->desc(), v->desc(), + k_cache->desc(), v_cache->desc(), slot_mapping->desc())); + cache_manager->putPagedCachingDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetPagedCachingWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + RUN_INFINI(infiniopPagedCaching( + desc, workspace, workspace_size, + k->data(), v->data(), + k_cache->data(), v_cache->data(), + slot_mapping->data(), stream)); +} + +void InferenceContext::pagedAttention(std::shared_ptr out, + std::shared_ptr q, + std::shared_ptr k_cache, + std::shared_ptr v_cache, + std::shared_ptr block_tables, + std::shared_ptr seq_lens, + std::shared_ptr alibi_slopes, // can be nullptr + float scale) { + + size_t key = CacheManager::createDescriptorKey(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes); + + infiniopPagedAttentionDescriptor_t desc; + if (!cache_manager->getPagedAttentionDescriptor(key, desc)) { + infiniopTensorDescriptor_t alibi_desc = alibi_slopes ? alibi_slopes->desc() : nullptr; + RUN_INFINI(infiniopCreatePagedAttentionDescriptor( + op_handle, &desc, out->desc(), q->desc(), + k_cache->desc(), v_cache->desc(), block_tables->desc(), + seq_lens->desc(), alibi_desc, scale)); + cache_manager->putPagedAttentionDescriptor(key, desc); + } + + size_t workspace_size = 0; + RUN_INFINI(infiniopGetPagedAttentionWorkspaceSize(desc, &workspace_size)); + ensure_workspace(workspace_size); + void *workspace = workspace_storage->memory(); + + const void* alibi_data = alibi_slopes ? alibi_slopes->data() : nullptr; + RUN_INFINI(infiniopPagedAttention( + desc, workspace, workspace_size, + out->data(), q->data(), k_cache->data(), v_cache->data(), + block_tables->data(), seq_lens->data(), alibi_data, + stream)); +} + + void InferenceContext::dequant(std::shared_ptr weight, std::shared_ptr in_w, std::shared_ptr in_s, @@ -279,5 +341,5 @@ void InferenceContext::dequant(std::shared_ptr weight, RUN_INFINI(infiniopDequantize( desc, workspace, workspace_size, - weight->data(), in_w->data(), in_s->data(), in_z->data(), 0, 0, 0, stream)); + weight->data(), in_w->data(), in_s->data(), in_z->data(), stream)); } diff --git a/src/models/inference_context.hpp b/src/models/inference_context.hpp index 0cf93f6f..fdad0a17 100644 --- a/src/models/inference_context.hpp +++ b/src/models/inference_context.hpp @@ -58,6 +58,21 @@ struct InferenceContext { float alpha, float beta, std::shared_ptr residual, std::shared_ptr bias); + + void pagedCaching(std::shared_ptr k, + std::shared_ptr v, + std::shared_ptr k_cache, + std::shared_ptr v_cache, + std::shared_ptr slot_mapping); + + void pagedAttention(std::shared_ptr out, + std::shared_ptr q, + std::shared_ptr k_cache, + std::shared_ptr v_cache, + std::shared_ptr block_tables, + std::shared_ptr seq_lens, + std::shared_ptr alibi_slopes, // can be nullptr + float scale); void dequant(std::shared_ptr weight, std::shared_ptr in_w, std::shared_ptr in_s, @@ -140,8 +155,24 @@ inline void linear(std::shared_ptr c, std::shared_ptr a, std::shared_ptr b, float alpha, float beta, std::shared_ptr residual, std::shared_ptr bias) { getInferenceContext().linear(c, a, b, alpha, beta, residual, bias); + +} + +inline void pagedCaching(std::shared_ptr k, std::shared_ptr v, + std::shared_ptr k_cache, std::shared_ptr v_cache, + std::shared_ptr slot_mapping) { + getInferenceContext().pagedCaching(k, v, k_cache, v_cache, slot_mapping); } +inline void pagedAttention(std::shared_ptr out, std::shared_ptr q, + std::shared_ptr k_cache, std::shared_ptr v_cache, + std::shared_ptr block_tables, std::shared_ptr seq_lens, + std::shared_ptr alibi_slopes, float scale) { + getInferenceContext().pagedAttention(out, q, k_cache, v_cache, block_tables, seq_lens, alibi_slopes, scale); +} + + + inline void dequant_linear(std::shared_ptr out, std::shared_ptr x, std::shared_ptr w_w, std::shared_ptr w_s, std::shared_ptr w_z, float alpha, float beta, std::shared_ptr residual, std::shared_ptr bias) { diff --git a/src/models/jiuge/jiuge.cpp b/src/models/jiuge/jiuge.cpp index 059842cc..52cc6eb4 100644 --- a/src/models/jiuge/jiuge.cpp +++ b/src/models/jiuge/jiuge.cpp @@ -9,6 +9,8 @@ #include #include #include +#include +#include void createDeviceResource(JiugeDeviceResource *rsrc, const JiugeMeta *meta, const JiugeWeights *weights, @@ -116,7 +118,10 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, struct KVCache **kv_caches, + const int32_t *block_tables, + const int32_t *slot_mapping, const float *temperature, const uint32_t *topk, const float *topp, + const uint32_t is_prefill, const bool enable_paged_attn, uint32_t *output, void *last_logits) { auto nlayer = meta.nlayer; auto nkvh = meta.nkvh / ndev; @@ -130,13 +135,13 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, auto dvoc = meta.dvoc; auto stream = rsrc.stream; bool has_qkv_bias = rsrc.b_attn_qkv.size() > 0; - // Allocate buffers auto logits_in = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); auto logits_out = Tensor::buffer(dt_logits, {ntok, d}, rsrc.memory_pool); auto qkv_buf = Tensor::buffer(dt_logits, {ntok, (nh + nkvh * 2) * dh}, rsrc.memory_pool); auto gate_up_buf = Tensor::buffer(dt_logits, {ntok, 2 * di}, rsrc.memory_pool); auto o_buf = Tensor::buffer(dt_logits, {ntok, nh * dh}, rsrc.memory_pool); + auto q_buf = Tensor::buffer(dt_logits, {ntok, nh , dh}, rsrc.memory_pool); auto prob_buf = Tensor::buffer(dt_logits, {nreq, dvoc}, rsrc.memory_pool); auto result_buf = Tensor::buffer(INFINI_DTYPE_I64, {nreq}, rsrc.memory_pool); auto result_cpu = std::vector(nreq); @@ -145,11 +150,14 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, // Prepare inputs auto batch_pos_ids = std::vector(ntok); + auto batch_seq_lens = std::vector(nreq); + size_t req_start = 0; for (uint32_t req = 0; req < nreq; req++) { for (uint32_t i = 0; i < req_lens[req]; i++) { batch_pos_ids[req_start + i] = req_pos[req] + i; } + batch_seq_lens[req] = req_lens[req] + req_pos[req]; req_start += req_lens[req]; } @@ -167,6 +175,27 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, dsize(dt_logits) * d, INFINIRT_MEMCPY_D2D, stream)); } + std::shared_ptr slot_mapping_buf, block_tables_buf, seq_lens_buf; + size_t max_seq_len_in_batch = 0; + if (enable_paged_attn) { + max_seq_len_in_batch = *std::max_element(batch_seq_lens.begin(), batch_seq_lens.end()); + // Assuming block_size is a known constant, e.g., 16. The max_blocks_per_seq can be calculated. + // Let's assume a reasonable upper bound for simplicity. This might need to be passed in. + // TODO: get block_size from meta + size_t block_size = meta.kvcache_block_size; + size_t max_blocks_per_seq = (max_seq_len_in_batch + block_size - 1) / block_size; + + + slot_mapping_buf = Tensor::buffer(INFINI_DTYPE_I32, {ntok}, rsrc.memory_pool); + block_tables_buf = Tensor::buffer(INFINI_DTYPE_I32, {(uint32_t)nreq, (uint32_t)max_blocks_per_seq}, rsrc.memory_pool); + seq_lens_buf = Tensor::buffer(INFINI_DTYPE_I32, {nreq}, rsrc.memory_pool); + + RUN_INFINI(infinirtMemcpyAsync(slot_mapping_buf->data(), slot_mapping, sizeof(int32_t) * ntok, INFINIRT_MEMCPY_H2D, stream)); + RUN_INFINI(infinirtMemcpyAsync(block_tables_buf->data(), block_tables, sizeof(int32_t) * (nreq * max_blocks_per_seq), INFINIRT_MEMCPY_H2D, stream)); + RUN_INFINI(infinirtMemcpyAsync(seq_lens_buf->data(), batch_seq_lens.data(), sizeof(int32_t) * nreq, INFINIRT_MEMCPY_H2D, stream)); + + } + // Attention // attention inner size_t max_qk_size = 0; @@ -187,11 +216,12 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, auto attn_val_buf = Tensor::buffer(dt_logits, {nkvh, ngroup * max_seq_len, dh}, rsrc.memory_pool); auto attn_val_gemm = attn_val_buf->view({nkvh, ngroup, max_seq_len, dh}); + // MLP buffers auto gate_buf = gate_up_buf->slice(1, 0, di); auto up_buf = gate_up_buf->slice(1, di, di); - // Compute + for (uint32_t layer = 0; layer < nlayer; layer++) { // 1. Attention // rms norm @@ -202,34 +232,97 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, rope(qkv_rope->slice(1, 0, nh), qkv_rope->slice(1, 0, nh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table); rope(qkv_rope->slice(1, nh, nkvh), qkv_rope->slice(1, nh, nkvh), pos_ids_buf, rsrc.sin_table, rsrc.cos_table); - size_t token_offset = 0; - for (uint32_t req = 0; req < nreq; req++) { - auto past_len = req_pos[req]; - auto seq_len = req_lens[req]; - auto total_len = past_len + seq_len; - auto o = o_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); - auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); - auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}}); - auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}}); - - // self attention - // concat - rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), k); - rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v); - // qk - rearrange(q_rearrange->slice(2, 0, seq_len), q); - auto qk_gemm = qk_buf->slice(0, 0, nh * seq_len * total_len)->view({nkvh, ngroup * seq_len, total_len}); - auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}); - linear(qk_gemm, rearrange_q_buf->slice(1, 0, ngroup * seq_len), k_gemm, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr); - // softmax - auto qk_softmax = qk_gemm->view({nh, seq_len, total_len}); - causalSoftmax(qk_softmax, qk_softmax); - auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}); - linear(attn_val_buf->slice(1, 0, ngroup * seq_len), qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr); - // rearrange attn val - rearrange(o, attn_val_gemm->slice(2, 0, seq_len)); - - token_offset += seq_len; + if (enable_paged_attn) { + auto k = qkv_rope->slice({ {0, 0, ntok}, {1, nh, nkvh} }); + auto v = qkv_rope->slice({ {0, 0, ntok}, {1, nh + nkvh, nkvh} }); + + // Assuming kv_caches[0] gives access to the entire cache pool for this device. + // This part may need adjustment based on the actual KVCache struct definition. + auto k_cache_pool = kv_caches[0]->k[idev][layer]; + auto v_cache_pool = kv_caches[0]->v[idev][layer]; + pagedCaching(k, v, k_cache_pool, v_cache_pool, slot_mapping_buf); + // printf("o_buf: pass pagedCaching\n"); + + if (is_prefill) { + size_t token_offset = 0; + for (uint32_t req = 0; req < nreq; req++) { + auto past_len = req_pos[req]; + auto seq_len = req_lens[req]; + auto total_len = past_len + seq_len; + auto o = o_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); + auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}}); + auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}}); + auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); + // qk + // std::cout << "rearrange q" << std::endl; + // std::cout << "q shape: " << q->info() << std::endl; + rearrange(q_rearrange->slice(2, 0, seq_len), q); + // std::cout << "qk_gemm" << std::endl; + // std::cout << "qk_buf: " << qk_buf->info() << std::endl; + auto qk_gemm = qk_buf->slice(0, 0, nh * seq_len * total_len)->view({nkvh, ngroup * seq_len, total_len}); + auto k_gemm = k->permute({1, 2, 0}); + linear(qk_gemm, rearrange_q_buf->slice(1, 0, ngroup * seq_len), k_gemm, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr); + // softmax + // std::cout << "qk_softmax" << std::endl; + auto qk_softmax = qk_gemm->view({nh, seq_len, total_len}); + causalSoftmax(qk_softmax, qk_softmax); + // std::cout << "v_gemm" << std::endl; + auto v_gemm = v->permute({1, 0, 2}); + // std::cout << "attn_val_buf" << std::endl; + linear(attn_val_buf->slice(1, 0, ngroup * seq_len), qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr); + // rearrange attn val + rearrange(o, attn_val_gemm->slice(2, 0, seq_len)); + + token_offset += seq_len; + } + } else { + auto o = o_buf->slice({{0, 0, ntok}})->view({ntok, nh, dh}); + auto q_batch = qkv_rope->slice({ {0, 0, ntok}, {1, 0, nh} })->view({ntok, nh, dh}); + + float scale = 1.f / float(sqrt(dh)); + pagedAttention(o, q_batch, k_cache_pool, v_cache_pool, + block_tables_buf, seq_lens_buf, nullptr /* alibi_slopes */, scale); + + + } + + } else { + size_t token_offset = 0; + for (uint32_t req = 0; req < nreq; req++) { + auto past_len = req_pos[req]; + auto seq_len = req_lens[req]; + auto total_len = past_len + seq_len; + auto o = o_buf->slice({{0, token_offset, seq_len}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); + auto q = qkv_rope->slice({{0, token_offset, seq_len}, {1, 0, nh}})->view({seq_len, nkvh, ngroup, dh})->permute({1, 2, 0, 3}); + auto k = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh, nkvh}}); + auto v = qkv_rope->slice({{0, token_offset, seq_len}, {1, nh + nkvh, nkvh}}); + + // self attention + // concat + rearrange(kv_caches[req]->k[idev][layer]->slice(0, past_len, seq_len), k); + rearrange(kv_caches[req]->v[idev][layer]->slice(0, past_len, seq_len), v); + // qk + // std::cout << "rearrange q" << std::endl; + // std::cout << "q shape: " << q->info() << std::endl; + rearrange(q_rearrange->slice(2, 0, seq_len), q); + // std::cout << "qk_gemm" << std::endl; + // std::cout << "qk_buf: " << qk_buf->info() << std::endl; + auto qk_gemm = qk_buf->slice(0, 0, nh * seq_len * total_len)->view({nkvh, ngroup * seq_len, total_len}); + auto k_gemm = kv_caches[req]->k[idev][layer]->slice(0, 0, total_len)->permute({1, 2, 0}); + linear(qk_gemm, rearrange_q_buf->slice(1, 0, ngroup * seq_len), k_gemm, 1.f / float(sqrt(dh)), 0.f, nullptr, nullptr); + // softmax + // std::cout << "qk_softmax" << std::endl; + auto qk_softmax = qk_gemm->view({nh, seq_len, total_len}); + causalSoftmax(qk_softmax, qk_softmax); + // std::cout << "v_gemm" << std::endl; + auto v_gemm = kv_caches[req]->v[idev][layer]->slice(0, 0, total_len)->permute({1, 0, 2}); + // std::cout << "attn_val_buf" << std::endl; + linear(attn_val_buf->slice(1, 0, ngroup * seq_len), qk_gemm, v_gemm, 1.f, 0.f, nullptr, nullptr); + // rearrange attn val + rearrange(o, attn_val_gemm->slice(2, 0, seq_len)); + + token_offset += seq_len; + } } // o_proj @@ -255,7 +348,10 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, INFINICCL_SUM, rsrc.comm, stream)); RUN_INFINI(infinirtStreamSynchronize(stream)); } + // printf("o_buf: pass layer %d\n", layer); } + // printf("o_buf: pass all layers\n"); + // Sample and Output if (idev == 0) { if (last_logits != nullptr) { @@ -296,13 +392,15 @@ void inferDeviceBatch(const JiugeMeta &meta, JiugeDeviceResource &rsrc, } } } - __C void inferBatchJiuge(struct JiugeModel *model, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, - struct KVCache **kv_caches, + struct KVCache **kv_caches, + const int32_t *block_tables, + const int32_t *slot_mapping, const float *temperature, const uint32_t *topk, const float *topp, + const uint32_t is_prefill, const bool enable_paged_attn, uint32_t *output) { model->req.tokens = tokens; model->req.ntok = ntok; @@ -310,11 +408,15 @@ inferBatchJiuge(struct JiugeModel *model, model->req.nreq = nreq; model->req.req_pos = req_pos; model->req.kv_caches = kv_caches; + model->req.block_tables = block_tables; + model->req.slot_mapping = slot_mapping; model->req.output = output; model->req.logits = nullptr; model->req.temperature = temperature; model->req.topk = topk; model->req.topp = topp; + model->req.is_prefill = is_prefill; + model->req.enable_paged_attn = enable_paged_attn; for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { std::unique_lock lock(model->states[idev].mtx); @@ -335,6 +437,9 @@ forwardBatchJiuge(struct JiugeModel *model, const uint32_t *tokens, uint32_t ntok, const uint32_t *req_lens, uint32_t nreq, const uint32_t *req_pos, struct KVCache **kv_caches, + const int32_t *block_tables, + const int32_t *slot_mapping, + const uint32_t is_prefill, const bool enable_paged_attn, void *logits) { model->req.tokens = tokens; model->req.ntok = ntok; @@ -342,11 +447,15 @@ forwardBatchJiuge(struct JiugeModel *model, model->req.nreq = nreq; model->req.req_pos = req_pos; model->req.kv_caches = kv_caches; + model->req.block_tables = block_tables; + model->req.slot_mapping = slot_mapping; model->req.output = nullptr; model->req.logits = logits; model->req.temperature = nullptr; model->req.topk = nullptr; model->req.topp = nullptr; + model->req.is_prefill = is_prefill; + model->req.enable_paged_attn = enable_paged_attn; for (size_t idev = 0; idev < model->dev_ids.size(); idev++) { std::unique_lock lock(model->states[idev].mtx); @@ -391,7 +500,10 @@ void launchDevice(const JiugeMeta &meta, const JiugeWeights *weights, JiugeDevic inferDeviceBatch(meta, *rsrc, idev, ndev, req.tokens, req.ntok, req.req_lens, req.nreq, req.req_pos, req.kv_caches, - req.temperature, req.topk, req.topp, req.output, req.logits); + req.block_tables, req.slot_mapping, + req.temperature, req.topk, req.topp, + req.is_prefill, req.enable_paged_attn, + req.output, req.logits); state.proceed = false; lock.unlock(); diff --git a/src/models/jiuge/jiuge_impl.hpp b/src/models/jiuge/jiuge_impl.hpp index 55800a37..4a673bbb 100644 --- a/src/models/jiuge/jiuge_impl.hpp +++ b/src/models/jiuge/jiuge_impl.hpp @@ -45,10 +45,14 @@ struct InferRequest { uint32_t nreq; const uint32_t *req_pos; struct KVCache **kv_caches; + const int32_t *block_tables; + const int32_t *slot_mapping; const float *temperature; const uint32_t *topk; const float *topp; uint32_t *output; + uint32_t is_prefill; + bool enable_paged_attn; void *logits; }; diff --git a/src/tensor.hpp b/src/tensor.hpp index 320d871c..129e8d21 100644 --- a/src/tensor.hpp +++ b/src/tensor.hpp @@ -137,6 +137,12 @@ class Tensor : public std::enable_shared_from_this { std::shared_ptr view_as(const std::vector &new_shape) const; std::shared_ptr view_as(const std::vector &new_shape, const std::vector &new_strides) const; + // template + // void init_value(T value, infiniopHandle_t handle, infinirtStream_t stream); + + // template + // void init_value_simple(T value, infiniopHandle_t handle, infinirtStream_t stream); + ~Tensor(); }; diff --git a/src/tensor/tensor.cpp b/src/tensor/tensor.cpp index edf0faeb..3b3e1c08 100644 --- a/src/tensor/tensor.cpp +++ b/src/tensor/tensor.cpp @@ -424,3 +424,61 @@ void Tensor::debug(const std::string &filename) const { } void Tensor::debug() const { this->debug(""); } + + +// template +// void Tensor::init_value(T value, infiniopHandle_t handle, +// infinirtStream_t stream) { +// ASSERT_EQ(dsize(this->dtype()), sizeof(T)); + +// size_t numel = 1; +// for (size_t dim : this->shape()) { +// numel *= dim; +// } +// if (numel == 0) { +// return; +// } + +// RUN_INFINI(infinirtMemcpy(this->data(), &value, sizeof(T), +// INFINIRT_MEMCPY_H2D)); + +// auto ndim = this->ndim(); +// auto shape = this->shape(); +// auto bcast_strides = std::vector(ndim, 0); +// auto src_desc = TensorDesc::create(this->dtype(), shape, bcast_strides); + +// infiniopRearrangeDescriptor_t rearrange_desc; +// RUN_INFINI(infiniopCreateRearrangeDescriptor( +// handle, &rearrange_desc, this->desc(), src_desc->desc())); +// RUN_INFINI(infiniopRearrange(rearrange_desc, this->data(), this->data(), +// stream)); + +// RUN_INFINI(infiniopDestroyRearrangeDescriptor(rearrange_desc)); +// } +// template +// void Tensor::init_value_simple(T value, infiniopHandle_t handle, +// infinirtStream_t stream) { +// // 1. 安全检查:确保类型匹配 +// ASSERT_EQ(dsize(this->dtype()), sizeof(T)); + +// // 2. 计算张量元素总数 +// size_t numel = 1; +// for (size_t dim : this->shape()) { +// numel *= dim; +// } +// if (numel == 0) { +// return; +// } + +// // 3. 在 Host (CPU) 上创建一个填满目标值的临时数据源 +// std::vector host_data(numel, value); + +// // 4. 使用 Tensor::weight 功能在设备上创建一个临时的、内容正确的源张量。 +// // 这个源张量的形状与当前张量相同,但内存是连续的。 +// // Tensor::weight 内部会处理从 Host 到 Device 的数据拷贝。 +// auto src_tensor = Tensor::weight(host_data.data(), this->dtype(), this->shape()); + +// // 5. 使用现有的、安全的 copyFrom 函数完成赋值。 +// // copyFrom 会正确处理当前张量(this)可能存在的非连续内存布局(strides)。 +// this->copyFrom(src_tensor, handle, stream); +// }