From 997a02f711679bb20ebad32d8d0952571278f74c Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Thu, 19 Feb 2026 12:41:54 -0800 Subject: [PATCH 1/7] feat: add Kimi-K2.5 (moonshotai/Kimi-K2.5) model support - Add model config for Kimi-K2.5 (MLA-based MoE, 61 layers, 384 routed experts, 64 attention heads, 262k context) - Register KimiK25ForConditionalGeneration architecture under the DEEPSEEK model family and add moonshotai/Kimi-K2.5 to DefaultHFModels - Fix _parse_hf_config_json to fall back to top-level config when model params are nested under "text_config" (required for VLM-style HF configs like Kimi-K2.5) - Extend MLA collector test cases and TRT-LLM collect_mla n_list to cover num_heads=64 (Kimi-K2.5) in addition to the existing 128 (DeepSeek-V3) Fix DeepSeekModel / TrtllmWideEPDeepSeekModel hardcoded 128-head ops: DeepSeekModel and TrtllmWideEPDeepSeekModel hardcoded DeepSeek-V3's 128 attention heads in several MLA GEMM / attention ops, making them produce incorrect weight-size and latency estimates for any DEEPSEEK model with a different head count (e.g. Kimi-K2.5 with 64 heads). Replace every affected hardcode with self._num_heads: - context/generation q_b_proj_gemm n = num_heads * 192 // tp - context kv_b_proj_gemm n = num_heads * 256 // tp - context/generation_attention n_heads = num_heads // tp - context_proj_gemm k = num_heads * 128 // tp Fix nextn (MTP) auto-assigned to all DEEPSEEK models (task.py): nextn was unconditionally set to 1 for every DEEPSEEK model, adding a spurious (nextn+1) activation-memory multiplier and incorrect MTP latency scaling for models without Multi-Token Prediction support. Now reads num_nextn_predict_layers from the raw model config (default 0), so DeepSeek-V3/V3.1 still get nextn=1 while Kimi-K2.5 gets nextn=0. Fix IndexError in get_worker_candidates() when all configs OOM (inference_session.py): Same exceptions[-1]-on-empty-list crash fixed in agg_pareto() by #378 now also fixed in DisaggInferenceSession.get_worker_candidates(). Fix disagg per-worker GPU search space not scaling with --total-gpus (task.py): _finalize_disagg used total_gpus only to cap max_gpu_per_replica (replica scaling), but never updated num_gpu_per_worker / tp_list / dp_list / moe_ep_list in the prefill and decode worker configs. Those lists were hardcoded to [1,2,4,8], so large MoE models like Kimi-K2.5 (needing EP=32+ to avoid OOM) were never explored regardless of --total-gpus. _finalize_disagg now extends each non-singleton parallel list with powers-of-2 up to total_gpus so that configurations like EP=32/64/128 are included in the sweep when sufficient GPUs are available. Co-Authored-By: Claude Sonnet 4.6 --- collector/common_test_cases.py | 1 + collector/trtllm/collect_mla.py | 4 +-- .../moonshotai--Kimi-K2.5_config.json | 26 +++++++++++++++++++ src/aiconfigurator/sdk/common.py | 3 +++ src/aiconfigurator/sdk/models.py | 26 +++++++++---------- src/aiconfigurator/sdk/task.py | 21 ++++++++++++++- src/aiconfigurator/sdk/utils.py | 26 +++++++++++-------- 7 files changed, 80 insertions(+), 27 deletions(-) create mode 100644 src/aiconfigurator/model_configs/moonshotai--Kimi-K2.5_config.json diff --git a/collector/common_test_cases.py b/collector/common_test_cases.py index c9e9b4a2..c25c63c8 100644 --- a/collector/common_test_cases.py +++ b/collector/common_test_cases.py @@ -226,6 +226,7 @@ def _get_mla_common_test_cases(is_context: bool): # num_heads, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim model_config_list = [ [128, 1536, 512, 128, 64, 128, "deepseek-ai/DeepSeek-V3"], + [64, 1536, 512, 128, 64, 128, "moonshotai/Kimi-K2.5"], ] if is_context: diff --git a/collector/trtllm/collect_mla.py b/collector/trtllm/collect_mla.py index 7fc604c1..7774da64 100644 --- a/collector/trtllm/collect_mla.py +++ b/collector/trtllm/collect_mla.py @@ -24,7 +24,7 @@ def get_context_mla_test_cases(): dtype_list = [tensorrt_llm.bindings.DataType.BF16] # not support f8 for trt < v1.1 test_cases = [] - n_list = [128] + n_list = [64, 128] b_list = [1, 2, 4, 8, 16, 32, 64, 128, 256] s_list = [1, 16, 32, 64, 128, 256, 512, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 10240, 12288, 16384, 32768] for n in n_list: @@ -59,7 +59,7 @@ def get_context_mla_test_cases(): def get_generation_mla_test_cases(): dtype_list = [tensorrt_llm.bindings.DataType.BF16] # not support f8 for trt < v1.1 test_cases = [] - n_list = [128] + n_list = [64, 128] for n in n_list: for b in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]: for s in [ diff --git a/src/aiconfigurator/model_configs/moonshotai--Kimi-K2.5_config.json b/src/aiconfigurator/model_configs/moonshotai--Kimi-K2.5_config.json new file mode 100644 index 00000000..cef60d82 --- /dev/null +++ b/src/aiconfigurator/model_configs/moonshotai--Kimi-K2.5_config.json @@ -0,0 +1,26 @@ +{ + "architectures": ["KimiK25ForConditionalGeneration"], + "model_type": "kimi_k25", + "num_hidden_layers": 61, + "hidden_size": 7168, + "num_attention_heads": 64, + "num_key_value_heads": 64, + "intermediate_size": 18432, + "vocab_size": 163840, + "max_position_embeddings": 262144, + "n_routed_experts": 384, + "n_shared_experts": 1, + "num_experts_per_tok": 8, + "moe_intermediate_size": 2048, + "moe_layer_freq": 1, + "first_k_dense_replace": 1, + "kv_lora_rank": 512, + "q_lora_rank": 1536, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "routed_scaling_factor": 2.827, + "torch_dtype": "bfloat16", + "use_cache": true, + "tie_word_embeddings": false +} diff --git a/src/aiconfigurator/sdk/common.py b/src/aiconfigurator/sdk/common.py index c5cd81af..67a59a4d 100644 --- a/src/aiconfigurator/sdk/common.py +++ b/src/aiconfigurator/sdk/common.py @@ -235,6 +235,8 @@ def get_default_models() -> set[str]: # DeepSeek Models "deepseek-ai/DeepSeek-V3", "nvidia/DeepSeek-V3.1-NVFP4", + # Kimi Models + "moonshotai/Kimi-K2.5", # Qwen 2.5 Models "Qwen/Qwen2.5-1.5B", "Qwen/Qwen2.5-7B", @@ -288,6 +290,7 @@ def get_default_models() -> set[str]: "Qwen3ForCausalLM": "LLAMA", "DeepSeekForCausalLM": "DEEPSEEK", "DeepseekV3ForCausalLM": "DEEPSEEK", + "KimiK25ForConditionalGeneration": "DEEPSEEK", "NemotronForCausalLM": "NEMOTRONNAS", "DeciLMForCausalLM": "NEMOTRONNAS", "NemotronHForCausalLM": "NEMOTRONH", diff --git a/src/aiconfigurator/sdk/models.py b/src/aiconfigurator/sdk/models.py index 3d8e147d..e1d46def 100755 --- a/src/aiconfigurator/sdk/models.py +++ b/src/aiconfigurator/sdk/models.py @@ -1137,27 +1137,27 @@ def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> N ops.GEMM( "context_q_b_proj_gemm", self._num_layers, - 24576 // tp_size, + self._num_heads * 192 // tp_size, # num_heads * (qk_nope_head_dim + qk_rope_head_dim) 1536, gemm_quant_mode, ), ops.GEMM( "context_kv_b_proj_gemm", self._num_layers, - 32768 // tp_size, + self._num_heads * 256 // tp_size, # num_heads * (qk_nope_head_dim + v_head_dim) 512, gemm_quant_mode, ), # agg ctx attn part ops.ContextMLA( "context_attention", self._num_layers, - 128 // tp_size, + self._num_heads // tp_size, kvcache_quant_mode, fmha_quant_mode, ), # agg ctx attn part ops.GEMM( - "context_proj_gemm", self._num_layers, h, 128 * 128 // tp_size, gemm_quant_mode - ), # agg ctx attn part + "context_proj_gemm", self._num_layers, h, self._num_heads * 128 // tp_size, gemm_quant_mode + ), # agg ctx attn part; 128 = v_head_dim ops.ElementWise("context_add_norm_2", self._num_layers, 2 * h, 2 * h, 0.8), ] ) @@ -1294,7 +1294,7 @@ def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> N ops.GEMM( "generation_q_b_proj_gemm", self._num_layers * self._mtp_scale_factor, - 24576 // tp_size, + self._num_heads * 192 // tp_size, # num_heads * (qk_nope_head_dim + qk_rope_head_dim) 1536, gemm_quant_mode, ), @@ -1308,7 +1308,7 @@ def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> N ops.GenerationMLA( "generation_attention", self._num_layers * self._mtp_scale_factor, - 128 // tp_size, + self._num_heads // tp_size, kvcache_quant_mode, ), # agg gen attn part ops.MLABmm( @@ -1596,25 +1596,25 @@ def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> N ops.GEMM( "context_q_b_proj_gemm", self._num_layers, - 24576 // tp_size, + self._num_heads * 192 // tp_size, # num_heads * (qk_nope_head_dim + qk_rope_head_dim) 1536, gemm_quant_mode, ), ops.GEMM( "context_kv_b_proj_gemm", self._num_layers, - 32768 // tp_size, + self._num_heads * 256 // tp_size, # num_heads * (qk_nope_head_dim + v_head_dim) 512, gemm_quant_mode, ), ops.ContextMLA( "context_attention", self._num_layers, - 128 // tp_size, + self._num_heads // tp_size, kvcache_quant_mode, fmha_quant_mode, ), - ops.GEMM("context_proj_gemm", self._num_layers, h, 128 * 128 // tp_size, gemm_quant_mode), + ops.GEMM("context_proj_gemm", self._num_layers, h, self._num_heads * 128 // tp_size, gemm_quant_mode), # 128 = v_head_dim ops.ElementWise("context_add_norm_2", self._num_layers, 2 * h, 2 * h, 0.8), ] ) @@ -1755,7 +1755,7 @@ def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> N ops.GEMM( "generation_q_b_proj_gemm", self._num_layers * self._mtp_scale_factor, - 24576 // tp_size, + self._num_heads * 192 // tp_size, # num_heads * (qk_nope_head_dim + qk_rope_head_dim) 1536, gemm_quant_mode, ), @@ -1769,7 +1769,7 @@ def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> N ops.GenerationMLA( "generation_attention", self._num_layers * self._mtp_scale_factor, - 128 // tp_size, + self._num_heads // tp_size, kvcache_quant_mode, ), ops.MLABmm( diff --git a/src/aiconfigurator/sdk/task.py b/src/aiconfigurator/sdk/task.py index 8adacf47..22dd5993 100644 --- a/src/aiconfigurator/sdk/task.py +++ b/src/aiconfigurator/sdk/task.py @@ -303,7 +303,8 @@ def _mode_layers(cls, ctx: TaskContext) -> list[ConfigLayer]: @staticmethod def _base_common_layer(ctx: TaskContext) -> dict: - nextn = 1 if ctx.model_family == "DEEPSEEK" else 0 + raw_config = get_model_config_from_model_path(ctx.model_path).get("raw_config", {}) + nextn = raw_config.get("num_nextn_predict_layers", 0) return { "serving_mode": ctx.serving_mode, "model_path": ctx.model_path, @@ -488,6 +489,24 @@ def _finalize_disagg(cls, config: DefaultMunch, ctx: TaskContext) -> None: replica_cfg.max_gpu_per_replica = min(ctx.total_gpus, replica_cfg.get("max_gpu_per_replica")) logger.debug("Using max gpu per replica %s", replica_cfg.max_gpu_per_replica) + # Extend per-worker parallel config lists with powers-of-2 up to total_gpus. + # The default search space is [1,2,4,8], which is insufficient for large MoE + # models (e.g. Kimi-K2.5 with 384 experts, DeepSeek-V3 with 256 experts) that + # require 32+ GPUs per worker to avoid OOM. Only extend lists that already + # contain values > 1 (lists pinned to [1] are intentionally single-valued). + for worker_cfg in (config.prefill_worker_config, config.decode_worker_config): + for key in ("num_gpu_per_worker", "tp_list", "dp_list", "moe_ep_list"): + current = list(getattr(worker_cfg, key, None) or []) + if not current or max(current) <= 1: + continue + v = max(current) * 2 + while v <= ctx.total_gpus: + if v not in current: + current.append(v) + v *= 2 + setattr(worker_cfg, key, current) + logger.debug("Extended worker %s to %s", key, current) + _quants = { "fp8": { diff --git a/src/aiconfigurator/sdk/utils.py b/src/aiconfigurator/sdk/utils.py index 69bbc92d..c3d69bb3 100644 --- a/src/aiconfigurator/sdk/utils.py +++ b/src/aiconfigurator/sdk/utils.py @@ -442,21 +442,25 @@ def _parse_hf_config_json(config: dict) -> dict: f"Supported architectures: {', '.join(ARCHITECTURE_TO_MODEL_FAMILY.keys())}" ) - layers = config["num_hidden_layers"] - hidden_size = config["hidden_size"] - n = config["num_attention_heads"] - vocab = config["vocab_size"] - context = config["max_position_embeddings"] + # For multimodal VLMs (e.g., KimiK25ForConditionalGeneration), model params are + # nested under "text_config"; fall back to top-level config for pure text models. + effective_config = config.get("text_config", config) + + layers = effective_config["num_hidden_layers"] + hidden_size = effective_config["hidden_size"] + n = effective_config["num_attention_heads"] + vocab = effective_config["vocab_size"] + context = effective_config["max_position_embeddings"] # Handle nullable fields (e.g., Nemotron has null for these) - n_kv = config.get("num_key_value_heads") or 0 - inter_size = config.get("intermediate_size") or 0 - d = config.get("head_dim") or config.get("attention_head_dim") or (hidden_size // n if n > 0 else 0) + n_kv = effective_config.get("num_key_value_heads") or 0 + inter_size = effective_config.get("intermediate_size") or 0 + d = effective_config.get("head_dim") or effective_config.get("attention_head_dim") or (hidden_size // n if n > 0 else 0) # MoE parameters - topk = config.get("num_experts_per_tok", 0) - num_experts = config.get("num_local_experts") or config.get("n_routed_experts") or config.get("num_experts", 0) - moe_inter_size = config.get("moe_intermediate_size", 0) or config.get("intermediate_size", 0) + topk = effective_config.get("num_experts_per_tok", 0) + num_experts = effective_config.get("num_local_experts") or effective_config.get("n_routed_experts") or effective_config.get("num_experts", 0) + moe_inter_size = effective_config.get("moe_intermediate_size", 0) or effective_config.get("intermediate_size", 0) # Handle NemotronH-specific configuration (only fields unique to NemotronH) extra_params = None From 5fcd29ab154c73b316bf55594a0ecb5fb06d92c7 Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Thu, 19 Feb 2026 17:58:41 -0800 Subject: [PATCH 2/7] fix --- src/aiconfigurator/sdk/models.py | 3 ++- src/aiconfigurator/sdk/task.py | 26 ++++++++++++++------------ src/aiconfigurator/sdk/utils.py | 12 ++++++++++-- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/src/aiconfigurator/sdk/models.py b/src/aiconfigurator/sdk/models.py index e1d46def..9bbd73eb 100755 --- a/src/aiconfigurator/sdk/models.py +++ b/src/aiconfigurator/sdk/models.py @@ -1614,7 +1614,8 @@ def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> N kvcache_quant_mode, fmha_quant_mode, ), - ops.GEMM("context_proj_gemm", self._num_layers, h, self._num_heads * 128 // tp_size, gemm_quant_mode), # 128 = v_head_dim + # 128 = v_head_dim + ops.GEMM("context_proj_gemm", self._num_layers, h, self._num_heads * 128 // tp_size, gemm_quant_mode), ops.ElementWise("context_add_norm_2", self._num_layers, 2 * h, 2 * h, 0.8), ] ) diff --git a/src/aiconfigurator/sdk/task.py b/src/aiconfigurator/sdk/task.py index 22dd5993..0760cf7a 100644 --- a/src/aiconfigurator/sdk/task.py +++ b/src/aiconfigurator/sdk/task.py @@ -494,18 +494,20 @@ def _finalize_disagg(cls, config: DefaultMunch, ctx: TaskContext) -> None: # models (e.g. Kimi-K2.5 with 384 experts, DeepSeek-V3 with 256 experts) that # require 32+ GPUs per worker to avoid OOM. Only extend lists that already # contain values > 1 (lists pinned to [1] are intentionally single-valued). - for worker_cfg in (config.prefill_worker_config, config.decode_worker_config): - for key in ("num_gpu_per_worker", "tp_list", "dp_list", "moe_ep_list"): - current = list(getattr(worker_cfg, key, None) or []) - if not current or max(current) <= 1: - continue - v = max(current) * 2 - while v <= ctx.total_gpus: - if v not in current: - current.append(v) - v *= 2 - setattr(worker_cfg, key, current) - logger.debug("Extended worker %s to %s", key, current) + # Wideep paths have carefully curated per-worker GPU lists; skip extension. + if not ctx.enable_wideep: + for worker_cfg in (config.prefill_worker_config, config.decode_worker_config): + for key in ("num_gpu_per_worker", "tp_list", "dp_list", "moe_ep_list"): + current = list(getattr(worker_cfg, key, None) or []) + if not current or max(current) <= 1: + continue + v = max(current) * 2 + while v <= ctx.total_gpus: + if v not in current: + current.append(v) + v *= 2 + setattr(worker_cfg, key, current) + logger.debug("Extended worker %s to %s", key, current) _quants = { diff --git a/src/aiconfigurator/sdk/utils.py b/src/aiconfigurator/sdk/utils.py index c3d69bb3..ba15fd9b 100644 --- a/src/aiconfigurator/sdk/utils.py +++ b/src/aiconfigurator/sdk/utils.py @@ -455,11 +455,19 @@ def _parse_hf_config_json(config: dict) -> dict: # Handle nullable fields (e.g., Nemotron has null for these) n_kv = effective_config.get("num_key_value_heads") or 0 inter_size = effective_config.get("intermediate_size") or 0 - d = effective_config.get("head_dim") or effective_config.get("attention_head_dim") or (hidden_size // n if n > 0 else 0) + d = ( + effective_config.get("head_dim") + or effective_config.get("attention_head_dim") + or (hidden_size // n if n > 0 else 0) + ) # MoE parameters topk = effective_config.get("num_experts_per_tok", 0) - num_experts = effective_config.get("num_local_experts") or effective_config.get("n_routed_experts") or effective_config.get("num_experts", 0) + num_experts = ( + effective_config.get("num_local_experts") + or effective_config.get("n_routed_experts") + or effective_config.get("num_experts", 0) + ) moe_inter_size = effective_config.get("moe_intermediate_size", 0) or effective_config.get("intermediate_size", 0) # Handle NemotronH-specific configuration (only fields unique to NemotronH) From 86e8b1d9f4be5be20782544fae2febc8965bd1ea Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Thu, 19 Feb 2026 18:08:18 -0800 Subject: [PATCH 3/7] ruff --- src/aiconfigurator/sdk/inference_session.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/aiconfigurator/sdk/inference_session.py b/src/aiconfigurator/sdk/inference_session.py index 2ebe0df8..022aaf42 100644 --- a/src/aiconfigurator/sdk/inference_session.py +++ b/src/aiconfigurator/sdk/inference_session.py @@ -373,9 +373,10 @@ def get_worker_candidates( continue if summary_df.empty: if exceptions: + last = exceptions[-1] raise RuntimeError( - f"No results found for any parallel configuration. Showing last exception: {exceptions[-1]}" - ) from exceptions[-1] + f"No results found for any parallel configuration. Showing last exception: {last}" + ) from last if all_configs_oom: raise RuntimeError( "No results found: the model does not fit in GPU memory for any parallel " From d9219759f08e23642435057b0bd1e3d99e3940d5 Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Thu, 19 Feb 2026 18:28:48 -0800 Subject: [PATCH 4/7] add k2 --- .../moonshotai--Kimi-K2-Instruct_config.json | 26 +++++++++++++++++++ .../moonshotai--Kimi-K2-Thinking_config.json | 26 +++++++++++++++++++ src/aiconfigurator/sdk/common.py | 2 ++ 3 files changed, 54 insertions(+) create mode 100644 src/aiconfigurator/model_configs/moonshotai--Kimi-K2-Instruct_config.json create mode 100644 src/aiconfigurator/model_configs/moonshotai--Kimi-K2-Thinking_config.json diff --git a/src/aiconfigurator/model_configs/moonshotai--Kimi-K2-Instruct_config.json b/src/aiconfigurator/model_configs/moonshotai--Kimi-K2-Instruct_config.json new file mode 100644 index 00000000..5b8c925e --- /dev/null +++ b/src/aiconfigurator/model_configs/moonshotai--Kimi-K2-Instruct_config.json @@ -0,0 +1,26 @@ +{ + "architectures": ["DeepseekV3ForCausalLM"], + "model_type": "kimi_k2", + "num_hidden_layers": 61, + "hidden_size": 7168, + "num_attention_heads": 64, + "num_key_value_heads": 64, + "intermediate_size": 18432, + "vocab_size": 163840, + "max_position_embeddings": 131072, + "n_routed_experts": 384, + "n_shared_experts": 1, + "num_experts_per_tok": 8, + "moe_intermediate_size": 2048, + "moe_layer_freq": 1, + "first_k_dense_replace": 1, + "kv_lora_rank": 512, + "q_lora_rank": 1536, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "routed_scaling_factor": 2.827, + "torch_dtype": "bfloat16", + "use_cache": true, + "tie_word_embeddings": false +} diff --git a/src/aiconfigurator/model_configs/moonshotai--Kimi-K2-Thinking_config.json b/src/aiconfigurator/model_configs/moonshotai--Kimi-K2-Thinking_config.json new file mode 100644 index 00000000..18589494 --- /dev/null +++ b/src/aiconfigurator/model_configs/moonshotai--Kimi-K2-Thinking_config.json @@ -0,0 +1,26 @@ +{ + "architectures": ["DeepseekV3ForCausalLM"], + "model_type": "kimi_k2", + "num_hidden_layers": 61, + "hidden_size": 7168, + "num_attention_heads": 64, + "num_key_value_heads": 64, + "intermediate_size": 18432, + "vocab_size": 163840, + "max_position_embeddings": 262144, + "n_routed_experts": 384, + "n_shared_experts": 1, + "num_experts_per_tok": 8, + "moe_intermediate_size": 2048, + "moe_layer_freq": 1, + "first_k_dense_replace": 1, + "kv_lora_rank": 512, + "q_lora_rank": 1536, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "routed_scaling_factor": 2.827, + "torch_dtype": "bfloat16", + "use_cache": true, + "tie_word_embeddings": false +} diff --git a/src/aiconfigurator/sdk/common.py b/src/aiconfigurator/sdk/common.py index 67a59a4d..219c99bf 100644 --- a/src/aiconfigurator/sdk/common.py +++ b/src/aiconfigurator/sdk/common.py @@ -236,6 +236,8 @@ def get_default_models() -> set[str]: "deepseek-ai/DeepSeek-V3", "nvidia/DeepSeek-V3.1-NVFP4", # Kimi Models + "moonshotai/Kimi-K2-Instruct", + "moonshotai/Kimi-K2-Thinking", "moonshotai/Kimi-K2.5", # Qwen 2.5 Models "Qwen/Qwen2.5-1.5B", From 5cca3e7fe7673f6a9d1f9f3b0a9b0143af2bd1dc Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Fri, 20 Feb 2026 10:00:56 -0800 Subject: [PATCH 5/7] wideep --- src/aiconfigurator/cli/main.py | 10 ++++++++++ src/aiconfigurator/sdk/task.py | 19 ------------------- 2 files changed, 10 insertions(+), 19 deletions(-) diff --git a/src/aiconfigurator/cli/main.py b/src/aiconfigurator/cli/main.py index 444c9ad2..e7accda0 100644 --- a/src/aiconfigurator/cli/main.py +++ b/src/aiconfigurator/cli/main.py @@ -145,6 +145,12 @@ def _add_default_mode_arguments(parser): help="Optional end-to-end request latency target (ms). Enables request-latency optimization mode.", ) parser.add_argument("--prefix", type=int, default=0, help="Prefix cache length. Default to 0.") + parser.add_argument( + "--enable-wideep", + action="store_true", + default=False, + help="Enable wide expert-parallelism search space (effective for DeepSeek models with trtllm/sglang backends).", + ) def _add_experiments_mode_arguments(parser): @@ -576,6 +582,7 @@ def build_default_task_configs( tpot: float = 30.0, request_latency: float | None = None, prefix: int = 0, + enable_wideep: bool = False, ) -> dict[str, TaskConfig]: """Build agg and disagg task configs for default mode comparison. @@ -594,6 +601,7 @@ def build_default_task_configs( tpot: Time per output token target in ms. request_latency: Optional end-to-end request latency target (ms). prefix: Prefix cache length. + enable_wideep: Enable wide expert-parallelism search space. Returns: Dict with TaskConfig objects. When backend='auto', returns 6 configs @@ -621,6 +629,7 @@ def build_default_task_configs( "request_latency": request_latency, "prefix": prefix, "database_mode": database_mode, + "enable_wideep": enable_wideep, } task_configs: dict[str, TaskConfig] = {} @@ -1332,6 +1341,7 @@ def main(args): tpot=args.tpot, request_latency=args.request_latency, prefix=args.prefix, + enable_wideep=args.enable_wideep, ) elif args.mode == "exp": try: diff --git a/src/aiconfigurator/sdk/task.py b/src/aiconfigurator/sdk/task.py index 0760cf7a..9fadaba5 100644 --- a/src/aiconfigurator/sdk/task.py +++ b/src/aiconfigurator/sdk/task.py @@ -489,25 +489,6 @@ def _finalize_disagg(cls, config: DefaultMunch, ctx: TaskContext) -> None: replica_cfg.max_gpu_per_replica = min(ctx.total_gpus, replica_cfg.get("max_gpu_per_replica")) logger.debug("Using max gpu per replica %s", replica_cfg.max_gpu_per_replica) - # Extend per-worker parallel config lists with powers-of-2 up to total_gpus. - # The default search space is [1,2,4,8], which is insufficient for large MoE - # models (e.g. Kimi-K2.5 with 384 experts, DeepSeek-V3 with 256 experts) that - # require 32+ GPUs per worker to avoid OOM. Only extend lists that already - # contain values > 1 (lists pinned to [1] are intentionally single-valued). - # Wideep paths have carefully curated per-worker GPU lists; skip extension. - if not ctx.enable_wideep: - for worker_cfg in (config.prefill_worker_config, config.decode_worker_config): - for key in ("num_gpu_per_worker", "tp_list", "dp_list", "moe_ep_list"): - current = list(getattr(worker_cfg, key, None) or []) - if not current or max(current) <= 1: - continue - v = max(current) * 2 - while v <= ctx.total_gpus: - if v not in current: - current.append(v) - v *= 2 - setattr(worker_cfg, key, current) - logger.debug("Extended worker %s to %s", key, current) _quants = { From 3270e649f58c7e75022ff772cbddf6a6f0f2a671 Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Fri, 20 Feb 2026 13:31:53 -0800 Subject: [PATCH 6/7] lint --- src/aiconfigurator/sdk/task.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/aiconfigurator/sdk/task.py b/src/aiconfigurator/sdk/task.py index 9fadaba5..aa9b386a 100644 --- a/src/aiconfigurator/sdk/task.py +++ b/src/aiconfigurator/sdk/task.py @@ -490,7 +490,6 @@ def _finalize_disagg(cls, config: DefaultMunch, ctx: TaskContext) -> None: logger.debug("Using max gpu per replica %s", replica_cfg.max_gpu_per_replica) - _quants = { "fp8": { "gemm_quant_mode": "fp8", From 71707e067f821891a107a6cf64351a60bb175578 Mon Sep 17 00:00:00 2001 From: Jason Zhou Date: Fri, 27 Feb 2026 14:17:52 -0800 Subject: [PATCH 7/7] feat: add vision encoder ops for Kimi-K2.5 VLM Model the ViT vision encoder (27-layer, 1152-dim), patch merger, and projector as GEMM/ElementWise ops prepended to context_ops. Each vision op carries _vision_num_tokens so the backend uses the correct token count (4096 pre-merge, 1024 post-merge) instead of isl. Also reverts collector changes per review comments. Co-Authored-By: Claude Opus 4.6 (1M context) --- collector/common_test_cases.py | 1 - collector/trtllm/collect_mla.py | 4 +- .../moonshotai--Kimi-K2.5_config.json | 15 ++++- .../sdk/backends/base_backend.py | 10 +++- src/aiconfigurator/sdk/models.py | 55 +++++++++++++++++++ 5 files changed, 79 insertions(+), 6 deletions(-) diff --git a/collector/common_test_cases.py b/collector/common_test_cases.py index c25c63c8..c9e9b4a2 100644 --- a/collector/common_test_cases.py +++ b/collector/common_test_cases.py @@ -226,7 +226,6 @@ def _get_mla_common_test_cases(is_context: bool): # num_heads, q_lora_rank, kv_lora_rank, qk_nope_head_dim, qk_rope_head_dim, v_head_dim model_config_list = [ [128, 1536, 512, 128, 64, 128, "deepseek-ai/DeepSeek-V3"], - [64, 1536, 512, 128, 64, 128, "moonshotai/Kimi-K2.5"], ] if is_context: diff --git a/collector/trtllm/collect_mla.py b/collector/trtllm/collect_mla.py index 7774da64..7fc604c1 100644 --- a/collector/trtllm/collect_mla.py +++ b/collector/trtllm/collect_mla.py @@ -24,7 +24,7 @@ def get_context_mla_test_cases(): dtype_list = [tensorrt_llm.bindings.DataType.BF16] # not support f8 for trt < v1.1 test_cases = [] - n_list = [64, 128] + n_list = [128] b_list = [1, 2, 4, 8, 16, 32, 64, 128, 256] s_list = [1, 16, 32, 64, 128, 256, 512, 1024, 1536, 2048, 3072, 4096, 6144, 8192, 10240, 12288, 16384, 32768] for n in n_list: @@ -59,7 +59,7 @@ def get_context_mla_test_cases(): def get_generation_mla_test_cases(): dtype_list = [tensorrt_llm.bindings.DataType.BF16] # not support f8 for trt < v1.1 test_cases = [] - n_list = [64, 128] + n_list = [128] for n in n_list: for b in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024]: for s in [ diff --git a/src/aiconfigurator/model_configs/moonshotai--Kimi-K2.5_config.json b/src/aiconfigurator/model_configs/moonshotai--Kimi-K2.5_config.json index cef60d82..56cdd2b5 100644 --- a/src/aiconfigurator/model_configs/moonshotai--Kimi-K2.5_config.json +++ b/src/aiconfigurator/model_configs/moonshotai--Kimi-K2.5_config.json @@ -22,5 +22,18 @@ "routed_scaling_factor": 2.827, "torch_dtype": "bfloat16", "use_cache": true, - "tie_word_embeddings": false + "tie_word_embeddings": false, + "vision_config": { + "vt_num_hidden_layers": 27, + "vt_num_attention_heads": 16, + "vt_hidden_size": 1152, + "vt_intermediate_size": 4304, + "patch_size": 14, + "init_pos_emb_height": 64, + "init_pos_emb_width": 64, + "merge_kernel_size": [2, 2], + "mm_hidden_size": 1152, + "text_hidden_size": 7168, + "mm_projector_type": "patchmerger" + } } diff --git a/src/aiconfigurator/sdk/backends/base_backend.py b/src/aiconfigurator/sdk/backends/base_backend.py index ba92849c..7f7c9194 100644 --- a/src/aiconfigurator/sdk/backends/base_backend.py +++ b/src/aiconfigurator/sdk/backends/base_backend.py @@ -77,13 +77,19 @@ def _run_context(batch_size: int, isl: int, prefix) -> tuple[dict[str, float], d for op in model.context_ops: # query latency and store the latency - x = batch_size * isl if "logits_gemm" not in op._name else batch_size + if "logits_gemm" in op._name: + x = batch_size + elif hasattr(op, "_vision_num_tokens"): + x = batch_size * op._vision_num_tokens + else: + x = batch_size * isl + s_val = op._vision_num_tokens if hasattr(op, "_vision_num_tokens") else isl result = op.query( database, x=x, batch_size=batch_size, beam_width=1, - s=isl, + s=s_val, prefix=prefix, model_name=getattr(model, "model_name", ""), ) diff --git a/src/aiconfigurator/sdk/models.py b/src/aiconfigurator/sdk/models.py index 9bbd73eb..9046f194 100755 --- a/src/aiconfigurator/sdk/models.py +++ b/src/aiconfigurator/sdk/models.py @@ -322,6 +322,11 @@ def get_model( # extra_params is NemotronHConfig with hybrid layer configuration model.set_hybrid_config(extra_params) + # Add vision encoder ops if model has vision config + vision_config = raw_config.get("vision_config") + if vision_config is not None and hasattr(model, "_add_vision_encoder_ops"): + model._add_vision_encoder_ops(vision_config) + return model @@ -1465,6 +1470,56 @@ def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> N # TODO # a lot of quantization ops + def _add_vision_encoder_ops(self, vision_config: dict) -> None: + """Add vision encoder (ViT + patch merger + projector) ops to context_ops. + + Vision ops are prepended so the vision encoder cost is accounted for + before the text decoder ops. Each vision op carries a + ``_vision_num_tokens`` attribute so the backend can use the correct + token count instead of ``isl``. + """ + vt_layers = vision_config["vt_num_hidden_layers"] # 27 + vt_heads = vision_config["vt_num_attention_heads"] # 16 + vt_hidden = vision_config["vt_hidden_size"] # 1152 + vt_inter = vision_config["vt_intermediate_size"] # 4304 + + init_h = vision_config["init_pos_emb_height"] # 64 + init_w = vision_config["init_pos_emb_width"] # 64 + num_patches = init_h * init_w # 4096 + + merge_kernel = vision_config["merge_kernel_size"] # [2, 2] + merge_h, merge_w = merge_kernel[0], merge_kernel[1] + num_merged_patches = num_patches // (merge_h * merge_w) # 1024 + + mm_hidden = vision_config["mm_hidden_size"] # 1152 + text_hidden = vision_config["text_hidden_size"] # 7168 + + fp16 = common.GEMMQuantMode.float16 + + # --- ViT transformer layers (pre-merge, num_patches tokens) --- + pre_merge_ops = [ + ops.ElementWise("vision_norm_1", vt_layers, vt_hidden, vt_hidden, 0.8), + ops.GEMM("vision_qkv_gemm", vt_layers, 3 * vt_hidden, vt_hidden, fp16), + ops.GEMM("vision_attn_proj_gemm", vt_layers, vt_hidden, vt_hidden, fp16), + ops.ElementWise("vision_norm_2", vt_layers, vt_hidden, vt_hidden, 0.8), + ops.GEMM("vision_ffn1_gemm", vt_layers, vt_inter, vt_hidden, fp16), + ops.ElementWise("vision_act", vt_layers, vt_inter, vt_inter, 0.8), + ops.GEMM("vision_ffn2_gemm", vt_layers, vt_hidden, vt_inter, fp16), + ] + for op in pre_merge_ops: + op._vision_num_tokens = num_patches + + # --- Patch merger + projector (post-merge, num_merged_patches tokens) --- + post_merge_ops = [ + ops.GEMM("vision_merge_gemm", 1, mm_hidden, merge_h * merge_w * mm_hidden, fp16), + ops.GEMM("vision_projector_gemm", 1, text_hidden, mm_hidden, fp16), + ] + for op in post_merge_ops: + op._vision_num_tokens = num_merged_patches + + # Prepend vision ops before text decoder ops + self.context_ops = pre_merge_ops + post_merge_ops + self.context_ops + class TrtllmWideEPDeepSeekModel(BaseModel): """