2828TFRECORD_SHARDS = 16
2929
3030
31- def write_tfrecord_file (splitted_entries ):
32- shard_path , entries = splitted_entries
33- with tf .io .TFRecordWriter (shard_path , options = 'ZLIB' ) as out :
34- for path , audio , indices in entries :
35- feature = {
36- "path" : bytestring_feature ([bytes (path , "utf-8" )]),
37- "audio" : bytestring_feature ([audio ]),
38- "indices" : bytestring_feature ([bytes (indices , "utf-8" )])
39- }
40- example = tf .train .Example (features = tf .train .Features (feature = feature ))
41- out .write (example .SerializeToString ())
42- print_one_line ("Processed:" , path )
43- print (f"\n Created { shard_path } " )
44-
45-
4631class ASRDataset (BaseDataset ):
4732 """ Dataset for ASR using Generator """
4833
@@ -54,40 +39,39 @@ def __init__(self,
5439 augmentations : Augmentation = Augmentation (None ),
5540 cache : bool = False ,
5641 shuffle : bool = False ,
57- use_tf : bool = False ,
5842 drop_remainder : bool = True ,
5943 buffer_size : int = BUFFER_SIZE ):
6044 super (ASRDataset , self ).__init__ (
6145 data_paths = data_paths , augmentations = augmentations ,
6246 cache = cache , shuffle = shuffle , stage = stage , buffer_size = buffer_size ,
63- use_tf = use_tf , drop_remainder = drop_remainder
47+ drop_remainder = drop_remainder
6448 )
6549 self .speech_featurizer = speech_featurizer
6650 self .text_featurizer = text_featurizer
6751
6852 def read_entries (self ):
69- self .lines = []
53+ self .entries = []
7054 for file_path in self .data_paths :
7155 print (f"Reading { file_path } ..." )
7256 with tf .io .gfile .GFile (file_path , "r" ) as f :
7357 temp_lines = f .read ().splitlines ()
7458 # Skip the header of tsv file
75- self .lines += temp_lines [1 :]
59+ self .entries += temp_lines [1 :]
7660 # The files is "\t" seperated
77- self .lines = [line .split ("\t " , 2 ) for line in self .lines ]
78- self . lines = np . array (self .lines )
79- for i , line in enumerate ( self .lines ):
80- self .lines [ i ][ - 1 ] = " " . join ([ str ( x ) for x in self .text_featurizer . extract ( line [ - 1 ]). numpy ()] )
81- if self .shuffle : np .random .shuffle (self .lines ) # Mix transcripts.tsv
82- self .total_steps = len (self .lines )
61+ self .entries = [line .split ("\t " , 2 ) for line in self .entries ]
62+ for i , line in enumerate (self .entries ):
63+ self . entries [ i ][ - 1 ] = " " . join ([ str ( x ) for x in self .text_featurizer . extract ( line [ - 1 ]). numpy ()])
64+ self .entries = np . array ( self .entries )
65+ if self .shuffle : np .random .shuffle (self .entries ) # Mix transcripts.tsv
66+ self .total_steps = len (self .entries )
8367
8468 def generator (self ):
85- for path , _ , indices in self .lines :
86- audio = load_and_convert_to_wav (path )
87- yield path , audio , indices
69+ for path , _ , indices in self .entries :
70+ audio = load_and_convert_to_wav (path ). numpy ()
71+ yield bytes ( path , "utf-8" ), audio , bytes ( indices , "utf-8" )
8872
89- def preprocess (self , path , audio , indices ):
90- def fn (_path , _audio , _indices ):
73+ def preprocess (self , path : tf . Tensor , audio : tf . Tensor , indices : tf . Tensor ):
74+ def fn (_path : bytes , _audio : bytes , _indices : bytes ):
9175 with tf .device ("/CPU:0" ):
9276 signal = read_raw_audio (_audio , self .speech_featurizer .sample_rate )
9377
@@ -111,7 +95,7 @@ def fn(_path, _audio, _indices):
11195 Tout = [tf .string , tf .float32 , tf .int32 , tf .int32 , tf .int32 , tf .int32 , tf .int32 ]
11296 )
11397
114- def tf_preprocess (self , path , audio , indices ):
98+ def tf_preprocess (self , path : tf . Tensor , audio : tf . Tensor , indices : tf . Tensor ):
11599 with tf .device ("/CPU:0" ):
116100 signal = tf_read_raw_audio (audio , self .speech_featurizer .sample_rate )
117101
@@ -130,7 +114,7 @@ def tf_preprocess(self, path, audio, indices):
130114
131115 return path , features , input_length , label , label_length , prediction , prediction_length
132116
133- def process (self , dataset , batch_size ):
117+ def process (self , dataset : tf . data . Dataset , batch_size : int ):
134118 dataset = dataset .map (self .parse , num_parallel_calls = AUTOTUNE )
135119
136120 if self .cache :
@@ -193,18 +177,34 @@ def __init__(self,
193177 tfrecords_shards : int = TFRECORD_SHARDS ,
194178 cache : bool = False ,
195179 shuffle : bool = False ,
196- use_tf : bool = False ,
180+ drop_remainder : bool = True ,
197181 buffer_size : int = BUFFER_SIZE ):
198182 super (ASRTFRecordDataset , self ).__init__ (
199183 stage = stage , speech_featurizer = speech_featurizer , text_featurizer = text_featurizer ,
200184 data_paths = data_paths , augmentations = augmentations , cache = cache , shuffle = shuffle , buffer_size = buffer_size ,
201- use_tf = use_tf
185+ drop_remainder = drop_remainder
202186 )
203187 self .tfrecords_dir = tfrecords_dir
204188 if tfrecords_shards <= 0 : raise ValueError ("tfrecords_shards must be positive" )
205189 self .tfrecords_shards = tfrecords_shards
206190 if not tf .io .gfile .exists (self .tfrecords_dir ): tf .io .gfile .makedirs (self .tfrecords_dir )
207191
192+ @staticmethod
193+ def write_tfrecord_file (splitted_entries ):
194+ shard_path , entries = splitted_entries
195+ with tf .io .TFRecordWriter (shard_path , options = 'ZLIB' ) as out :
196+ for path , _ , indices in entries :
197+ audio = load_and_convert_to_wav (path ).numpy ()
198+ feature = {
199+ "path" : bytestring_feature ([bytes (path , "utf-8" )]),
200+ "audio" : bytestring_feature ([audio ]),
201+ "indices" : bytestring_feature ([bytes (indices , "utf-8" )])
202+ }
203+ example = tf .train .Example (features = tf .train .Features (feature = feature ))
204+ out .write (example .SerializeToString ())
205+ print_one_line ("Processed:" , path )
206+ print (f"\n Created { shard_path } " )
207+
208208 def create_tfrecords (self ):
209209 if not tf .io .gfile .exists (self .tfrecords_dir ):
210210 tf .io .gfile .makedirs (self .tfrecords_dir )
@@ -217,16 +217,15 @@ def create_tfrecords(self):
217217
218218 self .read_entries ()
219219 if not self .total_steps or self .total_steps == 0 : return False
220- entries = np .fromiter (self .generator (), dtype = str )
221220
222221 def get_shard_path (shard_id ):
223222 return os .path .join (self .tfrecords_dir , f"{ self .stage } _{ shard_id } .tfrecord" )
224223
225224 shards = [get_shard_path (idx ) for idx in range (1 , self .tfrecords_shards + 1 )]
226225
227- splitted_entries = np .array_split (entries , self .tfrecords_shards )
226+ splitted_entries = np .array_split (self . entries , self .tfrecords_shards )
228227 with multiprocessing .Pool (self .tfrecords_shards ) as pool :
229- pool .map (write_tfrecord_file , zip (shards , splitted_entries ))
228+ pool .map (self . write_tfrecord_file , zip (shards , splitted_entries ))
230229
231230 return True
232231
@@ -260,12 +259,13 @@ class ASRSliceDataset(ASRDataset):
260259
261260 @staticmethod
262261 def load (record : tf .Tensor ):
263- audio = load_and_convert_to_wav (record [0 ])
262+ def fn (path : bytes ): return load_and_convert_to_wav (path .decode ("utf-8" )).numpy ()
263+ audio = tf .numpy_function (fn , inp = [record [0 ]], Tout = tf .string )
264264 return record [0 ], audio , record [2 ]
265265
266266 def create (self , batch_size : int ):
267267 self .read_entries ()
268268 if not self .total_steps or self .total_steps == 0 : return None
269- dataset = tf .data .Dataset .from_tensor_slices (self .lines )
269+ dataset = tf .data .Dataset .from_tensor_slices (self .entries )
270270 dataset = dataset .map (self .load , num_parallel_calls = AUTOTUNE )
271271 return self .process (dataset , batch_size )
0 commit comments