Skip to content

Commit fa85aa1

Browse files
authored
Merge pull request #155 from TensorSpeech/dev/stft
Update tf.signal.stft to near librosa
2 parents 5fbd6a8 + 839e5fa commit fa85aa1

File tree

12 files changed

+89
-3897
lines changed

12 files changed

+89
-3897
lines changed

examples/conformer/config.yml

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,14 +24,14 @@ speech_config:
2424
normalize_per_feature: False
2525

2626
decoder_config:
27-
vocabulary: ./vocabularies/librispeech_train_4_4076.subwords
27+
vocabulary: ./vocabularies/librispeech/librispeech_train_4_1030.subwords
2828
target_vocab_size: 4096
2929
max_subword_length: 4
3030
blank_at_zero: True
3131
beam_width: 5
3232
norm_score: True
3333
corpus_files:
34-
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv
34+
- /media/nlhuy/Data/ML/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
3535

3636
model_config:
3737
name: conformer
@@ -74,7 +74,7 @@ learning_config:
7474
num_masks: 1
7575
mask_factor: 27
7676
data_paths:
77-
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/train-clean-100/transcripts.tsv
77+
- /media/nlhuy/Data/ML/Datasets/ASR/Raw/LibriSpeech/train-clean-100/transcripts.tsv
7878
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
7979
shuffle: True
8080
cache: True
@@ -84,9 +84,7 @@ learning_config:
8484

8585
eval_dataset_config:
8686
use_tf: True
87-
data_paths:
88-
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-clean/transcripts.tsv
89-
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/dev-other/transcripts.tsv
87+
data_paths: null
9088
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
9189
shuffle: False
9290
cache: True
@@ -96,8 +94,7 @@ learning_config:
9694

9795
test_dataset_config:
9896
use_tf: True
99-
data_paths:
100-
- /mnt/Miscellanea/Datasets/Speech/LibriSpeech/test-clean/transcripts.tsv
97+
data_paths: null
10198
tfrecords_dir: /mnt/Miscellanea/Datasets/Speech/LibriSpeech/tfrecords
10299
shuffle: False
103100
cache: True

examples/conformer/train_tpu_keras_subword_conformer.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646

4747
parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords")
4848

49+
parser.add_argument("--saved", type=str, default=None, help="Path to saved model")
50+
4951
args = parser.parse_args()
5052

5153
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
@@ -108,6 +110,9 @@
108110
conformer._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, batch_size=global_batch_size)
109111
conformer.summary(line_length=120)
110112

