Skip to content

Commit cce78cb

Browse files
committed
Merge branch 'main' of github.com:TensorSpeech/TensorFlowASR
2 parents 6cb3a49 + 5971819 commit cce78cb

File tree

4 files changed

+96
-6
lines changed

4 files changed

+96
-6
lines changed

examples/demonstration/conformer.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,24 @@
3838

3939
parser.add_argument("--subwords", type=str, default=None, help="Path to file that stores generated subwords")
4040

41+
parser.add_argument("--sentence_piece", default=False, action="store_true", help="Whether to use `SentencePiece` model")
42+
4143
args = parser.parse_args()
4244

4345
setup_devices([args.device], cpu=args.cpu)
4446

4547
from tensorflow_asr.configs.config import Config
4648
from tensorflow_asr.featurizers.speech_featurizers import read_raw_audio
4749
from tensorflow_asr.featurizers.speech_featurizers import TFSpeechFeaturizer
48-
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer, SubwordFeaturizer
50+
from tensorflow_asr.featurizers.text_featurizers import CharFeaturizer, SubwordFeaturizer, SentencePieceFeaturizer
4951
from tensorflow_asr.models.conformer import Conformer
5052

5153
config = Config(args.config)
5254
speech_featurizer = TFSpeechFeaturizer(config.speech_config)
53-
if args.subwords and os.path.exists(args.subwords):
55+
if args.sentence_piece:
56+
print("Loading SentencePiece model ...")
57+
text_featurizer = SentencePieceFeaturizer.load_from_file(config.decoder_config, args.subwords)
58+
elif args.subwords and os.path.exists(args.subwords):
5459
print("Loading subwords ...")
5560
text_featurizer = SubwordFeaturizer.load_from_file(config.decoder_config, args.subwords)
5661
else:

tensorflow_asr/featurizers/text_featurizers.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,8 +323,7 @@ def __init__(self, decoder_config: dict, model=None):
323323
self.upoints = None
324324
# vocab size
325325
self.num_classes = self.model.get_piece_size()
326-
# create upoints
327-
self.__init_upoints()
326+
self.upoints = None
328327

