Skip to content

Commit 0176ebc

Browse files
support for sampling non keywords
1 parent ae2668e commit 0176ebc

File tree

1 file changed

+43
-6
lines changed

1 file changed

+43
-6
lines changed

examples/pytorch/FastCells/train_classifier.py

Lines changed: 43 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
import torch.nn.functional as F
2323
import torch.optim as optim
2424
import torch.onnx
25+
import random
2526

2627
from torch.autograd import Variable, Function
2728
from torch.utils.data import Dataset, DataLoader
@@ -306,7 +307,7 @@ class AudioDataset(Dataset):
306307
mini-batch training.
307308
"""
308309

309-
def __init__(self, filename, config, keywords):
310+
def __init__(self, filename, config, keywords, training=False):
310311
""" Initialize the AudioDataset from the given *.npz file """
311312
self.dataset = np.load(filename)
312313

@@ -331,34 +332,59 @@ def __init__(self, filename, config, keywords):
331332
else:
332333
self.mean = None
333334
self.std = None
335+
334336
self.label_names = self.dataset["labels"]
335337
self.keywords = keywords
336338
self.num_keywords = len(self.keywords)
337339
self.labels = self.to_long_vector()
340+
341+
self.keywords_idx = None
342+
self.non_keywords_idx = None
343+
if training and config.sample_non_kw is not None:
344+
self.keywords_idx, self.non_keywords_idx = self.get_keyword_idx(config.sample_non_kw)
345+
self.sample_non_kw_probability = config.sample_non_kw_probability
346+
338347
msg = "Loaded dataset {} and found sample rate {}, audio_size {}, input_size {}, window_size {} and shift {}"
339348
print(msg.format(os.path.basename(filename), self.sample_rate, self.audio_size, self.input_size,
340349
self.window_size, self.shift))
341350

342351
def get_data_loader(self, batch_size):
343352
""" Get a DataLoader that can enumerate shuffled batches of data in this dataset """
344353
return DataLoader(self, batch_size=batch_size, shuffle=True, drop_last=True)
345-
354+
346355
def to_long_vector(self):
347356
""" convert the expected labels to a list of integer indexes into the array of keywords """
348357
indexer = [(0 if x == "<null>" else self.keywords.index(x)) for x in self.label_names]
349358
return np.array(indexer, dtype=np.longlong)
350359

360+
def get_keyword_idx(self, non_kw_label):
361+
""" find the keywords and store there index """
362+
indexer = [ids for ids, label in enumerate(self.label_names) if label != non_kw_label]
363+
non_indexer = [ids for ids, label in enumerate(self.label_names) if label == non_kw_label]
364+
return (np.array(indexer, dtype=np.longlong), np.array(non_indexer, dtype=np.longlong))
365+
351366
def __len__(self):
352367
""" Return the number of rows in this Dataset """
353-
return self.num_rows
368+
if self.non_keywords_idx is None:
369+
return self.num_rows
370+
else:
371+
return int(len(self.keywords_idx) / (1-self.sample_non_kw_probability))
354372

355373
def __getitem__(self, idx):
356374
""" Return a single labelled sample here as a tuple """
357-
audio = self.features[idx] # batch index is second dimension
358-
label = self.labels[idx]
375+
if self.non_keywords_idx is None:
376+
updated_idx=idx
377+
else:
378+
if idx < len(self.keywords_idx):
379+
updated_idx=self.keywords_idx[idx]
380+
else:
381+
updated_idx=np.random.choice(self.non_keywords_idx)
382+
audio = self.features[updated_idx] # batch index is second dimension
383+
label = self.labels[updated_idx]
359384
sample = (audio, label)
360385
return sample
361386

387+
362388

363389
def create_model(model_config, input_size, num_keywords):
364390
ModelClass = get_model_class(KeywordSpotter)
@@ -453,7 +479,7 @@ def train(config, evaluate_only=False, outdir=".", detail=False, azureml=False):
453479
log = None
454480
if not evaluate_only:
455481
print("Loading {}...".format(training_file))
456-
training_data = AudioDataset(training_file, config.dataset, keywords)
482+
training_data = AudioDataset(training_file, config.dataset, keywords, training=True)
457483

458484
print("Loading {}...".format(validation_file))
459485
validation_data = AudioDataset(validation_file, config.dataset, keywords)
@@ -556,6 +582,8 @@ def str2bool(v):
556582
parser.add_argument("--rolling", help="Whether to train model in rolling fashion or not", action="store_true")
557583
parser.add_argument("--max_rolling_length", help="Max number of epochs you want to roll the rolling training"
558584
" default is 100", type=int)
585+
parser.add_argument("--sample_non_kw", "-sl", type=str, help="Sample data for this label with probability sample_prob")
586+
parser.add_argument("--sample_non_kw_probability", "-spr", type=float, help="Sample from scl with this probability")
559587

560588
# arguments for fastgrnn
561589
parser.add_argument("--wRank", "-wr", help="Rank of W in 1st layer of FastGRNN default is None", type=int)
@@ -645,6 +673,15 @@ def str2bool(v):
645673
config.dataset.categories = args.categories
646674
if args.dataset:
647675
config.dataset.path = args.dataset
676+
if args.sample_non_kw:
677+
config.dataset.sample_non_kw = args.sample_non_kw
678+
if args.sample_non_kw_probability is None:
679+
config.dataset.sample_non_kw_probability = 0.5
680+
else:
681+
config.dataset.sample_non_kw_probability = args.sample_non_kw_probability
682+
else:
683+
config.dataset.sample_non_kw = None
684+
648685
if args.wRank:
649686
config.model.wRank = args.wRank
650687
if args.uRank:

0 commit comments

Comments
 (0)