Skip to content

Commit 6e86f2a

Browse files
Make tests more concise/consistent
Fixtures are used consistently A new TempFolder class has been created for reusability A session-wide trained_model fixture has been added to speed up testing
1 parent 1c7cd51 commit 6e86f2a

File tree

10 files changed

+155
-79
lines changed

10 files changed

+155
-79
lines changed

test/scripts/conftest.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,21 +12,51 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15+
import shutil
16+
1517
import pytest
1618

1719
from precise.scripts.train import TrainScript
18-
from test.scripts.test_train import DummyTrainFolder
20+
from test.scripts.test_utils.temp_folder import TempFolder
21+
from test.scripts.test_utils.dummy_train_folder import DummyTrainFolder
1922

2023

2124
@pytest.fixture()
2225
def train_folder():
23-
folder = DummyTrainFolder(10)
26+
folder = DummyTrainFolder()
27+
folder.generate_default()
2428
try:
2529
yield folder
2630
finally:
2731
folder.cleanup()
2832

2933

3034
@pytest.fixture()
31-
def train_script(train_folder):
32-
return TrainScript.create(model=train_folder.model, folder=train_folder.root, epochs=1)
35+
def temp_folder():
36+
folder = TempFolder()
37+
try:
38+
yield folder
39+
finally:
40+
folder.cleanup()
41+
42+
43+
@pytest.fixture(scope='session')
44+
def _trained_model():
45+
"""Session wide model that gets trained once"""
46+
folder = DummyTrainFolder()
47+
folder.generate_default()
48+
script = TrainScript.create(model=folder.model, folder=folder.root, epochs=100)
49+
script.run()
50+
try:
51+
yield folder.model
52+
finally:
53+
folder.cleanup()
54+
55+
56+
@pytest.fixture()
57+
def trained_model(_trained_model, temp_folder):
58+
"""Copy of session wide model"""
59+
model = temp_folder.path('trained_model.net')
60+
shutil.copy(_trained_model, model)
61+
shutil.copy(_trained_model + '.params', model + '.params')
62+
return model

test/scripts/test_add_noise.py

Lines changed: 5 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -13,25 +13,13 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
from precise.scripts.add_noise import AddNoiseScript
16-
17-
from test.scripts.dummy_audio_folder import DummyAudioFolder
18-
19-
20-
class DummyNoiseFolder(DummyAudioFolder):
21-
def __init__(self, count=10):
22-
super().__init__(count)
23-
self.source = self.subdir('source')
24-
self.noise = self.subdir('noise')
25-
self.output = self.subdir('output')
26-
27-
self.generate_samples(self.subdir('source', 'wake-word'), 'ww-{}.wav', 1.0, self.rand(0, 2))
28-
self.generate_samples(self.subdir('source', 'not-wake-word'), 'nww-{}.wav', 0.0, self.rand(0, 2))
29-
self.generate_samples(self.noise, 'noise-{}.wav', 0.5, self.rand(10, 20))
16+
from test.scripts.test_utils.dummy_noise_folder import DummyNoiseFolder
3017

3118

3219
class TestAddNoise:
3320
def get_base_data(self, count):
34-
folders = DummyNoiseFolder(count)
21+
folders = DummyNoiseFolder()
22+
folders.generate_default(count)
3523
base_args = dict(
3624
folder=folders.source, noise_folder=folders.noise,
3725
output_folder=folders.output
@@ -42,10 +30,10 @@ def test_run_basic(self):
4230
folders, base_args = self.get_base_data(10)
4331
script = AddNoiseScript.create(inflation_factor=1, **base_args)
4432
script.run()
45-
assert folders.count_files(folders.output) == 20
33+
assert folders.count_files(folders.output) == 40
4634

4735
def test_run_basic_2(self):
4836
folders, base_args = self.get_base_data(10)
4937
script = AddNoiseScript.create(inflation_factor=2, **base_args)
5038
script.run()
51-
assert folders.count_files(folders.output) == 40
39+
assert folders.count_files(folders.output) == 80

test/scripts/test_combined.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,35 +18,35 @@
1818
from precise.scripts.calc_threshold import CalcThresholdScript
1919
from precise.scripts.eval import EvalScript
2020
from precise.scripts.graph import GraphScript
21+
from test.scripts.test_utils.dummy_train_folder import DummyTrainFolder
2122

2223

2324
def read_content(filename):
2425
with open(filename) as f:
2526
return f.read()
2627

2728

28-
def test_combined(train_folder, train_script):
29+
def test_combined(train_folder: DummyTrainFolder, trained_model: str):
2930
"""Test a "normal" development cycle, train, evaluate and calc threshold.
3031
"""
31-
train_script.run()
32-
params_file = train_folder.model + '.params'
33-
assert isfile(train_folder.model)
32+
params_file = trained_model + '.params'
33+
assert isfile(trained_model)
3434
assert isfile(params_file)
3535

3636
EvalScript.create(folder=train_folder.root,
37-
models=[train_folder.model]).run()
37+
models=[trained_model]).run()
3838