329328
def __init_upoints(self):
330329
text = [""]
@@ -451,6 +450,8 @@ def indices2upoints(self, indices: tf.Tensor) -> tf.Tensor:
451450
Returns:
452451
unicode code points transcript with dtype tf.int32 and shape [None]
453452
"""
453+
if self.upoints is None:
454+
self.__init_upoints()
454455
with tf.name_scope("indices2upoints"):
455456
indices = self.normalize_indices(indices)
456457
upoints = tf.gather_nd(self.upoints, tf.expand_dims(indices, axis=-1))

tensorflow_asr/losses/rnnt_losses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# RNNT loss implementation in pure TensorFlow is borrowed from [iamjanvijay's repo](https://github.com/iamjanvijay/rnnt)
1515

1616
import tensorflow as tf
17-
17+
from tensorflow.python.ops.gen_array_ops import matrix_diag_part_v2
1818
from ..utils.utils import has_gpu_or_tpu
1919

2020
use_cpu = not has_gpu_or_tpu()
@@ -24,7 +24,7 @@
2424
use_warprnnt = True
2525
except ImportError:
2626
print("Cannot import RNNT loss in warprnnt. Falls back to RNNT in TensorFlow")
27-
from tensorflow.python.ops.gen_array_ops import matrix_diag_part_v2
27+
print("Note: The RNNT in Tensorflow is not supported for CPU yet")
2828
use_warprnnt = False
2929

3030

tests/losses/test_rnnt_loss.py

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import tensorflow_asr.losses.rnnt_losses as rnnt_losses
2+
import numpy as np
3+
import tensorflow as tf
4+
5+
6+
class WarpRNNTTest(tf.test.TestCase):
7+
def _run_rnnt(self, acts, labels, input_lengths, label_lengths,
8+
expected_costs, expected_grads, use_gpu=False):
9+
self.assertEquals(acts.shape, expected_grads.shape)
10+
acts_t = tf.constant(acts)
11+
labels_t = tf.constant(labels)
12+
input_lengths_t = tf.constant(input_lengths)
13+
label_lengths_t = tf.constant(label_lengths)
14+
15+
with tf.GradientTape() as tape:
16+
# by default, GradientTape doesn’t track constants
17+
tape.watch(acts_t)
18+
tape.watch(labels_t)
19+
tape.watch(input_lengths_t)
20+
tape.watch(label_lengths_t)
21+
logits = acts_t if use_gpu else tf.nn.log_softmax(acts_t)
22+
costs = rnnt_losses.rnnt_loss_tf(logits=logits,
23+
labels=labels_t,
24+
label_length=label_lengths_t,
25+
logit_length=input_lengths_t,
26+
name=None)
27+
28+
grads = tape.gradient(costs, [acts_t])[0]
29+
self.assertAllClose(costs, expected_costs, atol=1e-6)
30+
self.assertAllClose(grads, expected_grads, atol=1e-6)
31+
32+
33+
def _test_multiple_batches(self, use_gpu):
34+
B = 2; T = 4; U = 3; V = 3
35+
36+
acts = np.array([0.065357, 0.787530, 0.081592, 0.529716, 0.750675, 0.754135,
37+
0.609764, 0.868140, 0.622532, 0.668522, 0.858039, 0.164539,
38+
0.989780, 0.944298, 0.603168, 0.946783, 0.666203, 0.286882,
39+
0.094184, 0.366674, 0.736168, 0.166680, 0.714154, 0.399400,
40+
0.535982, 0.291821, 0.612642, 0.324241, 0.800764, 0.524106,
41+
0.779195, 0.183314, 0.113745, 0.240222, 0.339470, 0.134160,
42+
0.505562, 0.051597, 0.640290, 0.430733, 0.829473, 0.177467,
43+
0.320700, 0.042883, 0.302803, 0.675178, 0.569537, 0.558474,
44+
0.083132, 0.060165, 0.107958, 0.748615, 0.943918, 0.486356,
45+
0.418199, 0.652408, 0.024243, 0.134582, 0.366342, 0.295830,
46+
0.923670, 0.689929, 0.741898, 0.250005, 0.603430, 0.987289,
47+
0.592606, 0.884672, 0.543450, 0.660770, 0.377128, 0.358021], dtype=np.float32).reshape(B, T, U, V);
48+
49+
expected_costs = np.array([4.28065, 3.93844], dtype=np.float32)
50+
expected_grads = np.array([-0.186844, -0.062555, 0.249399, -0.203377, 0.202399, 0.000977,
51+
-0.141016, 0.079123, 0.061893, -0.011552, -0.081280, 0.092832,
52+
-0.154257, 0.229433, -0.075176, -0.246593, 0.146405, 0.100188,
53+
-0.012918, -0.061593, 0.074512, -0.055986, 0.219831, -0.163845,
54+
-0.497627, 0.209240, 0.288387, 0.013605, -0.030220, 0.016615,
55+
0.113925, 0.062781, -0.176706, -0.667078, 0.367659, 0.299419,
56+
-0.356344, -0.055347, 0.411691, -0.096922, 0.029459, 0.067463,
57+
-0.063518, 0.027654, 0.035863, -0.154499, -0.073942, 0.228441,
58+
-0.166790, -0.000088, 0.166878, -0.172370, 0.105565, 0.066804,
59+
0.023875, -0.118256, 0.094381, -0.104707, -0.108934, 0.213642,
60+
-0.369844, 0.180118, 0.189726, 0.025714, -0.079462, 0.053748,
61+
0.122328, -0.238789, 0.116460, -0.598687, 0.302203, 0.296484], dtype=np.float32).reshape(B, T, U, V);
62+
63+
labels = np.array([[1, 2], [1, 1]], dtype=np.int32)
64+
input_lengths = np.array([4, 4], dtype=np.int32)
65+
label_lengths = np.array([2, 2], dtype=np.int32)
66+
67+
self._run_rnnt(acts,
68+
labels,
69+
input_lengths,
70+
label_lengths,
71+
expected_costs,
72+
expected_grads)
73+
74+
def test_multiple_batches_gpu(self):
75+
rnnt_losses.use_warprnnt = False
76+
self._test_multiple_batches(use_gpu=True)
77+
78+
def test_multiple_batches_cpu(self):
79+
rnnt_losses.use_warprnnt = False
80+
self._test_multiple_batches(use_gpu=False)
81+
82+
83+
if __name__ == "__main__":
84+
tf.test.main()

0 commit comments

Comments
 (0)