11import torch
2- from typing import final , Dict
32from typing_extensions import override
43from lightllm .models .registry import ModelRegistry
54from lightllm .models .qwen3_moe .model import Qwen3MOEModel
65from lightllm .models .qwen3next .layer_weights .transformer_layer_weight import Qwen3NextTransformerLayerWeight
76from lightllm .models .qwen3next .layer_infer .transformer_layer_infer import Qwen3NextTransformerLayerInfer
87from lightllm .utils .log_utils import init_logger
98from lightllm .distributed .communication_op import dist_group_manager
10- from lightllm .common .basemodel .layer_weights .hf_load_utils import load_hf_weights
11- from lightllm .common .mem_manager import MemoryManager
129from lightllm .utils .envs_utils import get_env_start_args
1310from lightllm .models .qwen3next .mem_manager import Qwen3NextMemoryManager , MambaStateBufferConfig
14- from lightllm .models .llama .model import LlamaFlashInferStateExtraInfo
1511
1612logger = init_logger (__name__ )
1713
@@ -26,7 +22,6 @@ class Qwen3NextTpPartModel(Qwen3MOEModel):
2622
2723 def __init__ (self , kvargs ) -> None :
2824 super ().__init__ (kvargs )
29- return
3025
3126 @override
3227 def autotune_layers (self ):
@@ -35,9 +30,7 @@ def autotune_layers(self):
3530 @override
3631 def _init_config (self ):
3732 super ()._init_config ()
38- self .config ["num_hidden_layers" ] = 4
3933 self .num_kv_heads = max (self .config ["num_key_value_heads" ] // self .tp_world_size_ , 1 )
40- return
4134
4235 @override
4336 def _init_custom (self ):
@@ -80,4 +73,3 @@ def _init_mem_manager(self):
8073 mamba_state_buffer_config = mamba_state_buffer_config ,
8174 mem_fraction = self .mem_fraction ,
8275 )
83- return
0 commit comments