@@ -167,7 +167,7 @@ def update_weights_from_ipc(self, zmq_handles: dict[str, str]):
167167 self .device = torch .device (f"npu:{ self .local_rank } " )
168168 assert self .device is not None
169169
170- def _load_weights (weights : _WEIGHTS_TYPE ):
170+ def _update_weights (weights : _WEIGHTS_TYPE ):
171171 # Load main model weights
172172 self .model_runner .model .load_weights (weights )
173173 # Load drafter model weights if MTP/speculative decoding is enabled
@@ -177,7 +177,7 @@ def _load_weights(weights: _WEIGHTS_TYPE):
177177 ):
178178 self .model_runner .drafter .model .load_weights (weights = weights )
179179
180- def _post_hook ():
180+ def _process_weight_after_loading ():
181181 process_weights_after_loading (self .model_runner .model , self .model_config , self .device )
182182 # Also trigger drafter model's post processing if MTP is enabled
183183 if (
@@ -188,10 +188,15 @@ def _post_hook():
188188 self .model_runner .drafter .model , self .model_config , self .device
189189 )
190190
191+ torch .cuda .empty_cache ()
192+
191193 update_weights_from_ipc (
192194 self ._zmq_ctx ,
193195 zmq_handles [self ._device_uuid ],
194196 device_id = self .device .index ,
195- run = _load_weights ,
196- post_hook = _post_hook ,
197+ run = _update_weights ,
198+ post_hook = _process_weight_after_loading ,
197199 )
200+
201+ if getattr (self , "_sampler_warmup" , None ) is not None :
202+ self ._sampler_warmup ()
0 commit comments