22
22
import torch .nn .functional as F
23
23
import torch .optim as optim
24
24
import torch .onnx
25
+ import random
25
26
26
27
from torch .autograd import Variable , Function
27
28
from torch .utils .data import Dataset , DataLoader
@@ -306,7 +307,7 @@ class AudioDataset(Dataset):
306
307
mini-batch training.
307
308
"""
308
309
309
- def __init__ (self , filename , config , keywords ):
310
+ def __init__ (self , filename , config , keywords , training = False ):
310
311
""" Initialize the AudioDataset from the given *.npz file """
311
312
self .dataset = np .load (filename )
312
313
@@ -331,34 +332,59 @@ def __init__(self, filename, config, keywords):
331
332
else :
332
333
self .mean = None
333
334
self .std = None
335
+
334
336
self .label_names = self .dataset ["labels" ]
335
337
self .keywords = keywords
336
338
self .num_keywords = len (self .keywords )
337
339
self .labels = self .to_long_vector ()
340
+
341
+ self .keywords_idx = None
342
+ self .non_keywords_idx = None
343
+ if training and config .sample_non_kw is not None :
344
+ self .keywords_idx , self .non_keywords_idx = self .get_keyword_idx (config .sample_non_kw )
345
+ self .sample_non_kw_probability = config .sample_non_kw_probability
346
+
338
347
msg = "Loaded dataset {} and found sample rate {}, audio_size {}, input_size {}, window_size {} and shift {}"
339
348
print (msg .format (os .path .basename (filename ), self .sample_rate , self .audio_size , self .input_size ,
340
349
self .window_size , self .shift ))
341
350
342
351
def get_data_loader (self , batch_size ):
343
352
""" Get a DataLoader that can enumerate shuffled batches of data in this dataset """
344
353
return DataLoader (self , batch_size = batch_size , shuffle = True , drop_last = True )
345
-
354
+
346
355
def to_long_vector (self ):
347
356
""" convert the expected labels to a list of integer indexes into the array of keywords """
348
357
indexer = [(0 if x == "<null>" else self .keywords .index (x )) for x in self .label_names ]
349
358
return np .array (indexer , dtype = np .longlong )
350
359
360
+ def get_keyword_idx (self , non_kw_label ):
361
+ """ find the keywords and store there index """
362
+ indexer = [ids for ids , label in enumerate (self .label_names ) if label != non_kw_label ]
363
+ non_indexer = [ids for ids , label in enumerate (self .label_names ) if label == non_kw_label ]
364
+ return (np .array (indexer , dtype = np .longlong ), np .array (non_indexer , dtype = np .longlong ))
365
+
351
366
def __len__ (self ):
352
367
""" Return the number of rows in this Dataset """
353
- return self .num_rows
368
+ if self .non_keywords_idx is None :
369
+ return self .num_rows
370
+ else :
371
+ return int (len (self .keywords_idx ) / (1 - self .sample_non_kw_probability ))
354
372
355
373
def __getitem__ (self , idx ):
356
374
""" Return a single labelled sample here as a tuple """
357
- audio = self .features [idx ] # batch index is second dimension
358
- label = self .labels [idx ]
375
+ if self .non_keywords_idx is None :
376
+ updated_idx = idx
377
+ else :
378
+ if idx < len (self .keywords_idx ):
379
+ updated_idx = self .keywords_idx [idx ]
380
+ else :
381
+ updated_idx = np .random .choice (self .non_keywords_idx )
382
+ audio = self .features [updated_idx ] # batch index is second dimension
383
+ label = self .labels [updated_idx ]
359
384
sample = (audio , label )
360
385
return sample
361
386
387
+
362
388
363
389
def create_model (model_config , input_size , num_keywords ):
364
390
ModelClass = get_model_class (KeywordSpotter )
@@ -453,7 +479,7 @@ def train(config, evaluate_only=False, outdir=".", detail=False, azureml=False):
453
479
log = None
454
480
if not evaluate_only :
455
481
print ("Loading {}..." .format (training_file ))
456
- training_data = AudioDataset (training_file , config .dataset , keywords )
482
+ training_data = AudioDataset (training_file , config .dataset , keywords , training = True )
457
483
458
484
print ("Loading {}..." .format (validation_file ))
459
485
validation_data = AudioDataset (validation_file , config .dataset , keywords )
@@ -556,6 +582,8 @@ def str2bool(v):
556
582
parser .add_argument ("--rolling" , help = "Whether to train model in rolling fashion or not" , action = "store_true" )
557
583
parser .add_argument ("--max_rolling_length" , help = "Max number of epochs you want to roll the rolling training"
558
584
" default is 100" , type = int )
585
+ parser .add_argument ("--sample_non_kw" , "-sl" , type = str , help = "Sample data for this label with probability sample_prob" )
586
+ parser .add_argument ("--sample_non_kw_probability" , "-spr" , type = float , help = "Sample from scl with this probability" )
559
587
560
588
# arguments for fastgrnn
561
589
parser .add_argument ("--wRank" , "-wr" , help = "Rank of W in 1st layer of FastGRNN default is None" , type = int )
@@ -645,6 +673,15 @@ def str2bool(v):
645
673
config .dataset .categories = args .categories
646
674
if args .dataset :
647
675
config .dataset .path = args .dataset
676
+ if args .sample_non_kw :
677
+ config .dataset .sample_non_kw = args .sample_non_kw
678
+ if args .sample_non_kw_probability is None :
679
+ config .dataset .sample_non_kw_probability = 0.5
680
+ else :
681
+ config .dataset .sample_non_kw_probability = args .sample_non_kw_probability
682
+ else :
683
+ config .dataset .sample_non_kw = None
684
+
648
685
if args .wRank :
649
686
config .model .wRank = args .wRank
650
687
if args .uRank :
0 commit comments