Skip to content

Commit 9afa326

Browse files
committed
docs + hf
Signed-off-by: Alexan <[email protected]>
1 parent b3362a4 commit 9afa326

File tree

7 files changed

+128
-73
lines changed

7 files changed

+128
-73
lines changed

dataset_configs/armenian/toloka/pipeline_get_final_res.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ documentation: |
66
It processes all accepted results from the Toloka pool and prepares the data for training by refining and resampling audio files and ensuring text formatting consistency.
77
88
**Stage Overview**:
9+
910
This stage includes the following steps:
1011
1. Downloading all the ACCEPTED results from the Toloka platform.
1112
2. Filtering out damaged audio files.

dataset_configs/armenian/toloka/pipeline_start.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ documentation: |
66
It sets up the foundation for creating structured tasks by initializing a new Toloka project, preparing pools, and processing textual data to generate a clean and organized corpus.
77
88
**Stage Overview**:
9+
910
This stage focuses on preparing and refining the dataset through the following steps:
1011
1. Creating a new Toloka project.
1112
2. Creating a new pool for the project.

dataset_configs/armenian/toloka/pipeline_validate_answers.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ documentation: |
77
88
**Stage Overview**:
99
This stage includes the following steps:
10+
1011
1. Downloading results of completed tasks from Toloka.
1112
2. Validating the audio files and filtering out corrupted files.
1213
3. Transcribing Armenian audio to text using a HuggingFace model.

sdp/processors/datasets/coraa/create_initial_manifest.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import os
33
from pathlib import Path
44
from typing import List
5-
from huggingface_hub import snapshot_download
65
import pandas as pd
76

87
import rarfile #Needs to be installed
@@ -64,7 +63,11 @@ def prepare(self):
6463
os.makedirs(self.resampled_audio_dir, exist_ok=True)
6564
os.makedirs(self.extract_archive_dir, exist_ok=True)
6665
if not self.already_downloaded:
67-
snapshot_download(repo_id="gabrielrstan/CORAA-v1.1", repo_type='dataset', local_dir=self.raw_data_dir)
66+
try:
67+
from huggingface_hub import snapshot_download
68+
snapshot_download(repo_id="gabrielrstan/CORAA-v1.1", repo_type='dataset', local_dir=self.raw_data_dir)
69+
except ImportError:
70+
raise ImportError("huggingface_hub is required to download the dataset. Please install it with pip install huggingface_hub")
6871
if not self.already_extracted:
6972

7073
if self.data_split == 'train':

sdp/processors/huggingface/speech_recognition.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -23,21 +23,27 @@
2323
from typing import Optional
2424

