Skip to content

Commit 6196bdc

Browse files
Add a series of test cases using the new api
The primary purpose is to ensure all the scripts run end to end rather than completely verifying its functionality
1 parent a37958f commit 6196bdc

File tree

7 files changed

+291
-0
lines changed

7 files changed

+291
-0
lines changed

test/scripts/conftest.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2019 Mycroft AI Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import pytest
16+
17+
from precise.scripts.train import TrainScript
18+
from test.scripts.test_train import DummyTrainFolder
19+
20+
21+
@pytest.fixture()
22+
def train_folder():
23+
folder = DummyTrainFolder(10)
24+
try:
25+
yield folder
26+
finally:
27+
folder.cleanup()
28+
29+
30+
@pytest.fixture()
31+
def train_script(train_folder):
32+
return TrainScript.create(model=train_folder.model, folder=train_folder.root, epochs=1)

test/scripts/dummy_audio_folder.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2019 Mycroft AI Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import atexit
16+
17+
import numpy as np
18+
import os
19+
from os import makedirs
20+
from os.path import isdir, join
21+
from shutil import rmtree
22+
from tempfile import mkdtemp
23+
24+
from precise.params import pr
25+
from precise.util import save_audio
26+
27+
28+
class DummyAudioFolder:
29+
def __init__(self, count=10):
30+
self.count = count
31+
self.root = mkdtemp()
32+
atexit.register(self.cleanup)
33+
34+
def rand(self, min, max):
35+
return min + (max - min) * np.random.random() * pr.buffer_t
36+
37+
def generate_samples(self, folder, name, value, duration):
38+
for i in range(self.count):
39+
save_audio(join(folder, name.format(i)), np.array([value] * int(duration * pr.sample_rate)))
40+
41+
def subdir(self, *parts):
42+
folder = self.path(*parts)
43+
if not isdir(folder):
44+
makedirs(folder)
45+
return folder
46+
47+
def path(self, *path):
48+
return join(self.root, *path)
49+
50+
def count_files(self, folder):
51+
return sum([len(files) for r, d, files in os.walk(folder)])
52+
53+
def cleanup(self):
54+
if isdir(self.root):
55+
rmtree(self.root)

test/scripts/test_add_noise.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2019 Mycroft AI Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from precise.scripts.add_noise import AddNoiseScript
16+
17+
from test.scripts.dummy_audio_folder import DummyAudioFolder
18+
19+
20+
class DummyNoiseFolder(DummyAudioFolder):
21+
def __init__(self, count=10):
22+
super().__init__(count)
23+
self.source = self.subdir('source')
24+
self.noise = self.subdir('noise')
25+
self.output = self.subdir('output')
26+
27+
self.generate_samples(self.subdir('source', 'wake-word'), 'ww-{}.wav', 1.0, self.rand(0, 2))
28+
self.generate_samples(self.subdir('source', 'not-wake-word'), 'nww-{}.wav', 0.0, self.rand(0, 2))
29+
self.generate_samples(self.noise, 'noise-{}.wav', 0.5, self.rand(10, 20))
30+
31+
32+
class TestAddNoise:
33+
def get_base_data(self, count):
34+
folders = DummyNoiseFolder(count)
35+
base_args = dict(
36+
folder=folders.source, noise_folder=folders.noise,
37+
output_folder=folders.output
38+
)
39+
return folders, base_args
40+
41+
def test_run_basic(self):
42+
folders, base_args = self.get_base_data(10)
43+
script = AddNoiseScript.create(inflation_factor=1, **base_args)
44+
script.run()
45+
assert folders.count_files(folders.output) == 20
46+
47+
def test_run_basic_2(self):
48+
folders, base_args = self.get_base_data(10)
49+
script = AddNoiseScript.create(inflation_factor=2, **base_args)
50+
script.run()
51+
assert folders.count_files(folders.output) == 40

