Skip to content

Commit 6cb3a49

Browse files
committed
✍️ fix some minor bugs
1 parent 3525598 commit 6cb3a49

File tree

4 files changed

+8
-9
lines changed

4 files changed

+8
-9
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636

3737
setuptools.setup(
3838
name="TensorFlowASR",
39-
version="0.7.1",
39+
version="0.7.2",
4040
author="Huy Le Nguyen",
4141
author_email="[email protected]",
4242
description="Almost State-of-the-art Automatic Speech Recognition using Tensorflow 2",

tensorflow_asr/datasets/asr_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def process(self, dataset: tf.data.Dataset, batch_size: int):
140140

141141
# PREFETCH to improve speed of input length
142142
dataset = dataset.prefetch(AUTOTUNE)
143-
self.total_steps = get_num_batches(self.total_steps, batch_size)
143+
self.total_steps = get_num_batches(self.total_steps, batch_size, drop_remainders=self.drop_remainder)
144144
return dataset
145145

146146
@tf.function

tensorflow_asr/datasets/keras/asr_dataset.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def process(self, dataset, batch_size):
9393

9494
# PREFETCH to improve speed of input length
9595
dataset = dataset.prefetch(AUTOTUNE)
96-
self.total_steps = get_num_batches(self.total_steps, batch_size)
96+
self.total_steps = get_num_batches(self.total_steps, batch_size, drop_remainders=self.drop_remainder)
9797
return dataset
9898

9999

tensorflow_asr/runners/base_runners.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -370,13 +370,12 @@ def __init__(self,
370370
"greed_cer": ErrorRate(func=cer, name="test_greed_cer", dtype=tf.float32)
371371
}
372372

373-
def set_output_file(self):
373+
def set_output_file(self, batch_size: int = 1):
374+
if not batch_size: batch_size = self.config.batch_size
374375
if os.path.exists(self.output_file_path):
375376
with open(self.output_file_path, "r", encoding="utf-8") as out:
376-
self.processed_records = get_num_batches(
377-
len(out.read().splitlines()) - 1,
378-
batch_size=1
379-
)
377+
self.processed_records = get_num_batches(len(out.read().splitlines()) - 1, batch_size=batch_size,
378+
drop_remainders=False)
380379
else:
381380
with open(self.output_file_path, "w") as out:
382381
out.write("PATH\tGROUNDTRUTH\tGREEDY\tBEAMSEARCH\tBEAMSEARCHLM\n")
@@ -396,7 +395,7 @@ def compile(self, trained_model: tf.keras.Model):
396395
self.model = trained_model
397396

398397
def run(self, test_dataset, batch_size=None):
399-
self.set_output_file()
398+
self.set_output_file(batch_size=batch_size)
400399
self.set_test_data_loader(test_dataset, batch_size=batch_size)
401400
self._test_epoch()
402401
self._finish()

0 commit comments

Comments
 (0)