diff --git a/asr-python-advanced-tao-ngram-pretrain.ipynb b/asr-python-advanced-tao-ngram-pretrain.ipynb index ded74464..66af9ae3 100644 --- a/asr-python-advanced-tao-ngram-pretrain.ipynb +++ b/asr-python-advanced-tao-ngram-pretrain.ipynb @@ -26,7 +26,9 @@ "In this tutorial, we will pretrain Riva ASR language modeling (n-gram) with TAO Toolkit.
\n", "To understand the basics of Riva ASR APIs, refer to [Getting started with Riva ASR in Python](https://github.com/nvidia-riva/tutorials/blob/dev/22.04/asr-python-basics.ipynb).
\n", "\n", - "For more information about Riva, refer to the [Riva developer documentation](https://developer.nvidia.com/riva)." + "For more information about Riva, refer to the [Riva developer documentation](https://developer.nvidia.com/riva).\n", + "\n", + "**Prerequisite**: You have access and are logged into NVIDIA NGC. For step-by-step instructions, refer to the [NGC Getting Started Guide.](https://docs.nvidia.com/ngc/ngc-overview/index.html#registering-activating-ngc-account)" ] }, { @@ -168,9 +170,8 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Downloading the dataset\n", "#### LibriSpeech LM Normalized dataset\n", - "The training data is publicly available [here](https://www.openslr.org/resources/11/librispeech-lm-corpus.tgz) and can be downloaded directly." + "The training data is publicly available [here](https://www.openslr.org/resources/11/librispeech-lm-corpus.tgz) and can be downloaded directly.#### Downloading the dataset" ] }, { @@ -217,9 +218,12 @@ "Scripts to download and preprocess LibriSpeech dev-clean\n", "\"\"\"\n", "from multiprocessing import Pool\n", - "\n", - "import numpy\n", - "\n", + "try:\n", + " import numpy\n", + "except ImportError:\n", + " import subprocess\n", + " subprocess.check_output('pip3 install numpy', shell=True)\n", + " \n", "LOG_STR = \" To regenerate this file, please, remove it.\"\n", "\n", "def find_transcript_files(dir):\n", @@ -327,8 +331,8 @@ "metadata": {}, "outputs": [], "source": [ - "# Use a random 10,000 lines for training\n", - "!shuf -n 10000 $DATA_DOWNLOAD_DIR/librispeech-lm-norm.txt > $DATA_DOWNLOAD_DIR/reduced_training.txt" + "# Use a random 100,000 lines for training\n", + "!shuf -n 100000 $DATA_DOWNLOAD_DIR/librispeech-lm-norm.txt > $DATA_DOWNLOAD_DIR/reduced_training.txt" ] }, { @@ -337,7 +341,7 @@ "source": [ "---\n", "## TAO Toolkit workflow\n", - "The rest of the tutorial demonstrates what a sample TAO Toolkit workflow looks like." + "The rest of the tutorial demonstrates what a sample TAO Toolkit workflow looks like" ] }, { @@ -393,10 +397,13 @@ "metadata": {}, "outputs": [], "source": [ - "# Make sure the source directories exist, if not, create them\n", - "# ! mkdir \n", - "# ! mkdir \n", - "# ! mkdir " + "# Make sure the source directories mentioned above exist, if not, create them. Provide aboslute Paths\n", + "SPECS_DIR_LOCAL = \"\"\n", + "RESULT_DIR_LOCAL = \"\"\n", + "CACHE_DIR_LOCAL = \"\"\n", + "! mkdir $SPECS_DIR_LOCAL\n", + "! mkdir $RESULT_DIR_LOCAL\n", + "! mkdir $CACHE_DIR_LOCAL" ] }, { @@ -440,36 +447,6 @@ "```\n" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "---\n", - "### Set Relevant Paths\n", - "Please set these paths according to your environment." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# NOTE: The following paths are set from the perspective of the TAO Toolkit Docker. \n", - "\n", - "# The data is saved here\n", - "DATA_DIR='/data'\n", - "\n", - "# The configuration files are stored here\n", - "SPECS_DIR='/specs/n_gram'\n", - "\n", - "# The results are saved at this path\n", - "RESULTS_DIR='/results/n_gram'\n", - "\n", - "# Set your encryption key, and use the same key for all commands\n", - "KEY='tlt_encode'" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -488,8 +465,8 @@ "outputs": [], "source": [ "!tao n_gram download_specs \\\n", - " -r $RESULTS_DIR \\\n", - " -o $SPECS_DIR" + " -r /results \\\n", + " -o /specs" ] }, { @@ -531,19 +508,19 @@ "source": [ "# Preprocess training data (LibriSpeech LM Normalized)\n", "!tao n_gram dataset_convert \\\n", - " -e $SPECS_DIR/dataset_convert.yaml \\\n", - " -r $RESULTS_DIR/dataset_convert \\\n", + " -e /specs/dataset_convert.yaml \\\n", + " -r /results/dataset_convert \\\n", " extension=*.txt \\\n", - " source_data_dir=$DATA_DIR/reduced_training.txt \\\n", - " target_data_file=$DATA_DIR/preprocessed.txt\n", + " source_data_dir=/data/reduced_training.txt \\\n", + " target_data_file=/data/preprocessed.txt\n", "\n", "# Preprocess evaluation data (LibriSpeech dev-clean)\n", "!tao n_gram dataset_convert \\\n", - " -e $SPECS_DIR/dataset_convert.yaml \\\n", - " -r $RESULTS_DIR/dataset_convert \\\n", + " -e /specs/dataset_convert.yaml \\\n", + " -r /results/dataset_convert \\\n", " extension=*.txt \\\n", - " source_data_dir=$DATA_DIR/text/dev-clean.txt \\\n", - " target_data_file=$DATA_DIR/preprocessed_dev_clean.txt" + " source_data_dir=/data/text/dev-clean.txt \\\n", + " target_data_file=/data/preprocessed_dev_clean.txt" ] }, { @@ -587,18 +564,17 @@ "outputs": [], "source": [ "!tao n_gram train \\\n", - " -e $SPECS_DIR/train.yaml \\\n", - " -r $RESULTS_DIR/train \\\n", - " training_ds.data_file=$DATA_DIR/preprocessed.txt \\\n", - " model.order=3 \\\n", - " model.pruning=[0,0,1]" + " -e /specs/train.yaml \\\n", + " -r /results/base \\\n", + " training_ds.data_file=/data/preprocessed.txt \\\n", + " model.order=4 " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The train command produces three files called `train_n_gram.arpa`, `train_n_gram.vocab` and `train_n_gram.kenlm_intermediate` saved at `$RESULTS_DIR/train/checkpoints`." + "The train command produces three files called `train_n_gram.arpa`, `train_n_gram.vocab` and `train_n_gram.kenlm_intermediate` saved at `$RESULTS_DIR_LOCAL/train/checkpoints`." ] }, { @@ -627,10 +603,10 @@ "outputs": [], "source": [ "!tao n_gram evaluate \\\n", - " -e $SPECS_DIR/evaluate.yaml \\\n", - " -r $RESULTS_DIR/evaluate \\\n", - " restore_from=$RESULTS_DIR/train/checkpoints/train_n_gram.arpa \\\n", - " test_ds.data_file=$DATA_DIR/preprocessed_dev_clean.txt" + " -e /specs/evaluate.yaml \\\n", + " -r /results/evaluate \\\n", + " restore_from=/results/base/checkpoints/train_n_gram.arpa \\\n", + " test_ds.data_file=/data/preprocessed_dev_clean.txt" ] }, { @@ -670,9 +646,9 @@ "outputs": [], "source": [ "!tao n_gram infer \\\n", - " -e $SPECS_DIR/infer.yaml \\\n", - " -r $RESULTS_DIR/infer \\\n", - " restore_from=$RESULTS_DIR/train/checkpoints/train_n_gram.arpa" + " -e /specs/infer.yaml \\\n", + " -r /results/infer \\\n", + " restore_from=/results/base/checkpoints/train_n_gram.arpa" ] }, { @@ -705,23 +681,398 @@ "outputs": [], "source": [ "!tao n_gram export \\\n", - " -e $SPECS_DIR/export.yaml \\\n", - " -r $RESULTS_DIR/export \\\n", + " -e /specs/export.yaml \\\n", + " -r /results/base \\\n", " export_format=RIVA \\\n", - " export_to=exported-model.riva \\\n", - " restore_from=$RESULTS_DIR/train/checkpoints/train_n_gram.arpa \\\n", + " export_to=exported-base.riva \\\n", + " restore_from=/results/base/checkpoints/train_n_gram.arpa \\\n", " binary_type=trie \\\n", " binary_q_bits=8 \\\n", " binary_b_bits=7 \\\n", - " binary_a_bits=256\n", - " " + " binary_a_bits=256 " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The model is exported as `exported-model.binary` which is in a format suited for deployment in Riva." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "\n", + "## RIVA deployment with ASR\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Riva ServiceMaker\n", + "Servicemaker is the set of tools that aggregates all the necessary artifacts (models, files, configurations, and user settings) for Riva deployment to a target environment. It has two main components as shown below:\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 1. Riva-build\n", + "\n", + "This step helps build a Riva-ready version of the model. It’s only output is an intermediate format (called a RMIR) of an end to end pipeline for the supported services within Riva. We are taking a ASR Citrinet Model in consideration. Although same setup can be used for Conformer models too.
\n", + "\n", + "`riva-build` is responsible for the combination of one or more exported models (.riva files) into a single file containing an intermediate format called Riva Model Intermediate Representation (.rmir). This file contains a deployment-agnostic specification of the whole end-to-end pipeline along with all the assets required for the final deployment and inference. Please checkout the [documentation](https://docs.nvidia.com/deeplearning/riva/user-guide/docs/service-asr.html#pipeline-configuration) to find out more." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# # Riva skills version\n", + "# RIVA_VERSION=\"2.7.0\"\n", + "\n", + "# # ServiceMaker Docker\n", + "# RIVA_SM_CONTAINER = f\"nvcr.io/nvidia/riva/riva-speech:{RIVA_VERSION}-servicemaker\"\n", + "\n", + "# # Riva API Docker\n", + "# RIVA_API_CONTAINER =f\"nvcr.io/nvidia/riva/riva-speech:{RIVA_VERSION}\"\n", + "\n", + "# # Directory where the create model repo\n", + "# MODEL_LOC = \"\"\n", + "\n", + "# # Name of the .riva file\n", + "# MODEL_NAME = \"nvidia/tao/speechtotext_en_us_citrinet:deployable_v3.0\"\n", + "\n", + "# # Key that model is encrypted with, while exporting with TAO\n", + "# KEY = \"tlt_encode\"\n", + "\n", + "# # NGC API KEY, can be generated from ngc.nvidia.com/setup\n", + "# NGC_API_KEY=\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get the ServiceMaker docker and latest riva ASR model\n", + "! mkdir $MODEL_LOC\n", + "! docker pull $RIVA_SM_CONTAINER\n", + "! ngc registry model download-version $MODEL_NAME\n", + "! mv speechtotext_en_us_citrinet_vdeployable_v3.0/citrinet-1024-Jarvis-asrset-3_0-encrypted.riva $MODEL_LOC/Citrinet.riva" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 2. Setup Flashligh decoder\n", + "The Flashlight decoder, deployed by default in Riva, is a lexicon-based decoder and only emits words that are present in the provided lexicon file.\n", + "Vocabulary file: The vocabulary file is a flat text file containing a list of vocabulary words, each on its own line. For example:\n", + "```\n", + "the\n", + "i\n", + "to\n", + "and\n", + "a\n", + "you\n", + "of\n", + "that\n", + "```\n", + "This file is used by the riva-build process to generate the lexicon file.\n", + "\n", + "Lexicon file: The lexicon file is a flat text file that contains the mapping of each vocabulary word to its tokenized form, e.g, sentencepiece tokens, separated by a tab. Below is an example:\n", + "```\n", + "with ▁with\n", + "not ▁not\n", + "this ▁this\n", + "just ▁just\n", + "my ▁my\n", + "as ▁as\n", + "don't ▁don ' t\n", + "```\n", + "Note: Ultimately, the Riva decoder makes use only of the lexicon file directly at run time (but not the vocabulary file).\n", + "\n", + "Riva ServiceMaker automatically tokenizes the words in the vocabulary file to generate the lexicon file. It uses the correct tokenizer model that is packaged together with the acoustic model in the .riva file. By default, Riva generates 1 tokenized form for each word in the vocabulary file." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate vocabulary using base LM training data\n", + "! cat $DATA_DOWNLOAD_DIR/preprocessed.txt | sed \"s/ /\\n/g\" | sort -u > $RESULT_DIR_LOCAL/base/dict_vocab.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate the RMIR file with trained Base Language Model\n", + "! docker run -it --rm --gpus 0 -v $MODEL_LOC:/data \\\n", + " -v $RESULT_DIR_LOCAL/base:/lm \\\n", + " --name riva-service-maker-lm \\\n", + " $RIVA_SM_CONTAINER -- \\\n", + " riva-build speech_recognition /data/base_asr.rmir:$KEY \\\n", + " /data/Citrinet.riva:$KEY \\\n", + " --ms_per_timestep=80 \\\n", + " --chunk_size=0.16 \\\n", + " --left_padding_size=1.92 \\\n", + " --right_padding_size=1.92 \\\n", + " --decoder_type=flashlight \\\n", + " --decoding_language_model_binary=/lm/exported-base.binary \\\n", + " --decoding_vocab=/lm/dict_vocab.txt \\\n", + " --flashlight_decoder.lm_weight=0.2 \\\n", + " --flashlight_decoder.word_insertion_score=0.2 \\\n", + " --flashlight_decoder.beam_threshold=20. \\\n", + " --featurizer.dither=0.0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### 3. Riva-deploy\n", + "\n", + "The deployment tool takes as input one or more Riva Model Intermediate Representation (RMIR) files and a target model repository directory. It creates an ensemble configuration specifying the pipeline for the execution and finally writes all those assets to the output model repository directory." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Syntax: riva-deploy -f dir-for-rmir/model.rmir:key output-dir-for-repository\n", + "! docker run --rm --gpus 0 -v $MODEL_LOC:/data $RIVA_SM_CONTAINER -- \\\n", + " riva-deploy -f /data/base_asr.rmir:$KEY /data/models/" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Start Riva Server\n", + "Once the model repository is generated, we are ready to start the Riva server. From this step onwards you need to download the Riva QuickStart Resource from NGC. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Download Riva Quickstart\n", + "! ngc registry resource download-version nvidia/riva/riva_quickstart:$RIVA_VERSION" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### config.sh snippet\n", + "```\n", + "service_enabled_asr=true ## MAKE CHANGES HERE\n", + "service_enabled_nlp=false ## MAKE CHANGES HERE\n", + "service_enabled_tts=false ## MAKE CHANGES HERE\n", + "\n", + "# Enable Riva Enterprise\n", + "# If enrolled in Enterprise, enable Riva Enterprise by setting configuration\n", + "# here. You must explicitly acknowledge you have read and agree to the EULA.\n", + "# RIVA_API_KEY= \n", + "# RIVA_API_NGC_ORG= \n", + "# RIVA_EULA=accept\n", + "\n", + "# Language code to fetch models of a specify language\n", + "# Currently only ASR supports languages other than English\n", + "# Supported language codes: en-US, de-DE, es-US, ru-RU, zh-CN, hi-IN, fr-FR\n", + "# for any language other than English, set service_enabled_nlp and service_enabled_tts to False\n", + "# for multiple languages enter space separated language codes.\n", + "language_code=(\"en-US\")\n", + "\n", + "# ASR acoustic model architecture\n", + "# Supported values are: conformer, citrinet_1024, citrinet_256 (en-US + arm64 only), jasper (en-US + amd64 only), quartznet (en-US + amd64 only)\n", + "asr_acoustic_model=(\"conformer\")\n", + "\n", + "# Specify one or more GPUs to use\n", + "# specifying more than one GPU is currently an experimental feature, and may result in undefined behaviours.\n", + "gpus_to_use=\"device=0\"\n", + "\n", + "# Specify the encryption key to use to deploy models\n", + "MODEL_DEPLOY_KEY=\"tlt_encode\" ## MAKE CHANGES HERE\n", + "\n", + "# Locations to use for storing models artifacts\n", + "#\n", + "# If an absolute path is specified, the data will be written to that location\n", + "# Otherwise, a docker volume will be used (default).\n", + "#\n", + "# riva_init.sh will create a `rmir` and `models` directory in the volume or\n", + "# path specified.\n", + "#\n", + "# RMIR ($riva_model_loc/rmir)\n", + "# Riva uses an intermediate representation (RMIR) for models\n", + "# that are ready to deploy but not yet fully optimized for deployment. Pretrained\n", + "# versions can be obtained from NGC (by specifying NGC models below) and will be\n", + "# downloaded to $riva_model_loc/rmir by `riva_init.sh`\n", + "#\n", + "# Custom models produced by NeMo or TLT and prepared using riva-build\n", + "# may also be copied manually to this location $(riva_model_loc/rmir).\n", + "#\n", + "# Models ($riva_model_loc/models)\n", + "# During the riva_init process, the RMIR files in $riva_model_loc/rmir\n", + "# are inspected and optimized for deployment. The optimized versions are\n", + "# stored in $riva_model_loc/models. The riva server exclusively uses these\n", + "# optimized versions.\n", + "riva_model_loc=\"riva-model-repo\" ## MAKE CHANGES HERE (Replace with MODEL_LOC) \n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Ensure you have permission to execute these scripts\n", + "! cd riva_quickstart_v$RIVA_VERSION && chmod +x ./riva_init.sh && chmod +x ./riva_start.sh && chmod +x ./riva_stop.sh" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Run Riva Start. This will deploy your model(s).\n", + "! cd riva_quickstart_v$RIVA_VERSION && ./riva_start.sh config.sh" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Download Evaluation dataset\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#Note: This data can be used only with NVIDIA’s products or services for evaluation and benchmarking purposes.\n", + "! ngc registry resource download-version --dest $DATA_DOWNLOAD_DIR nvstaging/tao/healthcare_eval_dataset:1.0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "---\n", + "## Run Inference\n", + "Once the Riva server is up and running with your models, you can send inference requests querying the server. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "! docker run --rm -v $DATA_DOWNLOAD_DIR/healthcare_eval_dataset_v1.0:/data \\\n", + " --net=host $RIVA_API_CONTAINER -- \\\n", + " riva_streaming_asr_client \\\n", + " --automatic_punctuation=false \\\n", + " --interim_results=false \\\n", + " --word_time_offsets=false \\\n", + " --audio_file /data/general.json \\\n", + " --output_filename=/data/base_asr_on_base_output.json" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "! docker run --rm -v $DATA_DOWNLOAD_DIR/healthcare_eval_dataset_v1.0/:/data \\\n", + " --net=host $RIVA_API_CONTAINER -- \\\n", + " riva_streaming_asr_client \\\n", + " --automatic_punctuation=false \\\n", + " --interim_results=false \\\n", + " --word_time_offsets=false \\\n", + " --audio_file /data/healthcare.json \\\n", + " --output_filename=/data/base_asr_on_domain_output.json" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Calculate word error rate\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "! pip install jiwer\n", + "from jiwer import wer\n", + "import json\n", + "\n", + "def calculate_wer(ground_truth_manifest, asr_transcript):\n", + " data ={}\n", + " ground_truths = []\n", + " predictions = []\n", + " with open(ground_truth_manifest) as file:\n", + " for line in file:\n", + " dt = json.loads(line)\n", + " data[dt['audio_filepath']] = dt['text']\n", + " with open(asr_transcript) as file:\n", + " for line in file:\n", + " dt = json.loads(line)\n", + " if dt['audio_filepath'] in data:\n", + " ground_truths.append(data[dt['audio_filepath']])\n", + " predictions.append(dt['text'])\n", + " return round(100*wer(ground_truths, predictions), 2)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print( \"WER of base model on generic domain data\", calculate_wer(f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/general.json\", f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/base_asr_on_base_output.json\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "print(\"WER of base model on Healthcare domain data\", calculate_wer(f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/healthcare.json\", f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/base_asr_on_domain_output.json\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The model is exported as `exported-model.riva` which is in a format suited for deployment in Riva." + "Above results show the model performace is well on general data but not on healthcare specific domain data. We can finetune the Language model on healthcare domain data to boost ASR performance" ] }, { @@ -729,20 +1080,421 @@ "metadata": {}, "source": [ "---\n", - "### What's Next?" + "\n", + "### Finetuning/Interpolation\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The fine-tuning process will continue training using a previously trained model by training a second model on new domain data and interpolating it with the original model. Finetuning requires the original model have intermediate enabled during training. A finetuned model cannot be used for finetuning again.
\n", + "\n", + "\n", + "### Downloading and procesing domain data (healthcare) for LM finetuning\n", + "For the purpose of finetuning on healthcare domain we can make use of Kaggle dataset PubMed 200k RCT: a Dataset for Sequential Sentence Classification in Medical Abstracts. [https://arxiv.org/abs/1710.06071]
\n", + "This dataset is available at https://www.kaggle.com/datasets/anshulmehtakaggl/200000-abstracts-for-seq-sentence-classification
\n", + "Please follow the instructions to install and authenticate Kaggle API https://www.kaggle.com/docs/api \n", + "
\n", + "**Note**: *Each user is responsible for checking the content of datasets and the applicable licenses and determining if suitable for the intended use*\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!kaggle datasets download -d anshulmehtakaggl/200000-abstracts-for-seq-sentence-classification" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!unzip -d $DATA_DOWNLOAD_DIR 200000-abstracts-for-seq-sentence-classification.zip" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Perform basic text cleaning and generate domain data\n", + "import string,re\n", + "def clean_text(text):\n", + " text = re.sub(r\"[^a-z' ]+\", \"\", text.lower().strip())\n", + " text = ' '.join(text.split())\n", + " if len(text.split())> 5:\n", + " return text.strip()\n", + " \n", + "# Using dev file since we want a small amount of finetuning data. For better text Normalization use NeMo [https://github.com/NVIDIA/NeMo/tree/main/nemo_text_processing]\n", + "with open(f'{DATA_DOWNLOAD_DIR}/20k_abstracts_numbers_with_@/dev.txt') as file, open(f'{DATA_DOWNLOAD_DIR}/domain_data_all.txt', 'w') as outfile:\n", + " for line in file:\n", + " if line.startswith(\"###\") or not line.strip():\n", + " continue\n", + " _, text = line.strip().split('\\t')\n", + " text = clean_text(text)\n", + " if text:\n", + " outfile.write(text+'\\n')\n", + " \n", + "# Picking top 10000 lines from dataset\n", + "! head -10000 $DATA_DOWNLOAD_DIR/domain_data_all.txt > $DATA_DOWNLOAD_DIR/domain_data.txt" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The fine-tuning process will continue training using a previously trained model by training a second model on new domain data and interpolating it with the original model. Finetuning requires the original model have intermediate enabled during training. A finetuned model cannot be used for finetuning again.
\n", + "\n", + "\n", + "For Finetuning a N-gram language model in TAO Toolkit, we use the `tao n_gram finetune` command with the following args:\n", + "- `-e`: Path to the spec file\n", + "- `-k`: User specified encryption key to use while saving/loading the model\n", + "- `-r`: Path to a folder where the outputs should be written. Make sure this is mapped in tlt_mounts.json\n", + "- Any overrides to the spec file eg. `model.order`, `weight` etc\n", + "
\n", + "\n", + "\n", + "More details about these arguments are present in the [TAO Toolkit Getting Started Guide](https://docs.nvidia.com/tao/tao-toolkit/text/overview.html)
\n", + "`Note:` All file paths correspond to the destination mounted directory that is visible in the TAO Toolkit docker container used in backend.
" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Interpolate the domain LM with base LM\n", + "!tao n_gram finetune \\\n", + " -e /specs/finetune.yaml \\\n", + " -r /results \\\n", + " restore_from=/results/base/checkpoints/train_n_gram.kenlm_intermediate \\\n", + " tuning_ds.data_file=/data/domain_data.txt \\\n", + " model.order=4 \\\n", + " weight=0.6 # weight of domain specific model \\\n", + " -k $KEY" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Export interpolated LM to Riva compatible format\n", + "!tao n_gram export \\\n", + " -e /specs/export.yaml \\\n", + " -r /results/interpolated \\\n", + " export_format=RIVA \\\n", + " export_to=exported-model.riva \\\n", + " restore_from=/results/checkpoints/finetune_n_gram.arpa \\\n", + " binary_type=trie \\\n", + " binary_q_bits=8 \\\n", + " binary_b_bits=7 \\\n", + " binary_a_bits=256" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Interpolated LM is not generated at /results/interpolated/exported-model.binary.
\n", + "We can now use this LM along with new vocabulary file to generate model repo for Domain specific ASR" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Add domain specific words to vocabulary file\n", + "! cat $DATA_DOWNLOAD_DIR/domain_data.txt | sed \"s/ /\\n/g\" | sort -u > $RESULT_DIR_LOCAL/interpolated/dict_vocab_domain.txt\n", + "! cat $RESULT_DIR_LOCAL/base/dict_vocab.txt $RESULT_DIR_LOCAL/interpolated/dict_vocab_domain.txt | sort -u > $RESULT_DIR_LOCAL/interpolated/dict_vocab.txt" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Generate new model repo with interpolated LM. Set absolute path to create MODEL_LOC_DOMAIN\n", + "MODEL_LOC_DOMAIN = \"\"\n", + "! mkdir $MODEL_LOC_DOMAIN\n", + "! cp $MODEL_LOC/Citrinet.riva $MODEL_LOC_DOMAIN/\n", + "! docker run -it --rm --gpus 0 -v $MODEL_LOC_DOMAIN:/data \\\n", + " -v $RESULT_DIR_LOCAL/interpolated:/lm \\\n", + " --name riva-service-maker-lm \\\n", + " $RIVA_SM_CONTAINER -- \\\n", + " riva-build speech_recognition /data/interpolated_asr.rmir:$KEY \\\n", + " /data/Citrinet.riva:$KEY \\\n", + " --ms_per_timestep=80 \\\n", + " --chunk_size=0.16 \\\n", + " --left_padding_size=1.92 \\\n", + " --right_padding_size=1.92 \\\n", + " --decoder_type=flashlight \\\n", + " --decoding_language_model_binary=/lm/exported-model.binary \\\n", + " --decoding_vocab=/lm/dict_vocab.txt \\\n", + " --flashlight_decoder.lm_weight=0.2 \\\n", + " --flashlight_decoder.word_insertion_score=0.2 \\\n", + " --flashlight_decoder.beam_threshold=20. \\\n", + " --force --featurizer.dither=0.0\n", + "! docker run --rm --gpus 0 -v $MODEL_LOC_DOMAIN:/data $RIVA_SM_CONTAINER -- \\\n", + " riva-deploy -f /data/interpolated_asr.rmir:$KEY /data/models/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Update riva_model_loc in riva_quickstart config file to MODEL_LOC_DOMAIN\n", + "! cd riva_quickstart_v$RIVA_VERSION && ./riva_stop.sh && ./riva_start.sh config.sh" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get model transcripts on base data\n", + "! docker run --rm -v $DATA_DOWNLOAD_DIR/healthcare_eval_dataset_v1.0:/data \\\n", + " --net=host $RIVA_API_CONTAINER -- \\\n", + " riva_streaming_asr_client \\\n", + " --automatic_punctuation=false \\\n", + " --interim_results=false \\\n", + " --word_time_offsets=false \\\n", + " --audio_file /data/general.json \\\n", + " --output_filename=/data/interpolated_asr_on_base_output.json" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Get model transcripts on Healthcare domain data\n", + "! docker run --rm -v $DATA_DOWNLOAD_DIR/healthcare_eval_dataset_v1.0:/data \\\n", + " --net=host $RIVA_API_CONTAINER -- \\\n", + " riva_streaming_asr_client \\\n", + " --automatic_punctuation=false \\\n", + " --interim_results=false \\\n", + " --word_time_offsets=false \\\n", + " --audio_file /data/healthcare.json \\\n", + " --output_filename=/data/interpolated_asr_on_domain_output.json" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Check WER on base data\n", + "print(\"WER of base model on generic data: \", calculate_wer(f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/general.json\", f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/base_asr_on_base_output.json\"))\n", + "print(\"WER of Domain model on generic data: \", calculate_wer(f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/general.json\", f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/interpolated_asr_on_base_output.json\"))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Check WER on Healtcare domain data\n", + "print(\"WER of base model on Healtcare domain data: \", calculate_wer(f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/healthcare.json\", f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/base_asr_on_domain_output.json\"))\n", + "print(\"WER of Domain model on Healtcare domain data: \", calculate_wer(f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/healthcare.json\", f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/interpolated_asr_on_domain_output.json\"))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "With help of interpolation we were able to improve the performance of our ASR model on Healthcare domain as well as generic domain" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Pruning" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "LM generated by simply passing the text corpus to TAO toolkit contains some n-grams which are less frequent(in corpus) and thus have very low probabilities. Such n-grams can be removed by `pruning`.
\n", + "Pruning requires some thresholds which can be defined in the `train.yaml` file as follows (for 4-gram):\n", + "```\n", + "pruning:\n", + " - 0\n", + " - 1\n", + " - 7\n", + " - 9\n", + "```\n", + "or can be passed as command line argument as follows:
\n", + "`model.pruning=[0,1,7,9]`\n", + "\n", + "All the n-gram with frequncy less than or equal to specified threshold will get eliminated.
\n", + "Here, 2-grams with freq. <= 1, 3-gram with freq.<=7 & 4-gram with freq.<=9 will get eliminated.
\n", + "There's a tradeoff between degree of pruning and accuracy. High pruning parameters will reduce the size of language model but at the cost of model accuracy! \n", + "\n", + "#### *Note:\n", + "Pruning of 1-gram is not supported, threshold for 1-gram should always be 0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!tao n_gram train \\\n", + " -e /specs/train.yaml \\\n", + " -r /results/pruned \\\n", + " training_ds.data_file=/data/preprocessed.txt \\\n", + " model.order=4 \\\n", + " model.pruning=[0,1,7,9]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Lets check the size of original LM and Pruned LM\n", + "!echo \"Size of unpruned ARPA: $(du -h $RESULT_DIR_LOCAL/base/checkpoints/train_n_gram.arpa | cut -f 1)\"\n", + "!echo \"Size of pruned ARPA: $(du -h $RESULT_DIR_LOCAL/pruned/checkpoints/train_n_gram.arpa | cut -f 1)\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We were able to significantly reduce the size of LM. Lets see the impact of pruned LM on ASR's accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#Lets deploy ASR server with Pruned LM, Set absolute path to create MODEL_LOC_PRUNED\n", + "MODEL_LOC_PRUNED = \"PATH_TO_PRUNED_MODEL_REPO_LOCATION\"\n", + "\n", + "#export to Riva format \n", + "!tao n_gram export \\\n", + " -e /specs/export.yaml \\\n", + " -r /results/base \\\n", + " export_format=RIVA \\\n", + " export_to=pruned-base.riva \\\n", + " restore_from=/results/pruned/checkpoints/train_n_gram.arpa \\\n", + " binary_type=trie \\\n", + " binary_q_bits=8 \\\n", + " binary_b_bits=7 \\\n", + " binary_a_bits=256\n", + "\n", + "# Generate RMIR\n", + "! mkdir $MODEL_LOC_PRUNED\n", + "! cp $MODEL_LOC/Citrinet.riva $MODEL_LOC_PRUNED/\n", + "! docker run -it --rm --gpus 0 -v $MODEL_LOC_PRUNED:/data \\\n", + " -v $RESULT_DIR_LOCAL/base:/lm \\\n", + " --name riva-service-maker-lm \\\n", + " $RIVA_SM_CONTAINER -- \\\n", + " riva-build speech_recognition /data/pruned_asr.rmir:$KEY \\\n", + " /data/Citrinet.riva:$KEY \\\n", + " --ms_per_timestep=80 \\\n", + " --chunk_size=0.16 \\\n", + " --left_padding_size=1.92 \\\n", + " --right_padding_size=1.92 \\\n", + " --decoder_type=flashlight \\\n", + " --decoding_language_model_binary=/lm/pruned-base.binary \\\n", + " --decoding_vocab=/lm/dict_vocab.txt \\\n", + " --flashlight_decoder.lm_weight=0.2 \\\n", + " --flashlight_decoder.word_insertion_score=0.2 \\\n", + " --flashlight_decoder.beam_threshold=20. \\\n", + " --force --featurizer.dither=0.0\n", + " \n", + "# Deploy RMIR with Pruned LM\n", + "! docker run --rm --gpus 0 -v $MODEL_LOC_PRUNED:/data $RIVA_SM_CONTAINER -- \\\n", + " riva-deploy -f /data/pruned_asr.rmir:$KEY /data/models/" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Update riva_model_loc in riva_quickstart config file to MODEL_LOC_PRUNED and then start server\n", + "! cd riva_quickstart_v$RIVA_VERSION && ./riva_stop.sh && ./riva_start.sh config.sh" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Evaluate model and calculate WERs\n", + "! docker run --rm -v $DATA_DOWNLOAD_DIR/healthcare_eval_dataset_v1.0/:/data \\\n", + " --net=host $RIVA_API_CONTAINER -- \\\n", + " riva_streaming_asr_client \\\n", + " --automatic_punctuation=false \\\n", + " --interim_results=false \\\n", + " --word_time_offsets=false \\\n", + " --audio_file /data/general.json \\\n", + " --output_filename=/data/pruned_asr_on_base_output.json\n", + "\n", + "! docker run --rm -v $DATA_DOWNLOAD_DIR/healthcare_eval_dataset_v1.0/:/data \\\n", + " --net=host $RIVA_API_CONTAINER -- \\\n", + " riva_streaming_asr_client \\\n", + " --automatic_punctuation=false \\\n", + " --interim_results=false \\\n", + " --word_time_offsets=false \\\n", + " --audio_file /data/healthcare.json \\\n", + " --output_filename=/data/pruned_asr_on_domain_output.json" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Check WER on base data\n", + "print(\"WER of base model on generic data: \", calculate_wer(f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/general.json\", f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/base_asr_on_base_output.json\"))\n", + "print(\"WER of Pruned base model on generic data: \", calculate_wer(f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/general.json\", f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/pruned_asr_on_base_output.json\"))\n", + "\n", + "# Check WER on Healtcare domain data\n", + "print(\"WER of base model on Healtcare domain data: \", calculate_wer(f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/healthcare.json\", f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/base_asr_on_domain_output.json\"))\n", + "print(\"WER of Pruned base model on Healtcare domain data: \", calculate_wer(f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/healthcare.json\", f\"{DATA_DOWNLOAD_DIR}/healthcare_eval_dataset_v1.0/pruned_asr_on_domain_output.json\"))\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "You could use TAO Toolkit to build custom models for your own applications, or you could deploy the custom model to NVIDIA Riva." + "Pruning drops some low probabiliy N-grams from Lnaguage model. This can sometime help improve ASR model's accuracy by reducing the perplexity of Language model. " ] } ], "metadata": { "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "Python 3", "language": "python", "name": "python3" }, @@ -756,7 +1508,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.13" + "version": "3.8.5" } }, "nbformat": 4,