diff --git a/sms_wsj/database/wsj/create_json.py b/sms_wsj/database/wsj/create_json.py index 4b89a47..7912526 100644 --- a/sms_wsj/database/wsj/create_json.py +++ b/sms_wsj/database/wsj/create_json.py @@ -8,6 +8,7 @@ import os import re import tempfile +import subprocess from pathlib import Path import sacred @@ -182,18 +183,20 @@ def normalize_transcription(transcriptions, wsj_root: Path): :param wsj_root: Path to WSJ database :return result: Clean transcription dictionary + + >>> transcriptions = {'ID1': 'Hello World, and bye!?', 'ID2': 'What? Yes.'} + >>> normalize_transcription(transcriptions, '') + {'ID1': 'HELLO WORLD, AND BYE!?', 'ID2': 'WHAT? YES.'} + """ assert len(transcriptions) > 0, 'No transcriptions to clean up.' - with tempfile.TemporaryDirectory() as temporary_directory: - temporary_directory = Path(temporary_directory).absolute() - with open(temporary_directory / 'dirty.txt', 'w') as f: - for key, value in transcriptions.items(): - f.write('{} {}\n'.format(key, value)) - result = sh.perl( - sh.cat(str(temporary_directory / 'dirty.txt')), - kaldi_wsj_tools / 'normalize_transcript.pl', - '' - ) + + text = ''.join([f'{key} {value}\n' for key, value in transcriptions.items()]) + cp = subprocess.run( + ['perl', kaldi_wsj_tools / 'normalize_transcript.pl', ''], + input=text, stdout=subprocess.PIPE, check=True, universal_newlines=True) + result = cp.stdout + result = [line.split(maxsplit=1) for line in result.strip().split('\n')] result = {k: v for k, v in result} return result