9
9
from transformers import LlamaConfig
10
10
11
11
from vllm .compilation .decorators import support_torch_compile
12
- from vllm .config import VllmConfig
12
+ from vllm .config import CacheConfig , VllmConfig , get_current_vllm_config
13
13
from vllm .logger import init_logger
14
14
from vllm .model_executor .layers .layernorm import RMSNorm
15
15
from vllm .model_executor .layers .linear import QKVParallelLinear
@@ -33,10 +33,14 @@ class LlamaDecoderLayer(LlamaDecoderLayer):
33
33
def __init__ (
34
34
self ,
35
35
config : LlamaConfig ,
36
+ cache_config : Optional [CacheConfig ] = None ,
36
37
quant_config : Optional [QuantizationConfig ] = None ,
37
38
prefix : str = "" ,
38
39
) -> None :
39
- super ().__init__ (config , quant_config = quant_config , prefix = prefix )
40
+ super ().__init__ (config ,
41
+ cache_config = cache_config ,
42
+ quant_config = quant_config ,
43
+ prefix = prefix )
40
44
41
45
# override qkv
42
46
self .self_attn .qkv_proj = QKVParallelLinear (
@@ -114,6 +118,8 @@ def __init__(
114
118
speculative_config .draft_model_config .hf_config
115
119
self .vocab_size = self .config .vocab_size
116
120
121
+ current_vllm_config = get_current_vllm_config ()
122
+
117
123
self .embed_tokens = VocabParallelEmbedding (
118
124
self .config .vocab_size ,
119
125
self .config .hidden_size ,
@@ -123,6 +129,7 @@ def __init__(
123
129
self .layers = nn .ModuleList ([
124
130
LlamaDecoderLayer (
125
131
config = self .config ,
132
+ cache_config = current_vllm_config .cache_config ,
126
133
prefix = maybe_prefix (prefix , f"layers.{ start_layer_id } " ),
127
134
)
128
135
])
0 commit comments