1+ import random
2+ from os .path import join
3+
4+ import numpy as np
5+
6+ from precise .params import pr
7+ from precise .util import save_audio
8+ from test .scripts .test_utils .temp_folder import TempFolder
9+
10+
11+ class DummyTrainFolder (TempFolder ):
12+ def __init__ (self , root = None ):
13+ super ().__init__ (root )
14+ self .model = self .path ('model.net' )
15+
16+ def generate_samples (self , count , subfolder , name , generator ):
17+ """
18+ Generate sample audio files in a folder
19+
20+ The file is generated in the specified folder, with the specified
21+ name and generated value.
22+
23+ Args:
24+ count: Number of samples to generate
25+ subfolder: String or list of subfolder path
26+ name: Format string used to generate each sample
27+ generator: Function called to get the data for each sample
28+ """
29+ if isinstance (subfolder , str ):
30+ subfolder = [subfolder ]
31+ for i in range (count ):
32+ save_audio (join (self .subdir (* subfolder ), name .format (i )), generator ())
33+
34+ def get_duration (self ):
35+ """Generate a random sample duration"""
36+ return int (random .random () * 2 * pr .buffer_samples )
37+
38+ def generate_default (self , count = 10 ):
39+ self .generate_samples (
40+ count , 'wake-word' , 'ww-{}.wav' ,
41+ lambda : np .ones (self .get_duration (), dtype = float )
42+ )
43+ self .generate_samples (
44+ count , 'not-wake-word' , 'nww-{}.wav' ,
45+ lambda : np .random .random (self .get_duration ()) * 2 - 1
46+ )
47+ self .generate_samples (
48+ count , ('test' , 'wake-word' ), 'ww-{}.wav' ,
49+ lambda : np .ones (self .get_duration (), dtype = float )
50+ )
51+ self .generate_samples (
52+ count , ('test' , 'not-wake-word' ), 'nww-{}.wav' ,
53+ lambda : np .random .random (self .get_duration ()) * 2 - 1
54+ )
55+ self .model = self .path ('model.net' )
0 commit comments