Skip to content

Commit f95a148

Browse files
authored
Add convbert model (#1036)
1 parent c6d7c8c commit f95a148

File tree

9 files changed

+1710
-1
lines changed

9 files changed

+1710
-1
lines changed

mindone/transformers/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,15 @@
155155
CLIPVisionModelWithProjection,
156156
)
157157
from .models.cohere2 import Cohere2ForCausalLM, Cohere2Model, Cohere2PreTrainedModel
158+
from .models.convbert import (
159+
ConvBertForMaskedLM,
160+
ConvBertForMultipleChoice,
161+
ConvBertForQuestionAnswering,
162+
ConvBertForSequenceClassification,
163+
ConvBertForTokenClassification,
164+
ConvBertLayer,
165+
ConvBertModel,
166+
)
158167
from .models.deberta import (
159168
DebertaForMaskedLM,
160169
DebertaForQuestionAnswering,

mindone/transformers/generation/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1351,7 +1351,7 @@ def compute_transition_scores(
13511351
13521352
```python
13531353
>>> from transformers import GPT2Tokenizer
1354-
>>> from mindway.transformers import AutoModelForCausalLM
1354+
>>> from mindone.transformers import AutoModelForCausalLM
13551355
>>> import numpy as np
13561356
13571357
>>> tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

mindone/transformers/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
camembert,
3232
clap,
3333
clip,
34+
convbert,
3435
dpt,
3536
fuyu,
3637
gemma,

mindone/transformers/models/auto/configuration_auto.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
("helium", "HeliumConfig"),
6262
("hiera", "HieraConfig"),
6363
("camembert", "CamembertConfig"),
64+
("convbert", "ConvBertConfig"),
6465
("idefics", "IdeficsConfig"),
6566
("idefics2", "Idefics2Config"),
6667
("idefics3", "Idefics3Config"),
@@ -184,6 +185,7 @@
184185
("umt5", "UMT5"),
185186
("wav2vec2", "Wav2Vec2"),
186187
("whisper", "Whisper"),
188+
("convbert", "ConvBERT"),
187189
("xlm-roberta", "XLM-RoBERTa"),
188190
("xlm-roberta-xl", "XLM-RoBERTa-XL"),
189191
("cohere2", "Cohere2"),

mindone/transformers/models/auto/modeling_auto.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
("bart", "BartModel"),
4040
("camembert", "CamembertModel"),
4141
("mvp", "MvpModel"),
42+
("convbert", "ConvBertModel"),
4243
("bit", "BitModel"),
4344
("blip", "BlipModel"),
4445
("blip-2", "Blip2Model"),
@@ -143,6 +144,7 @@
143144
("bert", "BertForMaskedLM"),
144145
("deberta", "DebertaForMaskedLM"),
145146
("deberta-v2", "DebertaV2ForMaskedLM"),
147+
("convbert", "ConvBertForMaskedLM"),
146148
("gpt2", "GPT2LMHeadModel"),
147149
("led", "LEDForConditionalGeneration"),
148150
("camembert", "CamembertForMaskedLM"),
@@ -285,6 +287,7 @@
285287
("mvp", "MvpForConditionalGeneration"),
286288
("albert", "AlbertForMaskedLM"),
287289
("bart", "BartForConditionalGeneration"),
290+
("convbert", "ConvBertForMaskedLM"),
288291
("bert", "BertForMaskedLM"),
289292
("roberta", "RobertaForMaskedLM"),
290293
("camembert", "CamembertForMaskedLM"),
@@ -371,6 +374,7 @@
371374
("llama", "LlamaForSequenceClassification"),
372375
("persimmon", "PersimmonForSequenceClassification"),
373376
("mobilebert", "MobileBertForSequenceClassification"),
377+
("convbert", "ConvBertForSequenceClassification"),
374378
("mt5", "MT5ForSequenceClassification"),
375379
("megatron-bert", "MegatronBertForSequenceClassification"),
376380
("mistral", "MistralForSequenceClassification"),
@@ -398,6 +402,7 @@
398402
("deberta", "DebertaForQuestionAnswering"),
399403
("deberta-v2", "DebertaV2ForQuestionAnswering"),
400404
("led", "LEDForQuestionAnswering"),
405+
("convbert", "ConvBertForQuestionAnswering"),
401406
("llama", "LlamaForQuestionAnswering"),
402407
("mobilebert", "MobileBertForQuestionAnswering"),
403408
("megatron-bert", "MegatronBertForQuestionAnswering"),
@@ -446,6 +451,7 @@
446451
("qwen2", "Qwen2ForTokenClassification"),
447452
("roberta", "RobertaForTokenClassification"),
448453
("rembert", "RemBertForTokenClassification"),
454+
("convbert", "ConvBertForTokenClassification"),
449455
("t5", "T5ForTokenClassification"),
450456
("umt5", "UMT5ForTokenClassification"),
451457
("xlm-roberta", "XLMRobertaForTokenClassification"),
@@ -458,6 +464,7 @@
458464
# Model for Multiple Choice mapping
459465
("camembert", "CamembertForMultipleChoice"),
460466
("albert", "AlbertForMultipleChoice"),
467+
("convbert", "ConvBertForMultipleChoice"),
461468
("bert", "BertForMultipleChoice"),
462469
("roberta", "RobertaForMultipleChoice"),
463470
("deberta-v2", "DebertaV2ForMultipleChoice"),
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# This code is adapted from https://github.com/huggingface/transformers
4+
# with modifications to run transformers on mindspore.
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
from .modeling_convbert import (
18+
ConvBertForMaskedLM,
19+
ConvBertForMultipleChoice,
20+
ConvBertForQuestionAnswering,
21+
ConvBertForSequenceClassification,
22+
ConvBertForTokenClassification,
23+
ConvBertLayer,
24+
ConvBertModel,
25+
)

0 commit comments

Comments
 (0)