Skip to content

Commit b3b9cb6

Browse files
committed
fix output type
Signed-off-by: David Slater <[email protected]>
1 parent f442424 commit b3b9cb6

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

art/defences/preprocessor/mp3_compression.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -115,6 +115,7 @@ def wav_to_mp3(x, sample_rate):
115115
x_mp3 = x_mp3 * 2 ** -15
116116
return x_mp3.astype(x_dtype)
117117

118+
x_orig_type = x.dtype
118119
if x.dtype != object and x.ndim == 2:
119120
x = x.astype(object)
120121

@@ -149,6 +150,9 @@ def wav_to_mp3(x, sample_rate):
149150
if x.dtype != object and self.channels_first:
150151
x_mp3 = np.swapaxes(x_mp3, 1, 2)
151152

153+
if x_orig_type != object and x.dtype == object and x.ndim == 2:
154+
x_mp3 = x_mp3.astype(x_orig_type)
155+
152156
return x_mp3, y
153157

154158
def _check_params(self) -> None:

0 commit comments

Comments
 (0)