Skip to content

Commit a8a7011

Browse files
modefied_for_test
1 parent aaae88f commit a8a7011

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

paddlenlp/prompt/prompt_model.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -263,13 +263,12 @@ def forward(
263263
hidden_states=model_outputs.logits,
264264
)
265265

266-
def generate(self, model_kwargs):
266+
def generate(self, model_kwargs, **kwargs):
267267
self.plm.prepare_inputs_for_generation = self.prepare_inputs_for_generation
268-
generated_tokens = self.plm.generate(**model_kwargs)
268+
generated_tokens = self.plm.generate(**model_kwargs, **kwargs)
269269
return generated_tokens
270270

271271
def prepare_inputs_for_generation(self, input_ids, use_cache=False, cache=None, **kwargs):
272-
273272
model_kwargs = self.base_model_prepare_inputs_for_generation(input_ids, cache=None, **kwargs)
274273
model_kwargs["soft_token_ids"] = kwargs.get("soft_token_ids", None)
275274
model_kwargs["token_type_ids"] = kwargs.get("token_type_ids", None)

paddlenlp/prompt/prompt_tokenizer.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ def __call__(self, inputs: List[Dict[str, Any]]):
4343
# Create input_ids.
4444
soft_token_ids = part.get("soft_tokens", None)
4545
if soft_token_ids is None or len(soft_token_ids) == 1 and soft_token_ids[0] == 0:
46+
if "generator_labels" in part:
47+
# import pdb; pdb.set_trace()
48+
encoded_inputs["labels"].append(
49+
self.tokenizer.encode(
50+
part["generator_labels"], add_special_tokens=False, return_token_type_ids=False
51+
)["input_ids"]
52+
)
53+
inputs.remove(part)
54+
continue
4655
orig_input_ids.append(
4756
self.tokenizer.encode(part["text"], add_special_tokens=False, return_token_type_ids=False)[
4857
"input_ids"
@@ -61,8 +70,6 @@ def __call__(self, inputs: List[Dict[str, Any]]):
6170
else:
6271
input_ids = orig_input_ids[index][: max_lengths[index]]
6372
encoded_inputs["soft_token_ids"].append([0] * len(input_ids))
64-
if part["token_types"] == 1:
65-
encoded_inputs["labels"].append(input_ids)
6673
else:
6774
input_ids = soft_token_ids
6875
encoded_inputs["soft_token_ids"].append(soft_token_ids)

paddlenlp/prompt/template.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -251,7 +251,8 @@ def encode(self, example: Dict[str, Any]):
251251
inputs = []
252252
for value in list(zip(*input_values)):
253253
inputs.append(dict(zip(input_names, value)))
254-
254+
if "labels" in example and isinstance(example["labels"], str):
255+
inputs.append({"generator_labels": example["labels"], "do_truncate": False})
255256
input_dict = self.prompt_tokenizer(inputs)
256257
unused_example = {k: v for k, v in example.items() if k not in self.example_keys}
257258

0 commit comments

Comments
 (0)