Skip to content

Commit 67fa652

Browse files
authored
Merge pull request #1680 from davidslater/mp3-compression-bug
Mp3Compression defense bug fix - handle case when input is 2D but not an object array
2 parents cb95baf + b3b9cb6 commit 67fa652

File tree

1 file changed

+7
-0
lines changed

1 file changed

+7
-0
lines changed

art/defences/preprocessor/mp3_compression.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,10 @@ def wav_to_mp3(x, sample_rate):
115115
x_mp3 = x_mp3 * 2 ** -15
116116
return x_mp3.astype(x_dtype)
117117

118+
x_orig_type = x.dtype
119+
if x.dtype != object and x.ndim == 2:
120+
x = x.astype(object)
121+
118122
if x.dtype != object and x.ndim != 3:
119123
raise ValueError("Mp3 compression can only be applied to temporal data across at least one channel.")
120124

@@ -146,6 +150,9 @@ def wav_to_mp3(x, sample_rate):
146150
if x.dtype != object and self.channels_first:
147151
x_mp3 = np.swapaxes(x_mp3, 1, 2)
148152

153+
if x_orig_type != object and x.dtype == object and x.ndim == 2:
154+
x_mp3 = x_mp3.astype(x_orig_type)
155+
149156
return x_mp3, y
150157

151158
def _check_params(self) -> None:

0 commit comments

Comments
 (0)