Skip to content

Commit 1c7ddf0

Browse files
authored
fix gpt inputs embeds (#4179)
1 parent d45864d commit 1c7ddf0

File tree

2 files changed

+66
-12
lines changed

2 files changed

+66
-12
lines changed

paddlenlp/transformers/gpt/modeling.py

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from paddle.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
2525
from paddle.nn.layer.transformer import _convert_param_attr_to_list
2626

27+
from ...utils.log import logger
2728
from .. import PretrainedModel, register_base_model
2829
from ..model_outputs import (
2930
BaseModelOutputWithPastAndCrossAttentions,
@@ -425,7 +426,7 @@ def __init__(
425426
def forward(self, input_ids, position_ids=None, inputs_embeddings=None):
426427
if input_ids is not None:
427428
input_shape = paddle.shape(input_ids)
428-
input_embeddings = self.word_embeddings(input_ids)
429+
inputs_embeddings = self.word_embeddings(input_ids)
429430
else:
430431
input_shape = paddle.shape(inputs_embeddings)[:-1]
431432

@@ -435,7 +436,7 @@ def forward(self, input_ids, position_ids=None, inputs_embeddings=None):
435436
position_ids = seq_length - ones
436437

437438
position_embeddings = self.position_embeddings(position_ids)
438-
embeddings = input_embeddings + position_embeddings
439+
embeddings = inputs_embeddings + position_embeddings
439440
embeddings = self.dropout(embeddings)
440441
return embeddings
441442

@@ -851,7 +852,7 @@ def forward(
851852
past_length = 0
852853
if cache is not None:
853854
past_length = paddle.shape(cache[0].k)[-2]
854-
position_ids = paddle.arange(past_length, input_shape[-1] + past_length, dtype=input_ids.dtype)
855+
position_ids = paddle.arange(past_length, input_shape[-1] + past_length, dtype="int64")
855856
position_ids = position_ids.unsqueeze(0)
856857
# .expand_as(input_ids)
857858
position_ids = paddle.expand(position_ids, input_shape)
@@ -860,7 +861,7 @@ def forward(
860861
)
861862

862863
# TODO, use registered buffer
863-
length = paddle.shape(input_ids)[-1]
864+
length = input_shape[-1]
864865
if cache is not None:
865866
cache_length = paddle.shape(cache[0].k)[2]
866867
length = length + cache_length
@@ -1177,6 +1178,7 @@ def forward(
11771178
Especialy, when `return_dict=use_cache=output_attentions=output_hidden_states=False`,
11781179
returns a tensor `logits` which is the output of the gpt model.
11791180
"""
1181+
input_type = type(input_ids) if input_ids is not None else type(inputs_embeds)
11801182
outputs = self.gpt(
11811183
input_ids,
11821184
position_ids=position_ids,
@@ -1188,7 +1190,7 @@ def forward(
11881190
output_hidden_states=output_hidden_states,
11891191
return_dict=return_dict,
11901192
)
1191-
if isinstance(outputs, type(input_ids)):
1193+
if isinstance(outputs, input_type):
11921194
hidden_states = outputs
11931195
else:
11941196
hidden_states = outputs[0]
@@ -1206,7 +1208,7 @@ def forward(
12061208

12071209
# outputs = [output, all_hidden_states, new_caches, all_self_attentions]
12081210
if not return_dict:
1209-
if isinstance(outputs, type(input_ids)):
1211+
if isinstance(outputs, input_type):
12101212
return (loss, logits) if loss is not None else logits
12111213

12121214
outputs = (logits,) + outputs[1:]
@@ -1370,6 +1372,7 @@ def forward(
13701372
logits = model(**inputs)
13711373
13721374
"""
1375+
input_type = type(input_ids) if input_ids is not None else type(inputs_embeds)
13731376
sequence_output = self.gpt(
13741377
input_ids,
13751378
position_ids=position_ids,
@@ -1379,7 +1382,7 @@ def forward(
13791382
output_hidden_states=output_hidden_states,
13801383
return_dict=return_dict,
13811384
)
1382-
if isinstance(sequence_output, type(input_ids)):
1385+
if isinstance(sequence_output, input_type):
13831386
hidden_states = sequence_output
13841387
else:
13851388
hidden_states = sequence_output[0]
@@ -1392,7 +1395,7 @@ def forward(
13921395
loss = loss_fct(logits.reshape((-1, self.num_classes)), labels.reshape((-1,)))
13931396

13941397
if not return_dict:
1395-
if isinstance(sequence_output, type(input_ids)):
1398+
if isinstance(sequence_output, input_type):
13961399
return (loss, logits) if loss is not None else logits
13971400

13981401
outputs = (logits,) + sequence_output[1:]
@@ -1488,7 +1491,7 @@ def forward(
14881491
logits = model(**inputs)
14891492
14901493
"""
1491-
1494+
input_type = type(input_ids) if input_ids is not None else type(inputs_embeds)
14921495
# sequence_output shape [bs, seq_len, hidden_size]
14931496
sequence_output = self.gpt(
14941497
input_ids,
@@ -1500,7 +1503,7 @@ def forward(
15001503
output_hidden_states=output_hidden_states,
15011504
return_dict=return_dict,
15021505
)
1503-
if isinstance(sequence_output, type(input_ids)):
1506+
if isinstance(sequence_output, input_type):
15041507
hidden_states = sequence_output
15051508
else:
15061509
hidden_states = sequence_output[0]
@@ -1509,7 +1512,15 @@ def forward(
15091512
# padding index maybe 0
15101513
eos_token_id = self.gpt.config.get("eos_token_id", 0)
15111514
# sequence_lengths shape [bs,]
1512-
sequence_lengths = (input_ids != eos_token_id).astype("int64").sum(axis=-1) - 1
1515+
if input_ids is not None:
1516+
sequence_lengths = (input_ids != eos_token_id).astype("int64").sum(axis=-1) - 1
1517+
else:
1518+
inputs_shape = paddle.shape(inputs_embeds)[:-1]
1519+
sequence_lengths = paddle.ones(inputs_shape[:-1], dtype="int64") * (inputs_shape[1] - 1)
1520+
logger.warning(
1521+
f"{self.__class__.__name__} will not detect padding tokens in `inputs_embeds`. Results may be "
1522+
"unexpected if using padding tokens in conjunction with `inputs_embeds.`"
1523+
)
15131524

15141525
pooled_logits = logits.gather_nd(paddle.stack([paddle.arange(logits.shape[0]), sequence_lengths], axis=-1))
15151526

@@ -1526,7 +1537,7 @@ def forward(
15261537
loss = loss_fct(pooled_logits, labels)
15271538

15281539
if not return_dict:
1529-
if isinstance(sequence_output, type(input_ids)):
1540+
if isinstance(sequence_output, input_type):
15301541
return (loss, pooled_logits) if loss is not None else pooled_logits
15311542

15321543
outputs = (pooled_logits,) + sequence_output[1:]

tests/transformers/gpt/test_modeling.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import copy
1617
import datetime
1718
import math
1819
import random
@@ -385,6 +386,13 @@ def prepare_config_and_inputs_for_common(self):
385386

386387
return config, inputs_dict
387388

389+
def prepare_config_and_inputs_for_gpt(self):
390+
config = self.get_config()
391+
# excluding eos_token_id which is equal to vocab_size - 1
392+
input_ids = ids_tensor([self.batch_size, self.seq_length], self.vocab_size - 1, dtype="int64")
393+
inputs_dict = {"input_ids": input_ids}
394+
return config, inputs_dict
395+
388396

389397
@parameterized_class(
390398
("return_dict", "use_labels"),
@@ -450,6 +458,41 @@ def test_gpt_weight_initialization(self):
450458
config_and_inputs = self.model_tester.prepare_config_and_inputs()
451459
self.model_tester.create_and_check_gpt_weight_initialization(*config_and_inputs)
452460

461+
def test_inputs_embeds(self):
462+
# NOTE: rewrite test inputs embeds for gpt model since couldn't detect eos token id from inputs_embeds
463+
# get config for model and inputs_dict for model forward
464+
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_gpt()
465+
# test all model classes
466+
for model_class in self.all_model_classes:
467+
model = self._make_model_instance(config, model_class)
468+
model.eval()
469+
470+
inputs = copy.deepcopy(self._prepare_for_class(inputs_dict, model_class))
471+
472+
with paddle.no_grad():
473+
ids_output = model(**inputs)
474+
475+
if not self.is_encoder_decoder:
476+
input_ids = inputs["input_ids"]
477+
del inputs["input_ids"]
478+
else:
479+
encoder_input_ids = inputs["input_ids"]
480+
decoder_input_ids = inputs.get("decoder_input_ids", encoder_input_ids)
481+
del inputs["input_ids"]
482+
inputs.pop("decoder_input_ids", None)
483+
484+
wte = model.get_input_embeddings()
485+
if not self.is_encoder_decoder:
486+
inputs["inputs_embeds"] = wte(input_ids)
487+
else:
488+
inputs["inputs_embeds"] = wte(encoder_input_ids)
489+
inputs["decoder_inputs_embeds"] = wte(decoder_input_ids)
490+
491+
with paddle.no_grad():
492+
embeds_output = model(**inputs)
493+
494+
self.assertTrue(paddle.allclose(ids_output, embeds_output, rtol=1e-4, atol=1e-4))
495+
453496
@slow
454497
def test_batch_generation(self):
455498
model = GPTLMHeadModel.from_pretrained("gpt2-en")

0 commit comments

Comments
 (0)