test/scripts/test_combined.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2019 Mycroft AI Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License
15+
16+
from os.path import isfile
17+
18+
from precise.scripts.calc_threshold import CalcThresholdScript
19+
from precise.scripts.eval import EvalScript
20+
from precise.scripts.graph import GraphScript
21+
22+
23+
def read_content(filename):
24+
with open(filename) as f:
25+
return f.read()
26+
27+
28+
def test_combined(train_folder, train_script):
29+
train_script.run()
30+
params_file = train_folder.model + '.params'
31+
assert isfile(train_folder.model)
32+
assert isfile(params_file)
33+
34+
EvalScript.create(folder=train_folder.root, models=[train_folder.model]).run()
35+
36+
out_file = train_folder.path('outputs.npz')
37+
graph_script = GraphScript.create(folder=train_folder.root, models=[train_folder.model], output_file=out_file)
38+
graph_script.run()
39+
assert isfile(out_file)
40+
41+
params_before = read_content(params_file)
42+
CalcThresholdScript.create(folder=train_folder.root, model=train_folder.model, input_file=out_file).run()
43+
assert params_before != read_content(params_file)

test/scripts/test_convert.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2019 Mycroft AI Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from os.path import isfile
16+
17+
from precise.scripts.convert import ConvertScript
18+
19+
20+
def test_convert(train_folder, train_script):
21+
train_script.run()
22+
23+
ConvertScript.create(model=train_folder.model, out=train_folder.model + '.pb').run()
24+
assert isfile(train_folder.model + '.pb')

test/scripts/test_engine.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2019 Mycroft AI Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import sys
16+
17+
import glob
18+
import re
19+
from os.path import join
20+
21+
from precise.scripts.engine import EngineScript
22+
from runner.precise_runner import ReadWriteStream
23+
24+
25+
class FakeStdin:
26+
def __init__(self, data: bytes):
27+
self.buffer = ReadWriteStream(data)
28+
29+
def isatty(self):
30+
return False
31+
32+
33+
class FakeStdout:
34+
def __init__(self):
35+
self.buffer = ReadWriteStream()
36+
37+
38+
def test_engine(train_folder, train_script):
39+
train_script.run()
40+
with open(glob.glob(join(train_folder.root, 'wake-word', '*.wav'))[0], 'rb') as f:
41+
data = f.read()
42+
try:
43+
sys.stdin = FakeStdin(data)
44+
sys.stdout = FakeStdout()
45+
EngineScript.create(model_name=train_folder.model).run()
46+
assert re.match(rb'[01]\.[0-9]+', sys.stdout.buffer.buffer)
47+
finally:
48+
sys.stdin = sys.__stdin__
49+
sys.stdout = sys.__stdout__

test/scripts/test_train.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#!/usr/bin/env python3
2+
# Copyright 2019 Mycroft AI Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
from os.path import isfile
16+
17+
from precise.params import pr
18+
from precise.scripts.train import TrainScript
19+
from test.scripts.dummy_audio_folder import DummyAudioFolder
20+
21+
22+
class DummyTrainFolder(DummyAudioFolder):
23+
def __init__(self, count=10):
24+
super().__init__(count)
25+
self.generate_samples(self.subdir('wake-word'), 'ww-{}.wav', 1.0, self.rand(0, 2 * pr.buffer_t))
26+
self.generate_samples(self.subdir('not-wake-word'), 'nww-{}.wav', 0.0, self.rand(0, 2 * pr.buffer_t))
27+
self.generate_samples(self.subdir('test', 'wake-word'), 'ww-{}.wav', 1.0, self.rand(0, 2 * pr.buffer_t))
28+
self.generate_samples(self.subdir('test', 'not-wake-word'), 'nww-{}.wav', 0.0, self.rand(0, 2 * pr.buffer_t))
29+
self.model = self.path('model.net')
30+
31+
32+
class TestTrain:
33+
def test_run_basic(self):
34+
folders = DummyTrainFolder(10)
35+
script = TrainScript.create(model=folders.model, folder=folders.root)
36+
script.run()
37+
assert isfile(folders.model)

0 commit comments

Comments
 (0)