diff --git a/src/aiconfigurator/cli/main.py b/src/aiconfigurator/cli/main.py index 444c9ad2..e7accda0 100644 --- a/src/aiconfigurator/cli/main.py +++ b/src/aiconfigurator/cli/main.py @@ -145,6 +145,12 @@ def _add_default_mode_arguments(parser): help="Optional end-to-end request latency target (ms). Enables request-latency optimization mode.", ) parser.add_argument("--prefix", type=int, default=0, help="Prefix cache length. Default to 0.") + parser.add_argument( + "--enable-wideep", + action="store_true", + default=False, + help="Enable wide expert-parallelism search space (effective for DeepSeek models with trtllm/sglang backends).", + ) def _add_experiments_mode_arguments(parser): @@ -576,6 +582,7 @@ def build_default_task_configs( tpot: float = 30.0, request_latency: float | None = None, prefix: int = 0, + enable_wideep: bool = False, ) -> dict[str, TaskConfig]: """Build agg and disagg task configs for default mode comparison. @@ -594,6 +601,7 @@ def build_default_task_configs( tpot: Time per output token target in ms. request_latency: Optional end-to-end request latency target (ms). prefix: Prefix cache length. + enable_wideep: Enable wide expert-parallelism search space. Returns: Dict with TaskConfig objects. When backend='auto', returns 6 configs @@ -621,6 +629,7 @@ def build_default_task_configs( "request_latency": request_latency, "prefix": prefix, "database_mode": database_mode, + "enable_wideep": enable_wideep, } task_configs: dict[str, TaskConfig] = {} @@ -1332,6 +1341,7 @@ def main(args): tpot=args.tpot, request_latency=args.request_latency, prefix=args.prefix, + enable_wideep=args.enable_wideep, ) elif args.mode == "exp": try: diff --git a/src/aiconfigurator/model_configs/moonshotai--Kimi-K2-Instruct_config.json b/src/aiconfigurator/model_configs/moonshotai--Kimi-K2-Instruct_config.json new file mode 100644 index 00000000..5b8c925e --- /dev/null +++ b/src/aiconfigurator/model_configs/moonshotai--Kimi-K2-Instruct_config.json @@ -0,0 +1,26 @@ +{ + "architectures": ["DeepseekV3ForCausalLM"], + "model_type": "kimi_k2", + "num_hidden_layers": 61, + "hidden_size": 7168, + "num_attention_heads": 64, + "num_key_value_heads": 64, + "intermediate_size": 18432, + "vocab_size": 163840, + "max_position_embeddings": 131072, + "n_routed_experts": 384, + "n_shared_experts": 1, + "num_experts_per_tok": 8, + "moe_intermediate_size": 2048, + "moe_layer_freq": 1, + "first_k_dense_replace": 1, + "kv_lora_rank": 512, + "q_lora_rank": 1536, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "routed_scaling_factor": 2.827, + "torch_dtype": "bfloat16", + "use_cache": true, + "tie_word_embeddings": false +} diff --git a/src/aiconfigurator/model_configs/moonshotai--Kimi-K2-Thinking_config.json b/src/aiconfigurator/model_configs/moonshotai--Kimi-K2-Thinking_config.json new file mode 100644 index 00000000..18589494 --- /dev/null +++ b/src/aiconfigurator/model_configs/moonshotai--Kimi-K2-Thinking_config.json @@ -0,0 +1,26 @@ +{ + "architectures": ["DeepseekV3ForCausalLM"], + "model_type": "kimi_k2", + "num_hidden_layers": 61, + "hidden_size": 7168, + "num_attention_heads": 64, + "num_key_value_heads": 64, + "intermediate_size": 18432, + "vocab_size": 163840, + "max_position_embeddings": 262144, + "n_routed_experts": 384, + "n_shared_experts": 1, + "num_experts_per_tok": 8, + "moe_intermediate_size": 2048, + "moe_layer_freq": 1, + "first_k_dense_replace": 1, + "kv_lora_rank": 512, + "q_lora_rank": 1536, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "routed_scaling_factor": 2.827, + "torch_dtype": "bfloat16", + "use_cache": true, + "tie_word_embeddings": false +} diff --git a/src/aiconfigurator/model_configs/moonshotai--Kimi-K2.5_config.json b/src/aiconfigurator/model_configs/moonshotai--Kimi-K2.5_config.json new file mode 100644 index 00000000..56cdd2b5 --- /dev/null +++ b/src/aiconfigurator/model_configs/moonshotai--Kimi-K2.5_config.json @@ -0,0 +1,39 @@ +{ + "architectures": ["KimiK25ForConditionalGeneration"], + "model_type": "kimi_k25", + "num_hidden_layers": 61, + "hidden_size": 7168, + "num_attention_heads": 64, + "num_key_value_heads": 64, + "intermediate_size": 18432, + "vocab_size": 163840, + "max_position_embeddings": 262144, + "n_routed_experts": 384, + "n_shared_experts": 1, + "num_experts_per_tok": 8, + "moe_intermediate_size": 2048, + "moe_layer_freq": 1, + "first_k_dense_replace": 1, + "kv_lora_rank": 512, + "q_lora_rank": 1536, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "v_head_dim": 128, + "routed_scaling_factor": 2.827, + "torch_dtype": "bfloat16", + "use_cache": true, + "tie_word_embeddings": false, + "vision_config": { + "vt_num_hidden_layers": 27, + "vt_num_attention_heads": 16, + "vt_hidden_size": 1152, + "vt_intermediate_size": 4304, + "patch_size": 14, + "init_pos_emb_height": 64, + "init_pos_emb_width": 64, + "merge_kernel_size": [2, 2], + "mm_hidden_size": 1152, + "text_hidden_size": 7168, + "mm_projector_type": "patchmerger" + } +} diff --git a/src/aiconfigurator/sdk/backends/base_backend.py b/src/aiconfigurator/sdk/backends/base_backend.py index ba92849c..7f7c9194 100644 --- a/src/aiconfigurator/sdk/backends/base_backend.py +++ b/src/aiconfigurator/sdk/backends/base_backend.py @@ -77,13 +77,19 @@ def _run_context(batch_size: int, isl: int, prefix) -> tuple[dict[str, float], d for op in model.context_ops: # query latency and store the latency - x = batch_size * isl if "logits_gemm" not in op._name else batch_size + if "logits_gemm" in op._name: + x = batch_size + elif hasattr(op, "_vision_num_tokens"): + x = batch_size * op._vision_num_tokens + else: + x = batch_size * isl + s_val = op._vision_num_tokens if hasattr(op, "_vision_num_tokens") else isl result = op.query( database, x=x, batch_size=batch_size, beam_width=1, - s=isl, + s=s_val, prefix=prefix, model_name=getattr(model, "model_name", ""), ) diff --git a/src/aiconfigurator/sdk/common.py b/src/aiconfigurator/sdk/common.py index c5cd81af..219c99bf 100644 --- a/src/aiconfigurator/sdk/common.py +++ b/src/aiconfigurator/sdk/common.py @@ -235,6 +235,10 @@ def get_default_models() -> set[str]: # DeepSeek Models "deepseek-ai/DeepSeek-V3", "nvidia/DeepSeek-V3.1-NVFP4", + # Kimi Models + "moonshotai/Kimi-K2-Instruct", + "moonshotai/Kimi-K2-Thinking", + "moonshotai/Kimi-K2.5", # Qwen 2.5 Models "Qwen/Qwen2.5-1.5B", "Qwen/Qwen2.5-7B", @@ -288,6 +292,7 @@ def get_default_models() -> set[str]: "Qwen3ForCausalLM": "LLAMA", "DeepSeekForCausalLM": "DEEPSEEK", "DeepseekV3ForCausalLM": "DEEPSEEK", + "KimiK25ForConditionalGeneration": "DEEPSEEK", "NemotronForCausalLM": "NEMOTRONNAS", "DeciLMForCausalLM": "NEMOTRONNAS", "NemotronHForCausalLM": "NEMOTRONH", diff --git a/src/aiconfigurator/sdk/inference_session.py b/src/aiconfigurator/sdk/inference_session.py index 2ebe0df8..022aaf42 100644 --- a/src/aiconfigurator/sdk/inference_session.py +++ b/src/aiconfigurator/sdk/inference_session.py @@ -373,9 +373,10 @@ def get_worker_candidates( continue if summary_df.empty: if exceptions: + last = exceptions[-1] raise RuntimeError( - f"No results found for any parallel configuration. Showing last exception: {exceptions[-1]}" - ) from exceptions[-1] + f"No results found for any parallel configuration. Showing last exception: {last}" + ) from last if all_configs_oom: raise RuntimeError( "No results found: the model does not fit in GPU memory for any parallel " diff --git a/src/aiconfigurator/sdk/models.py b/src/aiconfigurator/sdk/models.py index 3d8e147d..9046f194 100755 --- a/src/aiconfigurator/sdk/models.py +++ b/src/aiconfigurator/sdk/models.py @@ -322,6 +322,11 @@ def get_model( # extra_params is NemotronHConfig with hybrid layer configuration model.set_hybrid_config(extra_params) + # Add vision encoder ops if model has vision config + vision_config = raw_config.get("vision_config") + if vision_config is not None and hasattr(model, "_add_vision_encoder_ops"): + model._add_vision_encoder_ops(vision_config) + return model @@ -1137,27 +1142,27 @@ def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> N ops.GEMM( "context_q_b_proj_gemm", self._num_layers, - 24576 // tp_size, + self._num_heads * 192 // tp_size, # num_heads * (qk_nope_head_dim + qk_rope_head_dim) 1536, gemm_quant_mode, ), ops.GEMM( "context_kv_b_proj_gemm", self._num_layers, - 32768 // tp_size, + self._num_heads * 256 // tp_size, # num_heads * (qk_nope_head_dim + v_head_dim) 512, gemm_quant_mode, ), # agg ctx attn part ops.ContextMLA( "context_attention", self._num_layers, - 128 // tp_size, + self._num_heads // tp_size, kvcache_quant_mode, fmha_quant_mode, ), # agg ctx attn part ops.GEMM( - "context_proj_gemm", self._num_layers, h, 128 * 128 // tp_size, gemm_quant_mode - ), # agg ctx attn part + "context_proj_gemm", self._num_layers, h, self._num_heads * 128 // tp_size, gemm_quant_mode + ), # agg ctx attn part; 128 = v_head_dim ops.ElementWise("context_add_norm_2", self._num_layers, 2 * h, 2 * h, 0.8), ] ) @@ -1294,7 +1299,7 @@ def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> N ops.GEMM( "generation_q_b_proj_gemm", self._num_layers * self._mtp_scale_factor, - 24576 // tp_size, + self._num_heads * 192 // tp_size, # num_heads * (qk_nope_head_dim + qk_rope_head_dim) 1536, gemm_quant_mode, ), @@ -1308,7 +1313,7 @@ def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> N ops.GenerationMLA( "generation_attention", self._num_layers * self._mtp_scale_factor, - 128 // tp_size, + self._num_heads // tp_size, kvcache_quant_mode, ), # agg gen attn part ops.MLABmm( @@ -1465,6 +1470,56 @@ def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> N # TODO # a lot of quantization ops + def _add_vision_encoder_ops(self, vision_config: dict) -> None: + """Add vision encoder (ViT + patch merger + projector) ops to context_ops. + + Vision ops are prepended so the vision encoder cost is accounted for + before the text decoder ops. Each vision op carries a + ``_vision_num_tokens`` attribute so the backend can use the correct + token count instead of ``isl``. + """ + vt_layers = vision_config["vt_num_hidden_layers"] # 27 + vt_heads = vision_config["vt_num_attention_heads"] # 16 + vt_hidden = vision_config["vt_hidden_size"] # 1152 + vt_inter = vision_config["vt_intermediate_size"] # 4304 + + init_h = vision_config["init_pos_emb_height"] # 64 + init_w = vision_config["init_pos_emb_width"] # 64 + num_patches = init_h * init_w # 4096 + + merge_kernel = vision_config["merge_kernel_size"] # [2, 2] + merge_h, merge_w = merge_kernel[0], merge_kernel[1] + num_merged_patches = num_patches // (merge_h * merge_w) # 1024 + + mm_hidden = vision_config["mm_hidden_size"] # 1152 + text_hidden = vision_config["text_hidden_size"] # 7168 + + fp16 = common.GEMMQuantMode.float16 + + # --- ViT transformer layers (pre-merge, num_patches tokens) --- + pre_merge_ops = [ + ops.ElementWise("vision_norm_1", vt_layers, vt_hidden, vt_hidden, 0.8), + ops.GEMM("vision_qkv_gemm", vt_layers, 3 * vt_hidden, vt_hidden, fp16), + ops.GEMM("vision_attn_proj_gemm", vt_layers, vt_hidden, vt_hidden, fp16), + ops.ElementWise("vision_norm_2", vt_layers, vt_hidden, vt_hidden, 0.8), + ops.GEMM("vision_ffn1_gemm", vt_layers, vt_inter, vt_hidden, fp16), + ops.ElementWise("vision_act", vt_layers, vt_inter, vt_inter, 0.8), + ops.GEMM("vision_ffn2_gemm", vt_layers, vt_hidden, vt_inter, fp16), + ] + for op in pre_merge_ops: + op._vision_num_tokens = num_patches + + # --- Patch merger + projector (post-merge, num_merged_patches tokens) --- + post_merge_ops = [ + ops.GEMM("vision_merge_gemm", 1, mm_hidden, merge_h * merge_w * mm_hidden, fp16), + ops.GEMM("vision_projector_gemm", 1, text_hidden, mm_hidden, fp16), + ] + for op in post_merge_ops: + op._vision_num_tokens = num_merged_patches + + # Prepend vision ops before text decoder ops + self.context_ops = pre_merge_ops + post_merge_ops + self.context_ops + class TrtllmWideEPDeepSeekModel(BaseModel): """ @@ -1596,25 +1651,26 @@ def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> N ops.GEMM( "context_q_b_proj_gemm", self._num_layers, - 24576 // tp_size, + self._num_heads * 192 // tp_size, # num_heads * (qk_nope_head_dim + qk_rope_head_dim) 1536, gemm_quant_mode, ), ops.GEMM( "context_kv_b_proj_gemm", self._num_layers, - 32768 // tp_size, + self._num_heads * 256 // tp_size, # num_heads * (qk_nope_head_dim + v_head_dim) 512, gemm_quant_mode, ), ops.ContextMLA( "context_attention", self._num_layers, - 128 // tp_size, + self._num_heads // tp_size, kvcache_quant_mode, fmha_quant_mode, ), - ops.GEMM("context_proj_gemm", self._num_layers, h, 128 * 128 // tp_size, gemm_quant_mode), + # 128 = v_head_dim + ops.GEMM("context_proj_gemm", self._num_layers, h, self._num_heads * 128 // tp_size, gemm_quant_mode), ops.ElementWise("context_add_norm_2", self._num_layers, 2 * h, 2 * h, 0.8), ] ) @@ -1755,7 +1811,7 @@ def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> N ops.GEMM( "generation_q_b_proj_gemm", self._num_layers * self._mtp_scale_factor, - 24576 // tp_size, + self._num_heads * 192 // tp_size, # num_heads * (qk_nope_head_dim + qk_rope_head_dim) 1536, gemm_quant_mode, ), @@ -1769,7 +1825,7 @@ def __init__(self, topk: int, num_experts: int, moe_inter_size: int, *args) -> N ops.GenerationMLA( "generation_attention", self._num_layers * self._mtp_scale_factor, - 128 // tp_size, + self._num_heads // tp_size, kvcache_quant_mode, ), ops.MLABmm( diff --git a/src/aiconfigurator/sdk/task.py b/src/aiconfigurator/sdk/task.py index 8adacf47..aa9b386a 100644 --- a/src/aiconfigurator/sdk/task.py +++ b/src/aiconfigurator/sdk/task.py @@ -303,7 +303,8 @@ def _mode_layers(cls, ctx: TaskContext) -> list[ConfigLayer]: @staticmethod def _base_common_layer(ctx: TaskContext) -> dict: - nextn = 1 if ctx.model_family == "DEEPSEEK" else 0 + raw_config = get_model_config_from_model_path(ctx.model_path).get("raw_config", {}) + nextn = raw_config.get("num_nextn_predict_layers", 0) return { "serving_mode": ctx.serving_mode, "model_path": ctx.model_path, diff --git a/src/aiconfigurator/sdk/utils.py b/src/aiconfigurator/sdk/utils.py index 69bbc92d..ba15fd9b 100644 --- a/src/aiconfigurator/sdk/utils.py +++ b/src/aiconfigurator/sdk/utils.py @@ -442,21 +442,33 @@ def _parse_hf_config_json(config: dict) -> dict: f"Supported architectures: {', '.join(ARCHITECTURE_TO_MODEL_FAMILY.keys())}" ) - layers = config["num_hidden_layers"] - hidden_size = config["hidden_size"] - n = config["num_attention_heads"] - vocab = config["vocab_size"] - context = config["max_position_embeddings"] + # For multimodal VLMs (e.g., KimiK25ForConditionalGeneration), model params are + # nested under "text_config"; fall back to top-level config for pure text models. + effective_config = config.get("text_config", config) + + layers = effective_config["num_hidden_layers"] + hidden_size = effective_config["hidden_size"] + n = effective_config["num_attention_heads"] + vocab = effective_config["vocab_size"] + context = effective_config["max_position_embeddings"] # Handle nullable fields (e.g., Nemotron has null for these) - n_kv = config.get("num_key_value_heads") or 0 - inter_size = config.get("intermediate_size") or 0 - d = config.get("head_dim") or config.get("attention_head_dim") or (hidden_size // n if n > 0 else 0) + n_kv = effective_config.get("num_key_value_heads") or 0 + inter_size = effective_config.get("intermediate_size") or 0 + d = ( + effective_config.get("head_dim") + or effective_config.get("attention_head_dim") + or (hidden_size // n if n > 0 else 0) + ) # MoE parameters - topk = config.get("num_experts_per_tok", 0) - num_experts = config.get("num_local_experts") or config.get("n_routed_experts") or config.get("num_experts", 0) - moe_inter_size = config.get("moe_intermediate_size", 0) or config.get("intermediate_size", 0) + topk = effective_config.get("num_experts_per_tok", 0) + num_experts = ( + effective_config.get("num_local_experts") + or effective_config.get("n_routed_experts") + or effective_config.get("num_experts", 0) + ) + moe_inter_size = effective_config.get("moe_intermediate_size", 0) or effective_config.get("intermediate_size", 0) # Handle NemotronH-specific configuration (only fields unique to NemotronH) extra_params = None