Skip to content

Commit 1a9a48e

Browse files
tanhanzhuojoey12300yingyibiao
authored
Add cls_token_id (#1777)
cls_token_id = input_ids[0][0] is not safe, add cls_token_id as initialization parameter for embedding class Co-authored-by: Jack Zhou <[email protected]> Co-authored-by: yingyibiao <[email protected]>
1 parent e93f899 commit 1a9a48e

File tree

1 file changed

+8
-6
lines changed

1 file changed

+8
-6
lines changed

paddlenlp/transformers/roberta/modeling.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ def __init__(self,
4343
hidden_dropout_prob=0.1,
4444
max_position_embeddings=512,
4545
type_vocab_size=16,
46-
pad_token_id=0):
46+
pad_token_id=0,
47+
cls_token_id=101):
4748
super(RobertaEmbeddings, self).__init__()
4849
self.word_embeddings = nn.Embedding(
4950
vocab_size, hidden_size, padding_idx=pad_token_id)
@@ -53,14 +54,14 @@ def __init__(self,
5354
self.layer_norm = nn.LayerNorm(hidden_size)
5455
self.dropout = nn.Dropout(hidden_dropout_prob)
5556
self.padding_idx = pad_token_id
56-
57+
self.cls_token_id = cls_token_id
58+
5759
def forward(self, input_ids, token_type_ids=None, position_ids=None):
5860
if position_ids is None:
5961
# maybe need use shape op to unify static graph and dynamic graph
6062
ones = paddle.ones_like(input_ids, dtype="int64")
6163
seq_length = paddle.cumsum(ones, axis=-1)
62-
cls_token_id = input_ids[0][0]
63-
if cls_token_id == 0: # postion_ids for RobertaBPETokenizer
64+
if self.cls_token_id == 0 or input_ids[0][0] == 0: # postion_ids for RobertaBPETokenizer
6465
position_ids = seq_length + self.padding_idx + 1 - ones
6566
else: # postion_ids for RobertaTokenizer
6667
position_ids = seq_length - ones
@@ -264,14 +265,15 @@ def __init__(self,
264265
type_vocab_size=16,
265266
initializer_range=0.02,
266267
pad_token_id=0,
267-
layer_norm_eps=1e-12):
268+
layer_norm_eps=1e-12,
269+
cls_token_id=101):
268270
super(RobertaModel, self).__init__()
269271
self.pad_token_id = pad_token_id
270272
self.initializer_range = initializer_range
271273
self.layer_norm_eps = layer_norm_eps
272274
self.embeddings = RobertaEmbeddings(
273275
vocab_size, hidden_size, hidden_dropout_prob,
274-
max_position_embeddings, type_vocab_size, pad_token_id)
276+
max_position_embeddings, type_vocab_size, pad_token_id, cls_token_id)
275277
encoder_layer = nn.TransformerEncoderLayer(
276278
hidden_size,
277279
num_attention_heads,

0 commit comments

Comments
 (0)