2525
class ASRTransformers(BaseProcessor):
26-
"""
27-
Processor to transcribe using ASR Transformers model from HuggingFace.
26+
"""This processor transcribes audio files using HuggingFace ASR Transformer models.
27+
28+
It processes audio files from the manifest and adds transcriptions using the specified
29+
pre-trained model from HuggingFace.
2830
2931
Args:
30-
pretrained_model (str): name of pretrained model on HuggingFace.
31-
output_text_key (str): Key to save transcription result.
32-
input_audio_key (str): Key to read audio file. Defaults to "audio_filepath".
33-
input_duration_key (str): Audio duration key. Defaults to "duration".
34-
device (str): Inference device.
35-
batch_size (int): Inference batch size. Defaults to 1.
36-
chunk_length_s (int): Length of the chunks (in seconds) into which the input audio should be divided.
37-
Note: Some models perform the chunking on their own (for instance, Whisper chunks into 30s segments also by maintaining the context of the previous chunks).
38-
torch_dtype (str): Tensor data type. Default to "float32"
39-
max_new_tokens (Optional[int]): The maximum number of new tokens to generate.
40-
If not specified, there is no hard limit on the number of tokens generated, other than model-specific constraints.
32+
pretrained_model (str): Name of pretrained model on HuggingFace.
33+
output_text_key (str): Key to save transcription result in the manifest.
34+
input_audio_key (str): Key to read audio file paths from the manifest. Default: "audio_filepath".
35+
input_duration_key (str): Key for audio duration in the manifest. Default: "duration".
36+
device (str, optional): Inference device (e.g., "cuda", "cpu"). Default: None.
37+
batch_size (int): Inference batch size. Default: 1.
38+
chunk_length_s (int): Length of audio chunks in seconds. Default: 0.
39+
torch_dtype (str): Tensor data type for model inference. Default: "float32".
40+
generate_task (str): Task type for generation. Default: "transcribe".
41+
generate_language (str): Language for generation. Default: "english".
42+
max_new_tokens (int, optional): Maximum number of new tokens to generate. Default: None.
43+
44+
Returns:
45+
A manifest with transcribed text added to each entry under the specified output_text_key.
46+
4147
"""
4248

4349
def __init__(

sdp/processors/modify_manifest/data_to_data.py

Lines changed: 75 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -739,8 +739,30 @@ def process_dataset_entry(self, data_entry):
739739
return [DataEntry(data=data_entry)]
740740

741741

742-
743742
class CopyManifestData(BaseParallelProcessor):
743+
"""This processor copies files specified in the manifest to a new location.
744+
745+
It is useful for creating a consolidated dataset by gathering files from different sources
746+
into a single directory.
747+
748+
Args:
749+
copy_path (str): The destination directory where files will be copied.
750+
source_filepath (str): The key in the manifest that contains the path to
751+
the file to be copied. Default: "audio_path".
752+
753+
Returns:
754+
The same data as in the input manifest, but the files referenced in the manifest
755+
will have been copied to the specified destination directory.
756+
757+
Example:
758+
.. code-block:: yaml
759+
760+
- _target_: sdp.processors.modify_manifest.data_to_data.CopyManifestData
761+
input_manifest_file: ${workspace_dir}/dataset.json
762+
output_manifest_file: ${workspace_dir}/dataset_copied.json
763+
copy_path: ${workspace_dir}/consolidated_data
764+
source_filepath: "audio_filepath"
765+
"""
744766
def __init__(
745767
self,
746768
copy_path: str,
@@ -906,6 +928,18 @@ def process_dataset_entry(self, data_entry) -> List:
906928

907929

908930
class GetWER(BaseParallelProcessor):
931+
"""This processor calculates Word Error Rate (WER) between predicted text and ground truth text.
932+
933+
It computes the WER for each entry in the manifest and adds the result as a new field.
934+
935+
Args:
936+
text_key (str): Key for the ground truth text field in the manifest. Default: "text".
937+
pred_text_key (str): Key for the predicted text field in the manifest. Default: "pred_text".
938+
939+
Returns:
940+
The same data as in the input manifest with an additional 'wer' field containing
941+
the calculated Word Error Rate between the specified text fields.
942+
"""
909943
def __init__(
910944
self,
911945
text_key: str = "text",
@@ -922,11 +956,33 @@ def process_dataset_entry(self, data_entry) -> List:
922956

923957

924958
class MakeSentence(BaseParallelProcessor):
925-
"""
926-
Processes a text string by capitalizing its first character (if enabled) and appending
927-
an end_symbol if the text does not already end with punctuation.
928-
"""
959+
"""This processor formats text strings into proper sentences.
960+
961+
It capitalizes the first character of the text (if enabled) and appends
962+
an end symbol if the text does not already end with punctuation.
963+
964+
Args:
965+
text_key (str): The key in the manifest containing the text to be processed.
966+
Default: "text".
967+
end_symbol (str): The punctuation symbol to add at the end of the text if it
968+
doesn't already have one. Default: ":".
969+
make_uppercase (bool): Whether to capitalize the first character of the text.
970+
Default: True.
929971
972+
Returns:
973+
The same data as in the input manifest with the text field modified to have
974+
proper sentence formatting.
975+
976+
Example:
977+
.. code-block:: yaml
978+
979+
- _target_: sdp.processors.modify_manifest.data_to_data.MakeSentence
980+
input_manifest_file: ${workspace_dir}/dataset.json
981+
output_manifest_file: ${workspace_dir}/dataset_formatted.json
982+
text_key: "transcript"
983+
end_symbol: "."
984+
make_uppercase: true
985+
"""
930986
def __init__(
931987
self,
932988
text_key: str = "text",
@@ -949,30 +1005,22 @@ def process_dataset_entry(self, data_entry) -> List:
9491005
return [DataEntry(data=data_entry)]
9501006

9511007

952-
9531008
class ASRFileCheck(BaseProcessor):
954-
"""
955-
ASRFileCheck is a class for validating audio files listed in a manifest file.
956-
This class checks if each audio file can be successfully loaded with the `torchaudio` library, marking
957-
and moving corrupted files to a specified directory.
1009+
"""This processor validates audio files in the manifest and identifies corrupted files.
1010+
1011+
It attempts to load each audio file using the torchaudio library and moves corrupted
1012+
files to a specified directory.
1013+
1014+
Args:
1015+
audio_filepath_key (str): The key in the manifest that contains the path to
1016+
the audio file. Default: "audio_filepath".
1017+
corrupted_audio_dir (str): The directory where corrupted audio files will be moved.
1018+
workspace_dir (str, optional): The base directory for resolving relative paths.
1019+
Default: None.
1020+
1021+
Returns:
1022+
A manifest with corrupted audio files removed.
9581023
959-
Attributes:
960-
----------
961-
audio_filepath_key : str, optional
962-
The key in the manifest entries used to retrieve the path to the audio file. Defaults to 'audio_filepath'.
963-
corrupted_audio_dir : str
964-
The directory where corrupted audio files will be moved. This is a required parameter.
965-
workspace_dir : str, optional
966-
The base directory where audio files are stored. If provided, audio file paths will be resolved
967-
relative to this directory. Defaults to None.
968-
failed_files : list
969-
A list of file paths for audio files that failed to load.
970-
971-
Methods:
972-
-------
973-
process()
974-
Checks each file listed in the manifest to ensure it can be loaded with torchaudio.
975-
Moves corrupted files and outputs a new manifest with only valid entries.
9761024
"""
9771025
def __init__(self, audio_filepath_key: str = "audio_filepath", corrupted_audio_dir: str = None, workspace_dir: str = None, **kwargs):
9781026
"""

sdp/processors/toloka/accept_if.py

Lines changed: 26 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -32,37 +32,32 @@
3232

3333

3434
class AcceptIfWERLess(BaseParallelProcessor):
35-
"""
36-
AcceptIfWERLess is a class for accepting Toloka assignments if the Word Error Rate (WER) is below a specified threshold.
37-
This class uses Toloka's API to evaluate the WER of assignments and accept them if they meet the criteria.
38-
39-
Attributes:
40-
----------
41-
input_data_file : str
42-
The path to the input data file containing API configurations.
43-
input_pool_file : str
44-
The path to the input pool file containing pool configurations.
45-
threshold : float, optional
46-
The WER threshold below which assignments are accepted. Defaults to 75.
47-
config_file : str, optional
48-
The path to the configuration file. Defaults to None.
49-
API_KEY : str, optional
50-
The API key used to authenticate with Toloka's API. Defaults to None, in which case it tries to
51-
load the key from environment variables or config file.
52-
platform : str, optional
53-
Specifies the Toloka environment (e.g., 'PRODUCTION', 'SANDBOX'). Defaults to None, meaning it will
54-
try to load from environment variables or the config file.
55-
pool_id : str, optional
56-
The ID of the pool from which assignments will be retrieved. Defaults to None.
57-
58-
Methods:
59-
-------
60-
load_config()
61-
Loads configuration data from a config file to populate API_KEY, platform, and pool_id attributes.
62-
prepare()
63-
Prepares the class by loading API configuration, pool configuration, and initializing Toloka client.
64-
process()
65-
Accepts Toloka assignments if their Word Error Rate (WER) is below the specified threshold.
35+
"""This processor accepts Toloka assignments if the Word Error Rate (WER) is below a threshold.
36+
37+
It evaluates the WER between ground truth and predicted text for each assignment
38+
and accepts those that meet the specified threshold criteria.
39+
40+
Args:
41+
input_data_file (str): Path to the input data file containing API configurations.
42+
input_pool_file (str): Path to the input pool file containing pool configurations.
43+
threshold (float): The WER threshold below which assignments are accepted. Default: 75.
44+
config_file (str, optional): Path to the configuration file. Default: None.
45+
API_KEY (str, optional): The API key for authenticating with Toloka's API. Default: None.
46+
platform (str, optional): The Toloka platform to use. Default: None.
47+
pool_id (str, optional): The ID of the Toloka pool. Default: None.
48+
49+
Returns:
50+
A manifest with accepted assignments from Toloka based on the WER threshold.
51+
52+
Example:
53+
.. code-block:: yaml
54+
55+
- _target_: sdp.processors.toloka.accept_if.AcceptIfWERLess
56+
input_manifest_file: ${workspace_dir}/result_manifest_pred_clean.json
57+
output_manifest_file: ${workspace_dir}/result_manifest_pred_review.json
58+
input_data_file: ${workspace_dir}/data_file.json
59+
input_pool_file: ${workspace_dir}/taskpool.json
60+
threshold: 50
6661
"""
6762
def __init__(
6863
self,

0 commit comments

Comments
 (0)