Skip to content

Commit ff70faa

Browse files
authored
[prompt] fix missing args in autotemplate and labels in trainer (#4293)
1 parent ff3c0c9 commit ff70faa

File tree

2 files changed

+11
-5
lines changed

2 files changed

+11
-5
lines changed

paddlenlp/prompt/prompt_trainer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def compute_loss(self, model, inputs, return_outputs=False):
235235
loss = self.criterion(logits, labels)
236236

237237
if self.args.use_rdrop:
238-
loss = self._compute_rdrop_loss(model, input_dict, logits, loss)
238+
loss = self._compute_rdrop_loss(model, input_dict, labels, logits, loss)
239239

240240
if self.args.use_rgl:
241241
loss += self._compute_rgl_loss(hidden_states, labels)
@@ -246,9 +246,8 @@ def compute_loss(self, model, inputs, return_outputs=False):
246246

247247
return (loss, outputs) if return_outputs else loss
248248

249-
def _compute_rdrop_loss(self, model, input_dict, outputs, loss):
249+
def _compute_rdrop_loss(self, model, input_dict, labels, outputs, loss):
250250
re_outputs, _ = model(**input_dict)
251-
labels = input_dict["labels"]
252251
ce_loss = (self.criterion(re_outputs, labels) + loss) * 0.5
253252
kl_loss = self.rdrop_criterion(outputs, re_outputs)
254253
loss = ce_loss + self.args.alpha_rdrop * kl_loss

paddlenlp/prompt/template.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -778,6 +778,7 @@ def create_from(
778778
max_length: int = 512,
779779
model: PretrainedModel = None,
780780
soft_embeddings: Tensor = None,
781+
prefix_dropout: float = 0.1,
781782
):
782783
# Default template if not defined.
783784
if prompt is None:
@@ -795,11 +796,17 @@ def create_from(
795796

796797
# Choose Template according to template keywords.
797798
if "prefix" in template_keywords:
798-
return PrefixTemplate(prompt=prompt, tokenizer=tokenizer, max_length=max_length, model=model)
799+
return PrefixTemplate(
800+
prompt=prompt, tokenizer=tokenizer, max_length=max_length, model=model, prefix_dropout=prefix_dropout
801+
)
799802
elif "soft" in template_keywords or "soft_id" in template_keywords:
800803
word_embeddings = model.get_input_embeddings()
801804
return SoftTemplate(
802-
prompt=prompt, tokenizer=tokenizer, max_length=max_length, word_embeddings=word_embeddings
805+
prompt=prompt,
806+
tokenizer=tokenizer,
807+
max_length=max_length,
808+
word_embeddings=word_embeddings,
809+
soft_embeddings=soft_embeddings,
803810
)
804811
else:
805812
return ManualTemplate(prompt=prompt, tokenizer=tokenizer, max_length=max_length)

0 commit comments

Comments
 (0)