3939
# Ensure that the graph script generates a numpy savez file
4040
out_file = train_folder.path('outputs.npz')
4141
graph_script = GraphScript.create(folder=train_folder.root,
42-
models=[train_folder.model],
42+
models=[trained_model],
4343
output_file=out_file)
4444
graph_script.run()
4545
assert isfile(out_file)
4646

4747
# Esure the params are updated after threshold is calculated
4848
params_before = read_content(params_file)
4949
CalcThresholdScript.create(folder=train_folder.root,
50-
model=train_folder.model,
50+
model=trained_model,
5151
input_file=out_file).run()
5252
assert params_before != read_content(params_file)

test/scripts/test_convert.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515
from os.path import isfile
1616

1717
from precise.scripts.convert import ConvertScript
18+
from test.scripts.test_utils.temp_folder import TempFolder
1819

1920

20-
def test_convert(train_folder, train_script):
21-
train_script.run()
22-
23-
ConvertScript.create(model=train_folder.model, out=train_folder.model + '.pb').run()
24-
assert isfile(train_folder.model + '.pb')
21+
def test_convert(temp_folder: TempFolder, trained_model: str):
22+
pb_model = temp_folder.path('model.pb')
23+
ConvertScript.create(model=trained_model, out=pb_model).run()
24+
assert isfile(pb_model)

test/scripts/test_engine.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
from precise.scripts.engine import EngineScript
2222
from runner.precise_runner import ReadWriteStream
2323

24+
from test.scripts.test_utils.dummy_train_folder import DummyTrainFolder
25+
2426

2527
class FakeStdin:
2628
def __init__(self, data: bytes):
@@ -35,18 +37,17 @@ def __init__(self):
3537
self.buffer = ReadWriteStream()
3638

3739

38-
def test_engine(train_folder, train_script):
40+
def test_engine(train_folder: DummyTrainFolder, trained_model: str):
3941
"""
40-
Test t hat the output format of the engina matches a decimal form in the
42+
Test t hat the output format of the engine matches a decimal form in the
4143
range 0.0 - 1.0.
4244
"""
43-
train_script.run()
4445
with open(glob.glob(join(train_folder.root, 'wake-word', '*.wav'))[0], 'rb') as f:
4546
data = f.read()
4647
try:
4748
sys.stdin = FakeStdin(data)
4849
sys.stdout = FakeStdout()
49-
EngineScript.create(model_name=train_folder.model).run()
50+
EngineScript.create(model_name=trained_model).run()
5051
assert re.match(rb'[01]\.[0-9]+', sys.stdout.buffer.buffer)
5152
finally:
5253
sys.stdin = sys.__stdin__

test/scripts/test_train.py

Lines changed: 6 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,29 +14,14 @@
1414
# limitations under the License.
1515
from os.path import isfile
1616

17-
from precise.params import pr
1817
from precise.scripts.train import TrainScript
19-
from test.scripts.dummy_audio_folder import DummyAudioFolder
20-
21-
22-
class DummyTrainFolder(DummyAudioFolder):
23-
def __init__(self, count=10):
24-
super().__init__(count)
25-
self.generate_samples(self.subdir('wake-word'), 'ww-{}.wav', 1.0,
26-
self.rand(0, 2 * pr.buffer_t))
27-
self.generate_samples(self.subdir('not-wake-word'), 'nww-{}.wav', 0.0,
28-
self.rand(0, 2 * pr.buffer_t))
29-
self.generate_samples(self.subdir('test', 'wake-word'), 'ww-{}.wav',
30-
1.0, self.rand(0, 2 * pr.buffer_t))
31-
self.generate_samples(self.subdir('test', 'not-wake-word'),
32-
'nww-{}.wav', 0.0, self.rand(0, 2 * pr.buffer_t))
33-
self.model = self.path('model.net')
18+
from test.scripts.test_utils.dummy_train_folder import DummyTrainFolder
3419

3520

