24
24
from paddle .nn import BCEWithLogitsLoss , CrossEntropyLoss , MSELoss
25
25
from paddle .nn .layer .transformer import _convert_param_attr_to_list
26
26
27
+ from ...utils .log import logger
27
28
from .. import PretrainedModel , register_base_model
28
29
from ..model_outputs import (
29
30
BaseModelOutputWithPastAndCrossAttentions ,
@@ -425,7 +426,7 @@ def __init__(
425
426
def forward (self , input_ids , position_ids = None , inputs_embeddings = None ):
426
427
if input_ids is not None :
427
428
input_shape = paddle .shape (input_ids )
428
- input_embeddings = self .word_embeddings (input_ids )
429
+ inputs_embeddings = self .word_embeddings (input_ids )
429
430
else :
430
431
input_shape = paddle .shape (inputs_embeddings )[:- 1 ]
431
432
@@ -435,7 +436,7 @@ def forward(self, input_ids, position_ids=None, inputs_embeddings=None):
435
436
position_ids = seq_length - ones
436
437
437
438
position_embeddings = self .position_embeddings (position_ids )
438
- embeddings = input_embeddings + position_embeddings
439
+ embeddings = inputs_embeddings + position_embeddings
439
440
embeddings = self .dropout (embeddings )
440
441
return embeddings
441
442
@@ -851,7 +852,7 @@ def forward(
851
852
past_length = 0
852
853
if cache is not None :
853
854
past_length = paddle .shape (cache [0 ].k )[- 2 ]
854
- position_ids = paddle .arange (past_length , input_shape [- 1 ] + past_length , dtype = input_ids . dtype )
855
+ position_ids = paddle .arange (past_length , input_shape [- 1 ] + past_length , dtype = "int64" )
855
856
position_ids = position_ids .unsqueeze (0 )
856
857
# .expand_as(input_ids)
857
858
position_ids = paddle .expand (position_ids , input_shape )
@@ -860,7 +861,7 @@ def forward(
860
861
)
861
862
862
863
# TODO, use registered buffer
863
- length = paddle . shape ( input_ids ) [- 1 ]
864
+ length = input_shape [- 1 ]
864
865
if cache is not None :
865
866
cache_length = paddle .shape (cache [0 ].k )[2 ]
866
867
length = length + cache_length
@@ -1177,6 +1178,7 @@ def forward(
1177
1178
Especialy, when `return_dict=use_cache=output_attentions=output_hidden_states=False`,
1178
1179
returns a tensor `logits` which is the output of the gpt model.
1179
1180
"""
1181
+ input_type = type (input_ids ) if input_ids is not None else type (inputs_embeds )
1180
1182
outputs = self .gpt (
1181
1183
input_ids ,
1182
1184
position_ids = position_ids ,
@@ -1188,7 +1190,7 @@ def forward(
1188
1190
output_hidden_states = output_hidden_states ,
1189
1191
return_dict = return_dict ,
1190
1192
)
1191
- if isinstance (outputs , type ( input_ids ) ):
1193
+ if isinstance (outputs , input_type ):
1192
1194
hidden_states = outputs
1193
1195
else :
1194
1196
hidden_states = outputs [0 ]
@@ -1206,7 +1208,7 @@ def forward(
1206
1208
1207
1209
# outputs = [output, all_hidden_states, new_caches, all_self_attentions]
1208
1210
if not return_dict :
1209
- if isinstance (outputs , type ( input_ids ) ):
1211
+ if isinstance (outputs , input_type ):
1210
1212
return (loss , logits ) if loss is not None else logits
1211
1213
1212
1214
outputs = (logits ,) + outputs [1 :]
@@ -1370,6 +1372,7 @@ def forward(
1370
1372
logits = model(**inputs)
1371
1373
1372
1374
"""
1375
+ input_type = type (input_ids ) if input_ids is not None else type (inputs_embeds )
1373
1376
sequence_output = self .gpt (
1374
1377
input_ids ,
1375
1378
position_ids = position_ids ,
@@ -1379,7 +1382,7 @@ def forward(
1379
1382
output_hidden_states = output_hidden_states ,
1380
1383
return_dict = return_dict ,
1381
1384
)
1382
- if isinstance (sequence_output , type ( input_ids ) ):
1385
+ if isinstance (sequence_output , input_type ):
1383
1386
hidden_states = sequence_output
1384
1387
else :
1385
1388
hidden_states = sequence_output [0 ]
@@ -1392,7 +1395,7 @@ def forward(
1392
1395
loss = loss_fct (logits .reshape ((- 1 , self .num_classes )), labels .reshape ((- 1 ,)))
1393
1396
1394
1397
if not return_dict :
1395
- if isinstance (sequence_output , type ( input_ids ) ):
1398
+ if isinstance (sequence_output , input_type ):
1396
1399
return (loss , logits ) if loss is not None else logits
1397
1400
1398
1401
outputs = (logits ,) + sequence_output [1 :]
@@ -1488,7 +1491,7 @@ def forward(
1488
1491
logits = model(**inputs)
1489
1492
1490
1493
"""
1491
-
1494
+ input_type = type ( input_ids ) if input_ids is not None else type ( inputs_embeds )
1492
1495
# sequence_output shape [bs, seq_len, hidden_size]
1493
1496
sequence_output = self .gpt (
1494
1497
input_ids ,
@@ -1500,7 +1503,7 @@ def forward(
1500
1503
output_hidden_states = output_hidden_states ,
1501
1504
return_dict = return_dict ,
1502
1505
)
1503
- if isinstance (sequence_output , type ( input_ids ) ):
1506
+ if isinstance (sequence_output , input_type ):
1504
1507
hidden_states = sequence_output
1505
1508
else :
1506
1509
hidden_states = sequence_output [0 ]
@@ -1509,7 +1512,15 @@ def forward(
1509
1512
# padding index maybe 0
1510
1513
eos_token_id = self .gpt .config .get ("eos_token_id" , 0 )
1511
1514
# sequence_lengths shape [bs,]
1512
- sequence_lengths = (input_ids != eos_token_id ).astype ("int64" ).sum (axis = - 1 ) - 1
1515
+ if input_ids is not None :
1516
+ sequence_lengths = (input_ids != eos_token_id ).astype ("int64" ).sum (axis = - 1 ) - 1
1517
+ else :
1518
+ inputs_shape = paddle .shape (inputs_embeds )[:- 1 ]
1519
+ sequence_lengths = paddle .ones (inputs_shape [:- 1 ], dtype = "int64" ) * (inputs_shape [1 ] - 1 )
1520
+ logger .warning (
1521
+ f"{ self .__class__ .__name__ } will not detect padding tokens in `inputs_embeds`. Results may be "
1522
+ "unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1523
+ )
1513
1524
1514
1525
pooled_logits = logits .gather_nd (paddle .stack ([paddle .arange (logits .shape [0 ]), sequence_lengths ], axis = - 1 ))
1515
1526
@@ -1526,7 +1537,7 @@ def forward(
1526
1537
loss = loss_fct (pooled_logits , labels )
1527
1538
1528
1539
if not return_dict :
1529
- if isinstance (sequence_output , type ( input_ids ) ):
1540
+ if isinstance (sequence_output , input_type ):
1530
1541
return (loss , pooled_logits ) if loss is not None else pooled_logits
1531
1542
1532
1543
outputs = (pooled_logits ,) + sequence_output [1 :]
0 commit comments