|
23 | 23 | from tensorrt_llm._torch.attention_backend.interface import PositionalEmbeddingParams, RopeParams |
24 | 24 | from tensorrt_llm._torch.model_config import ModelConfig |
25 | 25 | from tensorrt_llm._torch.models.modeling_utils import ( |
| 26 | + _load_weights_impl, |
26 | 27 | DecoderModel, |
27 | 28 | DecoderModelForCausalLM, |
28 | 29 | register_auto_model, |
|
37 | 38 | from tensorrt_llm.functional import PositionEmbeddingType |
38 | 39 |
|
39 | 40 |
|
| 41 | + |
40 | 42 | class Starcoder2Attention(Attention): |
41 | 43 | """ |
42 | 44 | StarCoder2 Attention with Grouped Query Attention and Sliding Window support. |
@@ -276,29 +278,11 @@ def load_weights(self, weights, weight_mapper=None, skip_modules=None): |
276 | 278 | r"(.*?)\.mlp\.c_fc\.(.*)": r"\1.mlp.up_proj.\2", |
277 | 279 | r"(.*?)\.mlp\.c_proj\.(.*)": r"\1.mlp.down_proj.\2", |
278 | 280 | } |
279 | | - |
280 | | - if weight_mapper is None: |
281 | | - # Use _load_weights_impl for non-weight-mapper path |
282 | | - from tensorrt_llm._torch.models.modeling_utils import _load_weights_impl |
283 | | - |
284 | | - preload_weight_modules = getattr(self, "preload_weight_modules", None) |
285 | | - _load_weights_impl( |
286 | | - self, |
287 | | - weights, |
288 | | - skip_modules, |
289 | | - params_map=params_map, |
290 | | - preload_weight_modules=preload_weight_modules, |
291 | | - ) |
292 | | - else: |
293 | | - # Use _load_weights_impl_v2 for weight-mapper path |
294 | | - from tensorrt_llm._torch.models.modeling_utils import _load_weights_impl_v2 |
295 | | - |
296 | | - preload_weight_modules = getattr(self, "preload_weight_modules", None) |
297 | | - _load_weights_impl_v2( |
298 | | - self, |
299 | | - weights, |
300 | | - weight_mapper, |
301 | | - skip_modules, |
302 | | - params_map=params_map, |
303 | | - preload_weight_modules=preload_weight_modules, |
304 | | - ) |
| 281 | + preload_weight_modules = getattr(self, "preload_weight_modules", None) |
| 282 | + _load_weights_impl( |
| 283 | + self, |
| 284 | + weights, |
| 285 | + skip_modules, |
| 286 | + params_map=params_map, |
| 287 | + preload_weight_modules=preload_weight_modules, |
| 288 | + ) |
0 commit comments