Skip to content

Commit b10cb03

Browse files
committed
✍️ tensorflow-io does not support tpu yet
1 parent b0cc951 commit b10cb03

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

examples/conformer/train_tpu_keras_subword_conformer.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,5 +125,4 @@
125125
conformer.fit(
126126
train_data_loader, epochs=config.learning_config.running_config.num_epochs,
127127
validation_data=eval_data_loader, callbacks=callbacks,
128-
steps_per_epoch=train_dataset.total_steps
129128
)

tensorflow_asr/featurizers/speech_featurizers.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,11 @@
2121
import tensorflow as tf
2222
import tensorflow_io as tfio
2323

24-
from ..utils.utils import log10
24+
from ..utils.utils import log10, has_tpu
2525
from .gammatone import fft_weights
2626

27+
tpu = has_tpu()
28+
2729

2830
def load_and_convert_to_wav(path: str) -> tf.Tensor:
2931
wave, rate = librosa.load(os.path.expanduser(path), sr=None, mono=True)
@@ -48,8 +50,10 @@ def read_raw_audio(audio, sample_rate=16000):
4850

4951
def tf_read_raw_audio(audio: tf.Tensor, sample_rate=16000):
5052
wave, rate = tf.audio.decode_wav(audio, desired_channels=1, desired_samples=-1)
51-
resampled = tfio.audio.resample(wave, rate_in=tf.cast(rate, dtype=tf.int64), rate_out=sample_rate)
52-
return tf.reshape(resampled, shape=[-1]) # reshape for using tf.signal
53+
if not tpu:
54+
resampled = tfio.audio.resample(wave, rate_in=tf.cast(rate, dtype=tf.int64), rate_out=sample_rate)
55+
return tf.reshape(resampled, shape=[-1]) # reshape for using tf.signal
56+
return tf.reshape(wave, shape=[-1]) # reshape for using tf.signal
5357

5458

5559
def slice_signal(signal, window_size, stride=0.5) -> np.ndarray:

tensorflow_asr/utils/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,12 @@ def has_gpu_or_tpu():
169169
return True
170170

171171

172+
def has_tpu():
173+
tpus = tf.config.list_logical_devices("TPU")
174+
if len(tpus) == 0: return False
175+
return True
176+
177+
172178
def find_max_length_prediction_tfarray(tfarray: tf.TensorArray) -> tf.Tensor:
173179
with tf.name_scope("find_max_length_prediction_tfarray"):
174180
index = tf.constant(0, dtype=tf.int32)

0 commit comments

Comments
 (0)