Skip to content

Commit aab4685

Browse files
authored
Add starcoder2 model (#1066)
1 parent 0e4a168 commit aab4685

File tree

9 files changed

+1348
-2
lines changed

9 files changed

+1348
-2
lines changed

mindone/transformers/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,6 +391,13 @@
391391
SpeechT5Model,
392392
SpeechT5PreTrainedModel,
393393
)
394+
from .models.starcoder2 import (
395+
Starcoder2ForCausalLM,
396+
Starcoder2ForSequenceClassification,
397+
Starcoder2ForTokenClassification,
398+
Starcoder2Model,
399+
Starcoder2PreTrainedModel,
400+
)
394401
from .models.switch_transformers import (
395402
SwitchTransformersEncoderModel,
396403
SwitchTransformersForConditionalGeneration,

mindone/transformers/integrations/sdpa_attention.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ def repeat_kv(hidden_states: ms.Tensor, n_rep: int) -> ms.Tensor:
1414
batch, num_key_value_heads, slen, head_dim = hidden_states.shape # BNSD format
1515
if n_rep == 1:
1616
return hidden_states
17-
hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
17+
hidden_states = hidden_states[:, :, None, :, :].expand((batch, num_key_value_heads, n_rep, slen, head_dim))
1818
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
1919

2020

@@ -33,7 +33,7 @@ def sdpa_attention_forward(
3333
value_states = repeat_kv(value, module.num_key_value_groups)
3434

3535
attn_weights = mint.matmul(query, key_states.transpose(2, 3)) * scaling
36-
if attention_mask is not None:
36+
if attention_mask is not None and attention_mask.dim() == 4:
3737
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
3838
attn_weights = attn_weights + causal_mask
3939

mindone/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@
6767
roberta,
6868
siglip,
6969
speecht5,
70+
starcoder2,
7071
switch_transformers,
7172
t5,
7273
umt5,

mindone/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
("mistral", "MistralConfig"),
7777
("mobilebert", "MobileBertConfig"),
7878
("mpt", "MptConfig"),
79+
("starcoder2", "Starcoder2Config"),
7980
("mt5", "MT5Config"),
8081
("megatron-bert", "MegatronBertConfig"),
8182
("mixtral", "MixtralConfig"),
@@ -119,6 +120,7 @@
119120
("chameleon", "Chameleon"),
120121
("clap", "CLAP"),
121122
("clip", "CLIP"),
123+
("starcoder2", "Starcoder2"),
122124
("clip_vision_model", "CLIPVisionModel"),
123125
("deberta", "DeBERTa"),
124126
("deberta-v2", "DeBERTa-v2"),

mindone/transformers/models/auto/modeling_auto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
("bit", "BitModel"),
4242
("blip", "BlipModel"),
4343
("blip-2", "Blip2Model"),
44+
("starcoder2", "Starcoder2Model"),
4445
("chameleon", "ChameleonModel"),
4546
("clap", "ClapModel"),
4647
("clip", "CLIPModel"),
@@ -163,6 +164,7 @@
163164
("bert-generation", "BertGenerationDecoder"),
164165
("gemma", "GemmaForCausalLM"),
165166
("gemma2", "Gemma2ForCausalLM"),
167+
("starcoder2", "Starcoder2ForCausalLM"),
166168
("gemma3", "Gemma3ForCausalLM"),
167169
("gemma3_text", "Gemma3ForCausalLM"),
168170
("granite", "GraniteForCausalLM"),
@@ -355,6 +357,7 @@
355357
("glm", "GlmForSequenceClassification"),
356358
("helium", "HeliumForSequenceClassification"),
357359
("led", "LEDForSequenceClassification"),
360+
("starcoder2", "Starcoder2ForSequenceClassification"),
358361
("llama", "LlamaForSequenceClassification"),
359362
("persimmon", "PersimmonForSequenceClassification"),
360363
("mobilebert", "MobileBertForSequenceClassification"),
@@ -417,6 +420,7 @@
417420
("camembert", "CamembertForTokenClassification"),
418421
("deberta", "DebertaForTokenClassification"),
419422
("deberta-v2", "DebertaV2ForTokenClassification"),
423+
("starcoder2", "Starcoder2ForTokenClassification"),
420424
("glm", "GlmForTokenClassification"),
421425
("helium", "HeliumForTokenClassification"),
422426
("mistral", "MistralForTokenClassification"),
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# This code is adapted from https://github.com/huggingface/transformers
4+
# with modifications to run transformers on mindspore.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
from .modeling_starcoder2 import (
18+
Starcoder2ForCausalLM,
19+
Starcoder2ForSequenceClassification,
20+
Starcoder2ForTokenClassification,
21+
Starcoder2Model,
22+
Starcoder2PreTrainedModel,
23+
)

0 commit comments

Comments
 (0)