@@ -43,7 +43,8 @@ def __init__(self,
43
43
hidden_dropout_prob = 0.1 ,
44
44
max_position_embeddings = 512 ,
45
45
type_vocab_size = 16 ,
46
- pad_token_id = 0 ):
46
+ pad_token_id = 0 ,
47
+ cls_token_id = 101 ):
47
48
super (RobertaEmbeddings , self ).__init__ ()
48
49
self .word_embeddings = nn .Embedding (
49
50
vocab_size , hidden_size , padding_idx = pad_token_id )
@@ -53,14 +54,14 @@ def __init__(self,
53
54
self .layer_norm = nn .LayerNorm (hidden_size )
54
55
self .dropout = nn .Dropout (hidden_dropout_prob )
55
56
self .padding_idx = pad_token_id
56
-
57
+ self .cls_token_id = cls_token_id
58
+
57
59
def forward (self , input_ids , token_type_ids = None , position_ids = None ):
58
60
if position_ids is None :
59
61
# maybe need use shape op to unify static graph and dynamic graph
60
62
ones = paddle .ones_like (input_ids , dtype = "int64" )
61
63
seq_length = paddle .cumsum (ones , axis = - 1 )
62
- cls_token_id = input_ids [0 ][0 ]
63
- if cls_token_id == 0 : # postion_ids for RobertaBPETokenizer
64
+ if self .cls_token_id == 0 or input_ids [0 ][0 ] == 0 : # postion_ids for RobertaBPETokenizer
64
65
position_ids = seq_length + self .padding_idx + 1 - ones
65
66
else : # postion_ids for RobertaTokenizer
66
67
position_ids = seq_length - ones
@@ -264,14 +265,15 @@ def __init__(self,
264
265
type_vocab_size = 16 ,
265
266
initializer_range = 0.02 ,
266
267
pad_token_id = 0 ,
267
- layer_norm_eps = 1e-12 ):
268
+ layer_norm_eps = 1e-12 ,
269
+ cls_token_id = 101 ):
268
270
super (RobertaModel , self ).__init__ ()
269
271
self .pad_token_id = pad_token_id
270
272
self .initializer_range = initializer_range
271
273
self .layer_norm_eps = layer_norm_eps
272
274
self .embeddings = RobertaEmbeddings (
273
275
vocab_size , hidden_size , hidden_dropout_prob ,
274
- max_position_embeddings , type_vocab_size , pad_token_id )
276
+ max_position_embeddings , type_vocab_size , pad_token_id , cls_token_id )
275
277
encoder_layer = nn .TransformerEncoderLayer (
276
278
hidden_size ,
277
279
num_attention_heads ,
0 commit comments