Skip to content

Commit 3604494

Browse files
committed
refactored audio perturbation unit tests
Signed-off-by: Swanand Ravindra Kadhe <[email protected]>
1 parent 86ecb02 commit 3604494

File tree

1 file changed

+9
-16
lines changed

1 file changed

+9
-16
lines changed

tests/attacks/poison/test_audio_perturbations.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import logging
2121
import numpy as np
2222
import pytest
23+
import os, sys
2324

2425
from art.attacks.poisoning.perturbations.audio_perturbations import insert_tone_trigger, insert_audio_trigger
2526

@@ -71,37 +72,31 @@ def test_insert_tone_trigger(art_warning):
7172

7273
@pytest.mark.framework_agnostic
7374
def test_insert_audio_trigger(art_warning):
75+
file_path = os.path.join(os.getcwd(), "utils/data/backdoors/cough_trigger.wav")
7476
try:
75-
# TODO
7677
# test single example
77-
audio = insert_audio_trigger(
78-
x=np.zeros(32000), sampling_rate=16000, backdoor_path="/utils/data/backdoors/cough_trigger.wav"
79-
)
78+
audio = insert_audio_trigger(x=np.zeros(32000), sampling_rate=16000, backdoor_path=file_path)
8079
assert audio.shape == (32000,)
8180
assert np.max(audio) != 0
8281

8382
# test single example with differet duration and scale
8483
audio = insert_audio_trigger(
8584
x=np.zeros(32000),
8685
sampling_rate=16000,
87-
backdoor_path="/utils/data/backdoors/cough_trigger.wav",
86+
backdoor_path=file_path,
8887
duration=0.8,
8988
scale=0.5,
9089
)
9190
assert audio.shape == (32000,)
9291
assert np.max(audio) != 0
9392

9493
# test a batch of examples
95-
audio = insert_audio_trigger(
96-
x=np.zeros((10, 16000)), sampling_rate=16000, backdoor_path="/utils/data/backdoors/cough_trigger.wav"
97-
)
94+
audio = insert_audio_trigger(x=np.zeros((10, 16000)), sampling_rate=16000, backdoor_path=file_path)
9895
assert audio.shape == (10, 16000)
9996
assert np.max(audio) != 0
10097

10198
# test single example with shift
102-
audio = insert_audio_trigger(
103-
x=np.zeros(32000), sampling_rate=16000, backdoor_path="/utils/data/backdoors/cough_trigger.wav", shift=10
104-
)
99+
audio = insert_audio_trigger(x=np.zeros(32000), sampling_rate=16000, backdoor_path=file_path, shift=10)
105100
assert audio.shape == (32000,)
106101
assert np.max(audio) != 0
107102
assert np.sum(audio[:10]) == 0
@@ -110,24 +105,22 @@ def test_insert_audio_trigger(art_warning):
110105
audio = insert_audio_trigger(
111106
x=np.zeros((10, 32000)),
112107
sampling_rate=16000,
113-
backdoor_path="/utils/data/backdoors/cough_trigger.wav",
108+
backdoor_path=file_path,
114109
random=True,
115110
)
116111
assert audio.shape == (10, 32000)
117112
assert np.max(audio) != 0
118113

119114
# test when length of backdoor is larger than that of audio signal
120115
with pytest.raises(ValueError):
121-
_ = insert_audio_trigger(
122-
x=np.zeros(15000), sampling_rate=16000, backdoor_path="/utils/data/backdoors/cough_trigger.wav"
123-
)
116+
_ = insert_audio_trigger(x=np.zeros(15000), sampling_rate=16000, backdoor_path=file_path)
124117

125118
# test when shift + backdoor is larger than that of audio signal
126119
with pytest.raises(ValueError):
127120
_ = insert_audio_trigger(
128121
x=np.zeros(16000),
129122
sampling_rate=16000,
130-
backdoor_path="/utils/data/backdoors/cough_trigger.wav",
123+
backdoor_path=file_path,
131124
duration=1,
132125
shift=5,
133126
)

0 commit comments

Comments
 (0)