Skip to content

Commit e168554

Browse files
committed
Added pnc restoration
Signed-off-by: Sasha Meister <sasha.meister.work@gmail.com>
1 parent 2c12847 commit e168554

File tree

2 files changed

+129
-0
lines changed
  • dataset_configs/multilingual/yodas2/prompts/pnc_restoration
  • sdp/processors/huggingface/transformers

2 files changed

+129
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
system: |
2+
Description:
3+
You have a transcript that may contain punctuation and capitalization, may not contain them, or may contain incorrect punctuation. The task is to bring the text to the correct form by restoring punctuation and capitalization, ensuring the following rules:
4+
5+
Rules:
6+
- "Do not change, add, or remove any words in the text. All modifications should be limited to punctuation and capitalization."
7+
- "Restore the correct punctuation using only periods, commas, and question marks. All other symbols (including exclamation marks, colons, semicolons, quotes, parentheses, emojis, etc.) must be removed or replaced with allowed punctuation marks."
8+
- "If the text already contains sufficient punctuation (periods, commas, and question marks), it should remain unchanged."
9+
- "If punctuation is incomplete, incorrect, or contains invalid symbols (e.g., exclamation marks, ellipses, or other unnecessary symbols), it should be corrected to the proper form using only periods, commas, and question marks."
10+
- "Punctuation must match the context: if the sentence is a question, use a question mark at the end. In other cases, use a period or comma if needed to separate parts of the sentence."
11+
- "All alphanumeric characters (including digits, e.g., 3:30pm) should remain unchanged."
12+
- "Capitalize the first letter of each sentence."
13+
- "Capitalize proper nouns and abbreviations."
14+
- "If the text starts in the middle of a sentence or ends in the middle of a word, do not capitalize the first letter or add a period at the end."
15+
- "If punctuation is missing or incorrect, replace invalid symbols with valid punctuation (period, comma, or question mark) without changing the meaning of the text."
16+
17+
Examples:
18+
- input: "the quick brown fox jumped over the lazy dog"
19+
output: "The quick brown fox jumped over the lazy dog."
20+
21+
- input: "hello how are you today :-) I hope you're doing well :)"
22+
output: "Hello, how are you today? I hope you're doing well."
23+
24+
- input: "She went to the store; then she bought some bread."
25+
output: "She went to the store, then she bought some bread."
26+
27+
- input: "I can't believe this...!!! This is so exciting!!!"
28+
output: "I can't believe this. This is so exciting."
29+
30+
- input: "Do you know where the keys are I can't find them anywhere"
31+
output: "Do you know where the keys are? I can't find them anywhere."
32+
33+
- input: "the meeting is at 3:30pm, we should prepare by 3:00."
34+
output: "The meeting is at 3:30pm, we should prepare by 3:00."
35+
36+
- input: "this is a great idea, but we need more details."
37+
output: "This is a great idea, but we need more details."
38+
39+
- input: "my friend, john, is visiting new york next week."
40+
output: "My friend, John, is visiting New York next week."
41+
42+
- input: "we need to finish the project by friday, but I am not sure about the deadline yet."
43+
output: "We need to finish the project by Friday, but I am not sure about the deadline yet."
44+
45+
- input: "the report was almost done, but"
46+
output: "The report was almost done, but"
47+
48+
user: |
49+
Input transcript: {pred_text}
50+
51+
generation: |
52+
Output transcript:
53+
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import yaml
2+
import json
3+
from tqdm import tqdm
4+
from transformers import AutoModelForCausalLM, AutoConfig, AutoTokenizer
5+
6+
from sdp.logging import logger
7+
from sdp.processors.base_processor import BaseProcessor
8+
9+
class AutoModelForCausalLMProcessor(BaseProcessor):
10+
def __init__(self,
11+
input_manifest_file: str,
12+
output_manifest_file: str,
13+
prompt_file: str,
14+
output_field: str = 'generation',
15+
**kwargs):
16+
super().__init__(
17+
input_manifest_file=input_manifest_file,
18+
output_manifest_file=output_manifest_file,
19+
)
20+
21+
self.prompt_file = prompt_file
22+
self.prompt = None
23+
24+
self.cfg = kwargs['model']
25+
self.model_cfg = AutoConfig.from_pretrained(**self.cfg)
26+
27+
self.output_field = output_field
28+
29+
def read_prompt_file(self):
30+
with open(self.prompt_file, 'r') as prompt:
31+
self.prompt = yaml.safe_load(prompt)
32+
33+
def build_entry_prompt(self, data_entry):
34+
entry_prompt = []
35+
for role in self.prompt:
36+
entry_prompt.append(dict(
37+
role=role,
38+
content=self.prompt[role].format(**data_entry)
39+
))
40+
return entry_prompt
41+
42+
def process(self):
43+
logger.info(f'Reading prompt: ')
44+
self.read_prompt_file()
45+
logger.info(f'Prompt:\n{yaml.dump(self.prompt, default_flow_style=False)}\n')
46+
47+
logger.info(f'Loading model:')
48+
model = AutoModelForCausalLM.from_config(self.model_cfg)
49+
tokenizer = AutoTokenizer.from_pretrained(self.cfg.pretrained_model_name_or_path)
50+
51+
with open(self.input_manifest_file, 'r', encoding='utf8') as fin, open(self.output_manifest_file, 'w', encoding='utf8') as fout:
52+
for line in tqdm(fin, desc = "Generation: "):
53+
data_entry = json.loads(line)
54+
entry_prompt = self.build_entry_prompt(data_entry)
55+
text = tokenizer.apply_chat_template(
56+
entry_prompt,
57+
tokenize=False,
58+
add_generation_prompt=True
59+
)
60+
61+
model_inputs = tokenizer([text], return_tensors="pt").to(model.device)
62+
63+
generated_ids = model.generate(
64+
**model_inputs,
65+
max_new_tokens=512
66+
)
67+
68+
generated_ids = [
69+
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
70+
]
71+
72+
response = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
73+
74+
data_entry[self.output_field] = response
75+
line = json.dumps(data_entry)
76+
fout.writelines(f'{line}\n')

0 commit comments

Comments
 (0)