File tree Expand file tree Collapse file tree 3 files changed +11
-12
lines changed
tensorrt_llm/_torch/models
integration/defs/accuracy/references Expand file tree Collapse file tree 3 files changed +11
-12
lines changed Original file line number Diff line number Diff line change 2323from tensorrt_llm ._torch .attention_backend .interface import PositionalEmbeddingParams , RopeParams
2424from tensorrt_llm ._torch .model_config import ModelConfig
2525from tensorrt_llm ._torch .models .modeling_utils import (
26- _load_weights_impl ,
2726 DecoderModel ,
2827 DecoderModelForCausalLM ,
28+ _load_weights_impl ,
2929 register_auto_model ,
3030)
3131from tensorrt_llm ._torch .modules .attention import Attention
3838from tensorrt_llm .functional import PositionEmbeddingType
3939
4040
41-
4241class Starcoder2Attention (Attention ):
4342 """
4443 StarCoder2 Attention with Grouped Query Attention and Sliding Window support.
@@ -122,7 +121,9 @@ def __init__(
122121 config = model_config ,
123122 )
124123 else :
125- raise ValueError (f"Unsupported mlp_type: { config .mlp_type } . Only default (linear) MLP is supported." )
124+ raise ValueError (
125+ f"Unsupported mlp_type: { config .mlp_type } . Only default (linear) MLP is supported."
126+ )
126127
127128 norm_eps = getattr (config , "norm_epsilon" , 1e-5 )
128129 self .input_layernorm = LayerNorm (
@@ -219,9 +220,7 @@ def forward(
219220 lora_params = None ,
220221 ) -> torch .Tensor :
221222 if (input_ids is None ) ^ (inputs_embeds is not None ):
222- raise ValueError (
223- "You must specify exactly one of input_ids or inputs_embeds."
224- )
223+ raise ValueError ("You must specify exactly one of input_ids or inputs_embeds." )
225224
226225 if inputs_embeds is None :
227226 inputs_embeds = self .embed_tokens (input_ids )
Original file line number Diff line number Diff line change @@ -275,4 +275,4 @@ bigcode/starcoder2-3b:
275275bigcode/starcoder2-7b :
276276 - accuracy : 26.5
277277bigcode/starcoder2-15b :
278- - accuracy : 54.5
278+ - accuracy : 54.5
Original file line number Diff line number Diff line change 1- import pytest
21from copy import deepcopy
32from dataclasses import dataclass
4- from typing import Any
53
4+ import pytest
65import torch
76from transformers import Starcoder2Config
87from transformers import Starcoder2ForCausalLM as HFStarcoder2ForCausalLM
@@ -123,6 +122,7 @@ def get_kv_cache_manager(
123122 )
124123 return kv_cache_manager
125124
125+
126126@pytest .mark .parametrize (
127127 "scenario" ,
128128 [
@@ -173,13 +173,13 @@ def test_starcoder2_allclose_to_hf(scenario: Scenario) -> None:
173173 model_config = ModelConfig (pretrained_config = hf_config , attn_backend = backend )
174174 starcoder2 = Starcoder2ForCausalLM (model_config ).to (dtype ).to (device ).eval ()
175175 starcoder2 .load_weights (hf_starcoder2 .state_dict ())
176-
176+
177177 # Convert LayerNorm random weights to FP32 for numerical stability
178178 for name , module in starcoder2 .named_modules ():
179179 if isinstance (module , LayerNorm ):
180- if hasattr (module , ' weight' ) and module .weight is not None :
180+ if hasattr (module , " weight" ) and module .weight is not None :
181181 module .weight .data = module .weight .data .to (torch .float32 )
182- if hasattr (module , ' bias' ) and module .bias is not None :
182+ if hasattr (module , " bias" ) and module .bias is not None :
183183 module .bias .data = module .bias .data .to (torch .float32 )
184184
185185 num_blocks = 1
You can’t perform that action at this time.
0 commit comments