Skip to content

Commit 5e7c490

Browse files
Add DebertaClassifier and DeBERTa Presets (#594)
* Add DeBERTa presets * Add DebertaClassifier * Add imports * Fix initializer import * Small changes * Fix UTs * Fix hash * Fix UTs * Remove link in disclaimer * Address comments * Remove extra newline * Remove default value from backbone docstring * Rename checkpoints Co-authored-by: Matt Watson <[email protected]>
1 parent 9f9ac46 commit 5e7c490

File tree

8 files changed

+901
-13
lines changed

8 files changed

+901
-13
lines changed

keras_nlp/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from keras_nlp.models.bert.bert_preprocessor import BertPreprocessor
1818
from keras_nlp.models.bert.bert_tokenizer import BertTokenizer
1919
from keras_nlp.models.deberta.deberta_backbone import DebertaBackbone
20+
from keras_nlp.models.deberta.deberta_classifier import DebertaClassifier
2021
from keras_nlp.models.deberta.deberta_preprocessor import DebertaPreprocessor
2122
from keras_nlp.models.deberta.deberta_tokenizer import DebertaTokenizer
2223
from keras_nlp.models.distil_bert.distil_bert_backbone import DistilBertBackbone

keras_nlp/models/deberta/deberta_backbone.py

Lines changed: 62 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,19 @@
1414

1515
"""DeBERTa backbone model."""
1616

17+
import copy
18+
import os
19+
1720
import tensorflow as tf
1821
from tensorflow import keras
1922

23+
from keras_nlp.models.deberta.deberta_presets import backbone_presets
2024
from keras_nlp.models.deberta.disentangled_attention_encoder import (
2125
DisentangledAttentionEncoder,
2226
)
2327
from keras_nlp.models.deberta.relative_embedding import RelativeEmbedding
2428
from keras_nlp.utils.python_utils import classproperty
29+
from keras_nlp.utils.python_utils import format_docstring
2530

2631

2732
def deberta_kernel_initializer(stddev=0.02):
@@ -54,13 +59,12 @@ class DebertaBackbone(keras.Model):
5459
hidden_dim: int. The size of the transformer encoding layer.
5560
intermediate_dim: int. The output dimension of the first Dense layer in
5661
a two-layer feedforward network for each transformer.
57-
dropout: float, defaults to 0.1. Dropout probability for the
58-
DeBERTa model.
59-
max_sequence_length: int, defaults to 512. The maximum sequence length
60-
this encoder can consume. The sequence length of the input must be
61-
less than `max_sequence_length`.
62-
bucket_size: int, defaults to 256. The size of the relative position
63-
buckets. Generally equal to `max_sequence_length // 2`.
62+
dropout: float. Dropout probability for the DeBERTa model.
63+
max_sequence_length: int. The maximum sequence length this encoder can
64+
consume. The sequence length of the input must be less than
65+
`max_sequence_length`.
66+
bucket_size: int. The size of the relative position buckets. Generally
67+
equal to `max_sequence_length // 2`.
6468
6569
Example usage:
6670
```python
@@ -172,6 +176,7 @@ def __init__(
172176
self.dropout = dropout
173177
self.max_sequence_length = max_sequence_length
174178
self.bucket_size = bucket_size
179+
self.start_token_index = 0
175180

176181
def get_config(self):
177182
return {
@@ -193,13 +198,61 @@ def from_config(cls, config):
193198

194199
@classproperty
195200
def presets(cls):
196-
return {}
201+
return copy.deepcopy(backbone_presets)
197202

198203
@classmethod
204+
@format_docstring(names=", ".join(backbone_presets))
199205
def from_preset(
200206
cls,
201207
preset,
202208
load_weights=True,
203209
**kwargs,
204210
):
205-
raise NotImplementedError
211+
"""Instantiate DeBERTa model from preset architecture and weights.
212+
213+
Args:
214+
preset: string. Must be one of {{names}}.
215+
load_weights: Whether to load pre-trained weights into model.
216+
Defaults to `True`.
217+
218+
Examples:
219+
```python
220+
input_data = {
221+
"token_ids": tf.ones(shape=(1, 12), dtype=tf.int64),
222+
"padding_mask": tf.constant(
223+
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0], shape=(1, 12)
224+
),
225+
}
226+
227+
# Load architecture and weights from preset
228+
model = keras_nlp.models.DebertaBackbone.from_preset("deberta_base")
229+
output = model(input_data)
230+
231+
# Load randomly initialized model from preset architecture
232+
model = keras_nlp.models.DebertaBackbone.from_preset(
233+
"deberta_base", load_weights=False
234+
)
235+
output = model(input_data)
236+
```
237+
"""
238+
if preset not in cls.presets:
239+
raise ValueError(
240+
"`preset` must be one of "
241+
f"""{", ".join(cls.presets)}. Received: {preset}."""
242+
)
243+
metadata = cls.presets[preset]
244+
config = metadata["config"]
245+
model = cls.from_config({**config, **kwargs})
246+
247+
if not load_weights:
248+
return model
249+
250+
weights = keras.utils.get_file(
251+
"model.h5",
252+
metadata["weights_url"],
253+
cache_subdir=os.path.join("models", preset),
254+
file_hash=metadata["weights_hash"],
255+
)
256+
257+
model.load_weights(weights)
258+
return model
Lines changed: 270 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,270 @@
1+
# Copyright 2022 The KerasNLP Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""DeBERTa classification model."""
15+
16+
import copy
17+
18+
from tensorflow import keras
19+
20+
from keras_nlp.models.deberta.deberta_backbone import DebertaBackbone
21+
from keras_nlp.models.deberta.deberta_backbone import deberta_kernel_initializer
22+
from keras_nlp.models.deberta.deberta_preprocessor import DebertaPreprocessor
23+
from keras_nlp.models.deberta.deberta_presets import backbone_presets
24+
from keras_nlp.utils.pipeline_model import PipelineModel
25+
from keras_nlp.utils.python_utils import classproperty
26+
from keras_nlp.utils.python_utils import format_docstring
27+
28+
29+
@keras.utils.register_keras_serializable(package="keras_nlp")
30+
class DebertaClassifier(PipelineModel):
31+
"""An end-to-end DeBERTa model for classification tasks.
32+
33+
This model attaches a classification head to a
34+
`keras_nlp.model.DebertaBackbone` model, mapping from the backbone
35+
outputs to logit output suitable for a classification task. For usage of
36+
this model with pre-trained weights, see the `from_preset()` method.
37+
38+
This model can optionally be configured with a `preprocessor` layer, in
39+
which case it will automatically apply preprocessing to raw inputs during
40+
`fit()`, `predict()`, and `evaluate()`. This is done by default when
41+
creating the model with `from_preset()`.
42+
43+
Disclaimer: Pre-trained models are provided on an "as is" basis, without
44+
warranties or conditions of any kind.
45+
46+
Args:
47+
backbone: A `keras_nlp.models.Deberta` instance.
48+
num_classes: int. Number of classes to predict.
49+
hidden_dim: int. The size of the pooler layer.
50+
dropout: float. Dropout probability applied to the pooled output. For
51+
the second dropout layer, `backbone.dropout` is used.
52+
preprocessor: A `keras_nlp.models.DebertaPreprocessor` or `None`. If
53+
`None`, this model will not apply preprocessing, and inputs should
54+
be preprocessed before calling the model.
55+
56+
Example usage:
57+
```python
58+
preprocessed_features = {
59+
"token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
60+
"padding_mask": tf.constant(
61+
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(1, 12)),
62+
}
63+
labels = [0, 3]
64+
65+
# Randomly initialized DeBERTa encoder
66+
backbone = keras_nlp.models.DebertaBackbone(
67+
vocabulary_size=128100,
68+
num_layers=12,
69+
num_heads=12,
70+
hidden_dim=768,
71+
intermediate_dim=3072,
72+
max_sequence_length=12,
73+
bucket_size=6,
74+
)
75+
76+
# Create a DeBERTa classifier and fit your data.
77+
classifier = keras_nlp.models.DebertaClassifier(
78+
backbone,
79+
num_classes=4,
80+
preprocessor=None,
81+
)
82+
classifier.fit(x=preprocessed_features, y=labels, batch_size=2)
83+
84+
# Access backbone programatically (e.g., to change `trainable`)
85+
classifier.backbone.trainable = False
86+
```
87+
"""
88+
89+
def __init__(
90+
self,
91+
backbone,
92+
num_classes=2,
93+
hidden_dim=None,
94+
dropout=0.0,
95+
preprocessor=None,
96+
**kwargs,
97+
):
98+
inputs = backbone.input
99+
if hidden_dim is None:
100+
hidden_dim = backbone.hidden_dim
101+
102+
x = backbone(inputs)[:, backbone.start_token_index, :]
103+
x = keras.layers.Dropout(dropout, name="pooled_dropout")(x)
104+
x = keras.layers.Dense(
105+
hidden_dim,
106+
activation=lambda x: keras.activations.gelu(x, approximate=False),
107+
name="pooled_dense",
108+
)(x)
109+
x = keras.layers.Dropout(backbone.dropout, name="classifier_dropout")(x)
110+
outputs = keras.layers.Dense(
111+
num_classes,
112+
kernel_initializer=deberta_kernel_initializer(),
113+
name="logits",
114+
)(x)
115+
116+
# Instantiate using Functional API Model constructor
117+
super().__init__(
118+
inputs=inputs,
119+
outputs=outputs,
120+
include_preprocessing=preprocessor is not None,
121+
**kwargs,
122+
)
123+
# All references to `self` below this line
124+
self._backbone = backbone
125+
self._preprocessor = preprocessor
126+
self.num_classes = num_classes
127+
self.hidden_dim = hidden_dim
128+
self.dropout = dropout
129+
130+
def preprocess_samples(self, x, y=None, sample_weight=None):
131+
return self.preprocessor(x, y=y, sample_weight=sample_weight)
132+
133+
@property
134+
def backbone(self):
135+
"""A `keras_nlp.models.DebertaBackbone` submodel."""
136+
return self._backbone
137+
138+
@property
139+
def preprocessor(self):
140+
"""A `keras_nlp.models.DebertaPreprocessor` preprocessing layer."""
141+
return self._preprocessor
142+
143+
def get_config(self):
144+
return {
145+
"backbone": keras.layers.serialize(self.backbone),
146+
"preprocessor": keras.layers.serialize(self.preprocessor),
147+
"num_classes": self.num_classes,
148+
"hidden_dim": self.hidden_dim,
149+
"dropout": self.dropout,
150+
"name": self.name,
151+
"trainable": self.trainable,
152+
}
153+
154+
@classmethod
155+
def from_config(cls, config):
156+
if "backbone" in config and isinstance(config["backbone"], dict):
157+
config["backbone"] = keras.layers.deserialize(config["backbone"])
158+
if "preprocessor" in config and isinstance(
159+
config["preprocessor"], dict
160+
):
161+
config["preprocessor"] = keras.layers.deserialize(
162+
config["preprocessor"]
163+
)
164+
return cls(**config)
165+
166+
@classproperty
167+
def presets(cls):
168+
return copy.deepcopy(backbone_presets)
169+
170+
@classmethod
171+
@format_docstring(names=", ".join(backbone_presets))
172+
def from_preset(
173+
cls,
174+
preset,
175+
load_weights=True,
176+
**kwargs,
177+
):
178+
"""Create a classification model from a preset architecture and weights.
179+
180+
By default, this method will automatically create a `preprocessor`
181+
layer to preprocess raw inputs during `fit()`, `predict()`, and
182+
`evaluate()`. If you would like to disable this behavior, pass
183+
`preprocessor=None`.
184+
185+
Args:
186+
preset: string. Must be one of {{names}}.
187+
load_weights: Whether to load pre-trained weights into model.
188+
Defaults to `True`.
189+
190+
Examples:
191+
192+
Raw string inputs.
193+
```python
194+
# Create a dataset with raw string features in an `(x, y)` format.
195+
features = ["The quick brown fox jumped.", "I forgot my homework."]
196+
labels = [0, 3]
197+
198+
# Create a DebertaClassifier and fit your data.
199+
classifier = keras_nlp.models.DebertaClassifier.from_preset(
200+
"deberta_base",
201+
num_classes=4,
202+
)
203+
classifier.compile(
204+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
205+
)
206+
classifier.fit(x=features, y=labels, batch_size=2)
207+
```
208+
209+
Raw string inputs with customized preprocessing.
210+
```python
211+
# Create a dataset with raw string features in an `(x, y)` format.
212+
features = ["The quick brown fox jumped.", "I forgot my homework."]
213+
labels = [0, 3]
214+
215+
# Use a shorter sequence length.
216+
preprocessor = keras_nlp.models.DebertaPreprocessor.from_preset(
217+
"deberta_base",
218+
sequence_length=128,
219+
)
220+
221+
# Create a DebertaClassifier and fit your data.
222+
classifier = keras_nlp.models.DebertaClassifier.from_preset(
223+
"deberta_base",
224+
num_classes=4,
225+
preprocessor=preprocessor,
226+
)
227+
classifier.compile(
228+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
229+
)
230+
classifier.fit(x=features, y=labels, batch_size=2)
231+
```
232+
233+
Preprocessed inputs.
234+
```python
235+
# Create a dataset with preprocessed features in an `(x, y)` format.
236+
preprocessed_features = {
237+
"token_ids": tf.ones(shape=(2, 12), dtype=tf.int64),
238+
"padding_mask": tf.constant(
239+
[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0]] * 2, shape=(2, 12)
240+
),
241+
}
242+
labels = [0, 3]
243+
244+
# Create a DebertaClassifier and fit your data.
245+
classifier = keras_nlp.models.DebertaClassifier.from_preset(
246+
"deberta_base",
247+
num_classes=4,
248+
preprocessor=None,
249+
)
250+
classifier.compile(
251+
loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
252+
)
253+
classifier.fit(x=preprocessed_features, y=labels, batch_size=2)
254+
```
255+
"""
256+
if "preprocessor" not in kwargs:
257+
kwargs["preprocessor"] = DebertaPreprocessor.from_preset(preset)
258+
259+
# Check if preset is backbone-only model
260+
if preset in DebertaBackbone.presets:
261+
backbone = DebertaBackbone.from_preset(preset, load_weights)
262+
return cls(backbone, **kwargs)
263+
264+
# Otherwise must be one of class presets
265+
# Currently no classifier-level presets, so must throw.
266+
if preset not in cls.presets:
267+
raise ValueError(
268+
"`preset` must be one of "
269+
f"""{", ".join(cls.presets)}. Received: {preset}."""
270+
)

0 commit comments

Comments
 (0)