Skip to content

Commit ba0c3d1

Browse files
authored
[NEW Model] Add BIT DPT Model (#4202)
* add bit dpt * update autobackbone example * paddle.to_tensor * typo * update dpt tests image file * update docs * test_training_gradient_checkpointing
1 parent e9f785b commit ba0c3d1

27 files changed

+7116
-87
lines changed

paddlenlp/transformers/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
)
2727
from .processing_utils import ProcessorMixin
2828
from .feature_extraction_utils import BatchFeature, FeatureExtractionMixin
29+
from .image_processing_utils import ImageProcessingMixin
2930
from .attention_utils import create_bigbird_rand_mask_idx_list
3031
from .export import export_model
3132

@@ -47,6 +48,9 @@
4748
from .electra.converter import *
4849
from .albert.modeling import *
4950
from .albert.tokenizer import *
51+
from .bit.modeling import *
52+
from .bit.configuration import *
53+
from .bit.image_processing import *
5054
from .bart.modeling import *
5155
from .bart.tokenizer import *
5256
from .bart.configuration import *
@@ -63,6 +67,9 @@
6367
from .convbert.tokenizer import *
6468
from .ctrl.modeling import *
6569
from .ctrl.tokenizer import *
70+
from .dpt.modeling import *
71+
from .dpt.configuration import *
72+
from .dpt.image_processing import *
6673
from .distilbert.modeling import *
6774
from .distilbert.tokenizer import *
6875
from .ernie.configuration import *
@@ -160,6 +167,7 @@
160167
from .clip.tokenizer import *
161168
from .clip.procesing import *
162169
from .clip.converter import *
170+
from .clip.image_processing import *
163171
from .gptj.modeling import *
164172
from .gptj.tokenizer import *
165173
from .pegasus.modeling import *
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
# Copyright 2022 The HuggingFace Team. All rights reserved.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import math
17+
from collections import OrderedDict
18+
19+
import paddle
20+
import paddle.nn.functional as F
21+
from paddle import Tensor, nn
22+
23+
24+
class NewGELUActivation(nn.Layer):
25+
"""
26+
Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see
27+
the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
28+
"""
29+
30+
def forward(self, input: Tensor) -> Tensor:
31+
return (
32+
0.5 * input * (1.0 + paddle.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * paddle.pow(input, 3.0))))
33+
)
34+
35+
36+
class GELUActivation(nn.Layer):
37+
"""
38+
Original Implementation of the GELU activation function in Google BERT repo when initially created. For
39+
information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 +
40+
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional
41+
Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415
42+
"""
43+
44+
def __init__(self, use_gelu_python: bool = False):
45+
super().__init__()
46+
if use_gelu_python:
47+
self.act = self._gelu_python
48+
else:
49+
self.act = nn.functional.gelu
50+
51+
def _gelu_python(self, input: Tensor) -> Tensor:
52+
return input * 0.5 * (1.0 + paddle.erf(input / math.sqrt(2.0)))
53+
54+
def forward(self, input: Tensor) -> Tensor:
55+
return self.act(input)
56+
57+
58+
class FastGELUActivation(nn.Layer):
59+
"""
60+
Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs
61+
"""
62+
63+
def forward(self, input: Tensor) -> Tensor:
64+
return 0.5 * input * (1.0 + paddle.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input)))
65+
66+
67+
class QuickGELUActivation(nn.Layer):
68+
"""
69+
Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs
70+
"""
71+
72+
def forward(self, input: Tensor) -> Tensor:
73+
return input * F.sigmoid(1.702 * input)
74+
75+
76+
class ClippedGELUActivation(nn.Layer):
77+
"""
78+
Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as
79+
it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to
80+
https://arxiv.org/abs/2004.09602.
81+
82+
Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when
83+
initially created.
84+
85+
For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 +
86+
torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415
87+
"""
88+
89+
def __init__(self, min: float, max: float):
90+
if min > max:
91+
raise ValueError(f"min should be < max (got min: {min}, max: {max})")
92+
93+
super().__init__()
94+
self.min = min
95+
self.max = max
96+
97+
def forward(self, x: Tensor) -> Tensor:
98+
return paddle.clip(gelu(x), self.min, self.max)
99+
100+
101+
class SiLUActivation(nn.Layer):
102+
"""
103+
See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear
104+
Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function
105+
Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated
106+
Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with
107+
later.
108+
"""
109+
110+
def forward(self, input: Tensor) -> Tensor:
111+
return F.silu(input)
112+
113+
114+
class MishActivation(nn.Layer):
115+
"""
116+
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
117+
visit the official repository for the paper: https://github.com/digantamisra98/Mish
118+
"""
119+
120+
def forward(self, input: Tensor) -> Tensor:
121+
return F.mish(input)
122+
123+
124+
class LinearActivation(nn.Layer):
125+
"""
126+
Applies the linear activation function, i.e. forwarding input directly to output.
127+
"""
128+
129+
def forward(self, input: Tensor) -> Tensor:
130+
return input
131+
132+
133+
class ClassInstantier(OrderedDict):
134+
def __getitem__(self, key):
135+
content = super().__getitem__(key)
136+
cls, kwargs = content if isinstance(content, tuple) else (content, {})
137+
return cls(**kwargs)
138+
139+
140+
ACT2CLS = {
141+
"gelu": GELUActivation,
142+
"gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}),
143+
"gelu_fast": FastGELUActivation,
144+
"gelu_new": NewGELUActivation,
145+
"gelu_python": (GELUActivation, {"use_gelu_python": True}),
146+
"linear": LinearActivation,
147+
"mish": MishActivation,
148+
"quick_gelu": QuickGELUActivation,
149+
"relu": nn.ReLU,
150+
"relu6": nn.ReLU6,
151+
"sigmoid": nn.Sigmoid,
152+
"silu": SiLUActivation,
153+
"swish": SiLUActivation,
154+
"tanh": nn.Tanh,
155+
}
156+
ACT2FN = ClassInstantier(ACT2CLS)
157+
158+
159+
def get_activation(activation_string):
160+
if activation_string in ACT2FN:
161+
return ACT2FN[activation_string]
162+
else:
163+
raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}")
164+
165+
166+
# For backwards compatibility with: from activations import gelu_python
167+
gelu_python = get_activation("gelu_python")
168+
gelu_new = get_activation("gelu_new")
169+
gelu = get_activation("gelu")
170+
gelu_fast = get_activation("gelu_fast")
171+
quick_gelu = get_activation("quick_gelu")
172+
silu = get_activation("silu")
173+
mish = get_activation("mish")
174+
linear_act = get_activation("linear")

