Skip to content

Commit c1d4a6f

Browse files
committed
simplify load_weights logic
Signed-off-by: Yibin Li <109242046+yibinl-nvidia@users.noreply.github.com>
1 parent e0826be commit c1d4a6f

File tree

1 file changed

+10
-26
lines changed

1 file changed

+10
-26
lines changed

tensorrt_llm/_torch/models/modeling_starcoder2.py

Lines changed: 10 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from tensorrt_llm._torch.attention_backend.interface import PositionalEmbeddingParams, RopeParams
2424
from tensorrt_llm._torch.model_config import ModelConfig
2525
from tensorrt_llm._torch.models.modeling_utils import (
26+
_load_weights_impl,
2627
DecoderModel,
2728
DecoderModelForCausalLM,
2829
register_auto_model,
@@ -37,6 +38,7 @@
3738
from tensorrt_llm.functional import PositionEmbeddingType
3839

3940

41+
4042
class Starcoder2Attention(Attention):
4143
"""
4244
StarCoder2 Attention with Grouped Query Attention and Sliding Window support.
@@ -276,29 +278,11 @@ def load_weights(self, weights, weight_mapper=None, skip_modules=None):
276278
r"(.*?)\.mlp\.c_fc\.(.*)": r"\1.mlp.up_proj.\2",
277279
r"(.*?)\.mlp\.c_proj\.(.*)": r"\1.mlp.down_proj.\2",
278280
}
279-
280-
if weight_mapper is None:
281-
# Use _load_weights_impl for non-weight-mapper path
282-
from tensorrt_llm._torch.models.modeling_utils import _load_weights_impl
283-
284-
preload_weight_modules = getattr(self, "preload_weight_modules", None)
285-
_load_weights_impl(
286-
self,
287-
weights,
288-
skip_modules,
289-
params_map=params_map,
290-
preload_weight_modules=preload_weight_modules,
291-
)
292-
else:
293-
# Use _load_weights_impl_v2 for weight-mapper path
294-
from tensorrt_llm._torch.models.modeling_utils import _load_weights_impl_v2
295-
296-
preload_weight_modules = getattr(self, "preload_weight_modules", None)
297-
_load_weights_impl_v2(
298-
self,
299-
weights,
300-
weight_mapper,
301-
skip_modules,
302-
params_map=params_map,
303-
preload_weight_modules=preload_weight_modules,
304-
)
281+
preload_weight_modules = getattr(self, "preload_weight_modules", None)
282+
_load_weights_impl(
283+
self,
284+
weights,
285+
skip_modules,
286+
params_map=params_map,
287+
preload_weight_modules=preload_weight_modules,
288+
)

0 commit comments

Comments
 (0)