113+
if args.saved:
114+
conformer.load_weights(args.saved, by_name=True, skip_mismatch=True)
115+
111116
optimizer = tf.keras.optimizers.Adam(
112117
TransformerSchedule(
113118
d_model=conformer.dmodel,

examples/contextnet/train_tpu_keras_subword_contextnet.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@
4646

4747
parser.add_argument("--subwords_corpus", nargs="*", type=str, default=[], help="Transcript files for generating subwords")
4848

49+
parser.add_argument("--saved", type=str, default=None, help="Path to saved model")
50+
4951
args = parser.parse_args()
5052

5153
tf.config.optimizer.set_experimental_options({"auto_mixed_precision": args.mxp})
@@ -108,6 +110,9 @@
108110
contextnet._build(speech_featurizer.shape, prediction_shape=text_featurizer.prepand_shape, batch_size=global_batch_size)
109111
contextnet.summary(line_length=120)
110112

113+
if args.saved:
114+
contextnet.load_weights(args.saved, by_name=True, skip_mismatch=True)
115+
111116
optimizer = tf.keras.optimizers.Adam(
112117
TransformerSchedule(
113118
d_model=contextnet.dmodel,

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
setuptools.setup(
2424
name="TensorFlowASR",
25-
version="0.7.8",
25+
version="0.8.0",
2626
author="Huy Le Nguyen",
2727
author_email="[email protected]",
2828
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",

tensorflow_asr/datasets/asr_dataset.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from .base_dataset import BaseDataset, BUFFER_SIZE, TFRECORD_SHARDS, AUTOTUNE
2323
from ..featurizers.speech_featurizers import load_and_convert_to_wav, read_raw_audio, tf_read_raw_audio, SpeechFeaturizer
2424
from ..featurizers.text_featurizers import TextFeaturizer
25-
from ..utils.utils import bytestring_feature, get_num_batches, preprocess_paths, get_nsamples_from_duration
25+
from ..utils.utils import bytestring_feature, get_num_batches, preprocess_paths
2626

2727

2828
class ASRDataset(BaseDataset):
@@ -54,9 +54,7 @@ def __init__(self,
5454
def compute_metadata(self):
5555
self.read_entries()
5656
for _, duration, indices in tqdm.tqdm(self.entries, desc=f"Computing metadata for entries in {self.stage} dataset"):
57-
nsamples = get_nsamples_from_duration(duration, sample_rate=self.speech_featurizer.sample_rate)
58-
# https://www.tensorflow.org/api_docs/python/tf/signal/frame
59-
input_length = -(-nsamples // self.speech_featurizer.frame_step)
57+
input_length = self.speech_featurizer.get_length_from_duration(duration)
6058
label = str(indices).split()
6159
label_length = len(label)
6260
self.speech_featurizer.update_length(input_length)

tensorflow_asr/featurizers/speech_featurizers.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,12 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
1415
import os
1516
import io
1617
import abc
1718
import six
19+
import math
1820
import numpy as np
1921
import librosa
2022
import soundfile as sf
@@ -219,6 +221,7 @@ def __init__(self, speech_config: dict):
219221
self.normalize_signal = speech_config.get("normalize_signal", True)
220222
self.normalize_feature = speech_config.get("normalize_feature", True)
221223
self.normalize_per_feature = speech_config.get("normalize_per_feature", False)
224+
self.center = speech_config.get("center", True)
222225
# Length
223226
self.max_length = 0
224227

@@ -232,6 +235,11 @@ def shape(self) -> list:
232235
""" The shape of extracted features """
233236
raise NotImplementedError()
234237

238+
def get_length_from_duration(self, duration):
239+
nsamples = math.ceil(float(duration) * self.sample_rate)
240+
if self.center: nsamples += self.nfft
241+
return 1 + (nsamples - self.nfft) // self.frame_step # https://www.tensorflow.org/api_docs/python/tf/signal/frame
242+
235243
def update_length(self, length: int):
236244
self.max_length = max(self.max_length, length)
237245

@@ -280,7 +288,7 @@ def shape(self) -> list:
280288
def stft(self, signal):
281289
return np.square(
282290
np.abs(librosa.core.stft(signal, n_fft=self.nfft, hop_length=self.frame_step,
283-
win_length=self.frame_length, center=False, window="hann")))
291+
win_length=self.frame_length, center=self.center, window="hann")))
284292

285293
def power_to_db(self, S, ref=1.0, amin=1e-10, top_db=80.0):
286294
return librosa.power_to_db(S, ref=ref, amin=amin, top_db=top_db)
@@ -409,9 +417,14 @@ def shape(self) -> list:
409417
return [length, self.num_feature_bins, 1]
410418

411419
def stft(self, signal):
412-
return tf.square(
413-
tf.abs(tf.signal.stft(signal, frame_length=self.frame_length,
414-
frame_step=self.frame_step, fft_length=self.nfft, pad_end=True)))
420+
if self.center: signal = tf.pad(signal, [[self.nfft // 2, self.nfft // 2]], mode="REFLECT")
421+
window = tf.signal.hann_window(self.frame_length, periodic=True)
422+
left_pad = (self.nfft - self.frame_length) // 2
423+
right_pad = self.nfft - self.frame_length - left_pad
424+
window = tf.pad(window, [[left_pad, right_pad]])
425+
framed_signals = tf.signal.frame(signal, frame_length=self.nfft, frame_step=self.frame_step)
426+
framed_signals *= window
427+
return tf.square(tf.abs(tf.signal.rfft(framed_signals, [self.nfft])))
415428

416429
def power_to_db(self, S, ref=1.0, amin=1e-10, top_db=80.0):
417430
if amin <= 0:

tensorflow_asr/losses/keras/ctc_losses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818

1919
class CtcLoss(tf.keras.losses.Loss):
20-
def __init__(self, blank=0, global_batch_size=None, reduction=tf.keras.losses.Reduction.NONE, name=None):
21-
super(CtcLoss, self).__init__(reduction=reduction, name=name)
20+
def __init__(self, blank=0, global_batch_size=None, name=None):
21+
super(CtcLoss, self).__init__(reduction=tf.keras.losses.Reduction.NONE, name=name)
2222
self.blank = blank
2323
self.global_batch_size = global_batch_size
2424

tensorflow_asr/losses/keras/rnnt_losses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717

1818

1919
class RnntLoss(tf.keras.losses.Loss):
20-
def __init__(self, blank=0, global_batch_size=None, reduction=tf.keras.losses.Reduction.NONE, name=None):
21-
super(RnntLoss, self).__init__(reduction=reduction, name=name)
20+
def __init__(self, blank=0, global_batch_size=None, name=None):
21+
super(RnntLoss, self).__init__(reduction=tf.keras.losses.Reduction.NONE, name=name)
2222
self.blank = blank
2323
self.global_batch_size = global_batch_size
2424

tensorflow_asr/utils/metrics.py

Lines changed: 40 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,12 @@
1313
# limitations under the License.
1414

1515
from typing import Tuple
16-
import numpy as np
1716
import tensorflow as tf
1817
from nltk.metrics import distance
1918
from .utils import bytes_to_string
2019

2120

22-
def wer(decode: np.ndarray, target: np.ndarray) -> Tuple[tf.Tensor, tf.Tensor]:
23-
"""Word Error Rate
24-
25-
Args:
26-
decode (np.ndarray): array of prediction texts
27-
target (np.ndarray): array of groundtruth texts
28-
29-
Returns:
30-
tuple: a tuple of tf.Tensor of (edit distances, number of words) of each text
31-
"""
21+
def _wer(decode, target):
3222
decode = bytes_to_string(decode)
3323
target = bytes_to_string(target)
3424
dis = 0.0
@@ -45,16 +35,20 @@ def wer(decode: np.ndarray, target: np.ndarray) -> Tuple[tf.Tensor, tf.Tensor]:
4535
return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor(length, tf.float32)
4636

4737

48-
def cer(decode: np.ndarray, target: np.ndarray) -> Tuple[tf.Tensor, tf.Tensor]:
49-
"""Character Error Rate
38+
def wer(_decode: tf.Tensor, _target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
39+
"""Word Error Rate
5040
5141
Args:
5242
decode (np.ndarray): array of prediction texts
5343
target (np.ndarray): array of groundtruth texts
5444
5545
Returns:
56-
tuple: a tuple of tf.Tensor of (edit distances, number of characters) of each text
46+
tuple: a tuple of tf.Tensor of (edit distances, number of words) of each text
5747
"""
48+
return tf.numpy_function(_wer, inp=[_decode, _target], Tout=[tf.float32, tf.float32])
49+
50+
51+
def _cer(decode, target):
5852
decode = bytes_to_string(decode)
5953
target = bytes_to_string(target)
6054
dis = 0
@@ -65,6 +59,36 @@ def cer(decode: np.ndarray, target: np.ndarray) -> Tuple[tf.Tensor, tf.Tensor]:
6559
return tf.convert_to_tensor(dis, tf.float32), tf.convert_to_tensor(length, tf.float32)
6660

6761

62+
def cer(_decode: tf.Tensor, _target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
63+
"""Character Error Rate
64+
65+
Args:
66+
decode (np.ndarray): array of prediction texts
67+
target (np.ndarray): array of groundtruth texts
68+
69+
Returns:
70+
tuple: a tuple of tf.Tensor of (edit distances, number of characters) of each text
71+
"""
72+
return tf.numpy_function(_cer, inp=[_decode, _target], Tout=[tf.float32, tf.float32])
73+
74+
75+
def tf_cer(decode: tf.Tensor, target: tf.Tensor) -> Tuple[tf.Tensor, tf.Tensor]:
76+
"""Tensorflwo Charactor Error rate
77+
78+
Args:
79+
decoder (tf.Tensor): tensor shape [B]
80+
target (tf.Tensor): tensor shape [B]
81+
82+
Returns:
83+
tuple: a tuple of tf.Tensor of (edit distances, number of characters) of each text
84+
"""
85+
decode = tf.strings.bytes_split(decode) # [B, N]
86+
target = tf.strings.bytes_split(target) # [B, M]
87+
distances = tf.edit_distance(decode.to_sparse(), target.to_sparse(), normalize=False) # [B]
88+
lengths = tf.cast(target.row_lengths(axis=1), dtype=tf.float32) # [B]
89+
return tf.reduce_sum(distances), tf.reduce_sum(lengths)
90+
91+
6892
class ErrorRate(tf.keras.metrics.Metric):
6993
""" Metric for WER and CER """
7094

@@ -75,10 +99,9 @@ def __init__(self, func, name="error_rate", **kwargs):
7599
self.func = func
76100

77101
def update_state(self, decode: tf.Tensor, target: tf.Tensor):
78-
n, d = tf.numpy_function(self.func, inp=[decode, target], Tout=[tf.float32, tf.float32])
102+
n, d = self.func(decode, target)
79103
self.numerator.assign_add(n)
80104
self.denominator.assign_add(d)
81105

82106
def result(self):
83-
if self.denominator == 0.0: return 0.0
84-
return (self.numerator / self.denominator) * 100
107+
return tf.math.divide_no_nan(self.numerator, self.denominator) * 100

tensorflow_asr/utils/utils.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,3 @@ def body(index, tfarray):
238238

239239
index, tfarray = tf.while_loop(condition, body, loop_vars=[index, tfarray], swap_memory=False)
240240
return tfarray
241-
242-
243-
def get_nsamples_from_duration(duration, sample_rate=16000):
244-
return math.ceil(float(duration) * sample_rate)

0 commit comments

Comments
 (0)