paddlenlp/transformers/auto/modeling.py

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from paddlenlp.utils.log import logger
3434

3535
__all__ = [
36+
"AutoBackbone",
3637
"AutoModel",
3738
"AutoModelForPretraining",
3839
"AutoModelForSequenceClassification",
@@ -107,11 +108,14 @@
107108
("OPT", "opt"),
108109
("ErnieViL", "ernie_vil"),
109110
("Pegasus", "pegasus"),
111+
("DPT", "dpt"),
112+
("Bit", "bit"),
110113
]
111114
)
112115

113116
MAPPING_TASKS = OrderedDict(
114117
[
118+
("Backbone", "AutoBackbone"),
115119
("Model", "AutoModel"),
116120
("ForPretraining", "AutoModelForPretraining"),
117121
("ForSequenceClassification", "AutoModelForSequenceClassification"),
@@ -132,7 +136,7 @@
132136

133137
def get_name_mapping(task="Model"):
134138
"""
135-
Task can be 'Model', 'ForPretraining', 'ForSequenceClassification', 'ForTokenClassification',
139+
Task can be 'Backbone', 'Model', 'ForPretraining', 'ForSequenceClassification', 'ForTokenClassification',
136140
'ForQuestionAnswering', 'ForMultipleChoice', 'ForMaskedLM', 'ForCausalLM', 'Encoder', 'Decoder',
137141
'Generator', 'Discriminator', 'ForConditionalGeneration', 'ForImageGeneration'.
138142
"""
@@ -339,6 +343,49 @@ def _from_pretrained(cls, pretrained_model_name_or_path, task=None, from_hf_hub=
339343
logger.warning(f"{resolved_vocab_file} is not a valid path to a model config file")
340344

341345

346+
class AutoBackbone(_BaseAutoModelClass):
347+
"""
348+
AutoBackbone.
349+
"""
350+
351+
CONFIGURATION_MODEL_MAPPING = get_init_configurations()
352+
_pretrained_model_dict = CONFIGURATION_MODEL_MAPPING
353+
_name_mapping = get_name_mapping("Backbone")
354+
355+
@classmethod
356+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
357+
"""
358+
Creates an instance of `AutoBackbone`. Model weights are loaded
359+
by specifying name of a built-in pretrained model, or a community contributed model,
360+
or a local file directory path.
361+
362+
Args:
363+
pretrained_model_name_or_path (str): See :class:`AutoModel`.
364+
*args (tuple): See :class:`AutoModel`.
365+
**kwargs (dict): See :class:`AutoModel`.
366+
367+
Returns:
368+
PretrainedModel: An instance of `AutoBackbone`.
369+
370+
Example:
371+
.. code-block::
372+
373+
from paddlenlp.transformers import AutoBackbone
374+
375+
# Name of built-in pretrained model
376+
model = AutoBackbone.from_pretrained("google/bit-50")
377+
print(type(model))
378+
# <class 'paddlenlp.transformers.bit.modeling.BitBackbone'>
379+
380+
381+
# Load from local directory path
382+
model = AutoBackbone.from_pretrained("./bit-50")
383+
print(type(model))
384+
# <class 'paddlenlp.transformers.bit.modeling.BitBackbone'>
385+
"""
386+
return cls._from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
387+
388+
342389
class AutoModel(_BaseAutoModelClass):
343390
"""
344391
AutoClass can help you automatically retrieve the relevant model given the provided
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

0 commit comments

Comments
 (0)