1+ # Copyright 2020 Mycroft AI Inc.
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+ import random
15+ from os .path import join
16+
17+ import numpy as np
18+
19+ from precise .params import pr
20+ from precise .util import save_audio
21+ from test .scripts .test_utils .temp_folder import TempFolder
22+
23+
24+ class DummyTrainFolder (TempFolder ):
25+ def __init__ (self , root = None ):
26+ super ().__init__ (root )
27+ self .model = self .path ('model.net' )
28+
29+ def generate_samples (self , count , subfolder , name , generator ):
30+ """
31+ Generate sample audio files in a folder
32+
33+ The file is generated in the specified folder, with the specified
34+ name and generated value.
35+
36+ Args:
37+ count: Number of samples to generate
38+ subfolder: String or list of subfolder path
39+ name: Format string used to generate each sample
40+ generator: Function called to get the data for each sample
41+ """
42+ if isinstance (subfolder , str ):
43+ subfolder = [subfolder ]
44+ for i in range (count ):
45+ save_audio (join (self .subdir (* subfolder ), name .format (i )), generator ())
46+
47+ def get_duration (self ):
48+ """Generate a random sample duration"""
49+ return int (random .random () * 2 * pr .buffer_samples )
50+
51+ def generate_default (self , count = 10 ):
52+ self .generate_samples (
53+ count , 'wake-word' , 'ww-{}.wav' ,
54+ lambda : np .ones (self .get_duration (), dtype = float )
55+ )
56+ self .generate_samples (
57+ count , 'not-wake-word' , 'nww-{}.wav' ,
58+ lambda : np .random .random (self .get_duration ()) * 2 - 1
59+ )
60+ self .generate_samples (
61+ count , ('test' , 'wake-word' ), 'ww-{}.wav' ,
62+ lambda : np .ones (self .get_duration (), dtype = float )
63+ )
64+ self .generate_samples (
65+ count , ('test' , 'not-wake-word' ), 'nww-{}.wav' ,
66+ lambda : np .random .random (self .get_duration ()) * 2 - 1
67+ )
68+ self .model = self .path ('model.net' )
0 commit comments