@@ -784,9 +784,6 @@ def __init__(
784784 "Speculative decoding is not supported with "
785785 "contiguous PA, please set VLLM_CONTIGUOUS_PA=false" )
786786 self .model_type = self .model_config .hf_config .model_type
787- if self .model_type in ("medusa" , "mlp_speculator" , "eagle" ,
788- "deepseek_mtp" ):
789- self .skip_warmup = True
790787
791788 # For both multi-step scheduling and delayed sampling
792789 self .cached_step_outputs : List [torch .Tensor ] = []
@@ -2218,7 +2215,13 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
22182215 self .bucketing_ctx .generate_prompt_buckets ()
22192216 if not self .is_pooler :
22202217 max_blocks = kv_caches [0 ][0 ].size (0 )
2221- self .bucketing_ctx .generate_decode_buckets (max_blocks )
2218+ num_speculative_tokens = 0
2219+ if (self .vllm_config .speculative_config is not None
2220+ and self .model_type not in ("medusa" , "mlp_speculator" , "eagle" ,
2221+ "deepseek_mtp" )):
2222+ num_speculative_tokens = self .vllm_config .speculative_config .num_speculative_tokens
2223+
2224+ self .bucketing_ctx .generate_decode_buckets (max_blocks , num_speculative_tokens )
22222225 if not htorch .utils .internal .is_lazy () and not self .enforce_eager :
22232226 multiplier = 3 if os .getenv ('VLLM_REGIONAL_COMPILATION' ,
22242227 'true' ).lower () == 'true' else 1
@@ -2261,10 +2264,11 @@ def warmup_model(self, kv_caches: List[torch.Tensor]) -> None:
22612264 False , kv_caches )
22622265
22632266 if not self .enforce_eager and htorch .utils .internal .is_lazy ():
2264- if not self .is_pooler :
2265- assert self .mem_margin is not None , \
2266- ("HabanaWorker.determine_num_available_blocks needs "
2267- "to be called before warming up the model." )
2267+ if not self .is_pooler and self .mem_margin is None :
2268+ free_hpu_memory = torch .hpu .mem_get_info ()[0 ]
2269+ hpu_memory_margin = free_hpu_memory * (
2270+ 1 - self .cache_config .gpu_memory_utilization )
2271+ self .mem_margin = hpu_memory_margin
22682272
22692273 free_mem = HabanaMemoryProfiler .current_free_device_memory ()
22702274 graph_free_mem = free_mem - self .mem_margin
0 commit comments