3621
class TestTrain:
37-
def test_run_basic(self):
22+
def test_run_basic(self, train_folder: DummyTrainFolder):
3823
"""Run a training and check that a model is generated."""
39-
folders = DummyTrainFolder(10)
40-
script = TrainScript.create(model=folders.model, folder=folders.root)
41-
script.run()
42-
assert isfile(folders.model)
24+
train_script = TrainScript.create(model=train_folder.model, folder=train_folder.root, epochs=10)
25+
train_script.run()
26+
assert isfile(train_folder.model)
27+
assert isfile(train_folder.model + '.params')
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright 2020 Mycroft AI Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import numpy as np
2+
3+
from test.scripts.test_utils.temp_folder import TempFolder
4+
from test.scripts.test_utils.dummy_train_folder import DummyTrainFolder
5+
6+
7+
class DummyNoiseFolder(TempFolder):
8+
def __init__(self):
9+
super().__init__()
10+
self.source = self.subdir('source')
11+
self.noise = self.subdir('noise')
12+
self.output = self.subdir('output')
13+
14+
self.source_folder = DummyTrainFolder(root=self.source)
15+
self.noise_folder = DummyTrainFolder(root=self.noise)
16+
17+
def generate_default(self, count=10):
18+
self.source_folder.generate_default(count)
19+
self.noise_folder.generate_samples(
20+
count, [], 'noise-{}.wav',
21+
lambda: np.ones(self.noise_folder.get_duration(), dtype=float)
22+
)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import random
2+
from os.path import join
3+
4+
import numpy as np
5+
6+
from precise.params import pr
7+
from precise.util import save_audio
8+
from test.scripts.test_utils.temp_folder import TempFolder
9+
10+
11+
class DummyTrainFolder(TempFolder):
12+
def __init__(self, root=None):
13+
super().__init__(root)
14+
self.model = self.path('model.net')
15+
16+
def generate_samples(self, count, subfolder, name, generator):
17+
"""
18+
Generate sample audio files in a folder
19+
20+
The file is generated in the specified folder, with the specified
21+
name and generated value.
22+
23+
Args:
24+
count: Number of samples to generate
25+
subfolder: String or list of subfolder path
26+
name: Format string used to generate each sample
27+
generator: Function called to get the data for each sample
28+
"""
29+
if isinstance(subfolder, str):
30+
subfolder = [subfolder]
31+
for i in range(count):
32+
save_audio(join(self.subdir(*subfolder), name.format(i)), generator())
33+
34+
def get_duration(self):
35+
"""Generate a random sample duration"""
36+
return int(random.random() * 2 * pr.buffer_samples)
37+
38+
def generate_default(self, count=10):
39+
self.generate_samples(
40+
count, 'wake-word', 'ww-{}.wav',
41+
lambda: np.ones(self.get_duration(), dtype=float)
42+
)
43+
self.generate_samples(
44+
count, 'not-wake-word', 'nww-{}.wav',
45+
lambda: np.random.random(self.get_duration()) * 2 - 1
46+
)
47+
self.generate_samples(
48+
count, ('test', 'wake-word'), 'ww-{}.wav',
49+
lambda: np.ones(self.get_duration(), dtype=float)
50+
)
51+
self.generate_samples(
52+
count, ('test', 'not-wake-word'), 'nww-{}.wav',
53+
lambda: np.random.random(self.get_duration()) * 2 - 1
54+
)
55+
self.model = self.path('model.net')
Lines changed: 3 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -14,36 +14,18 @@
1414
# limitations under the License.
1515
import atexit
1616

17-
import numpy as np
1817
import os
1918
from os import makedirs
2019
from os.path import isdir, join
2120
from shutil import rmtree
2221
from tempfile import mkdtemp
2322

24-
from precise.params import pr
25-
from precise.util import save_audio
2623

27-
28-
class DummyAudioFolder:
29-
def __init__(self, count=10):
30-
self.count = count
31-
self.root = mkdtemp()
24+
class TempFolder:
25+
def __init__(self, root=None):
26+
self.root = mkdtemp() if root is None else root
3227
atexit.register(self.cleanup)
3328

34-
def rand(self, min, max):
35-
return min + (max - min) * np.random.random() * pr.buffer_t
36-
37-
def generate_samples(self, folder, name, value, duration):
38-
"""Generate sample file.
39-
40-
The file is generated in the specified folder, with the specified name,
41-
dummy value and duration.
42-
"""
43-
for i in range(self.count):
44-
save_audio(join(folder, name.format(i)),
45-
np.array([value] * int(duration * pr.sample_rate)))
46-
4729
def subdir(self, *parts):
4830
folder = self.path(*parts)
4931
if not isdir(folder):

0 commit comments

Comments
 (0)