Skip to content

Commit 5784b81

Browse files
Added HTK format support to sox_io's save & info (#1308)
* Cherry pick 3488f31 from #1276 * Cherry pick d2861fc from #1291
1 parent b423dec commit 5784b81

File tree

7 files changed

+52
-3
lines changed

7 files changed

+52
-3
lines changed

test/torchaudio_unittest/backend/sox_io/info_test.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ def test_ulaw(self):
205205
assert info.encoding == "ULAW"
206206

207207
def test_alaw(self):
208-
"""`sox_io_backend.info` can check ulaw file correctly"""
208+
"""`sox_io_backend.info` can check alaw file correctly"""
209209
duration = 1
210210
num_channels = 1
211211
sample_rate = 8000
@@ -221,6 +221,22 @@ def test_alaw(self):
221221
assert info.bits_per_sample == 8
222222
assert info.encoding == "ALAW"
223223

224+
def test_htk(self):
225+
"""`sox_io_backend.info` can check HTK file correctly"""
226+
duration = 1
227+
num_channels = 1
228+
sample_rate = 8000
229+
path = self.get_temp_path('data.htk')
230+
sox_utils.gen_audio_file(
231+
path, sample_rate=sample_rate, num_channels=num_channels,
232+
bit_depth=16, duration=duration)
233+
info = sox_io_backend.info(path)
234+
assert info.sample_rate == sample_rate
235+
assert info.num_frames == sample_rate * duration
236+
assert info.num_channels == num_channels
237+
assert info.bits_per_sample == 16
238+
assert info.encoding == "PCM_S"
239+
224240
def test_gsm(self):
225241
"""`sox_io_backend.info` can check gsm file correctly"""
226242
duration = 1

test/torchaudio_unittest/backend/sox_io/save_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,12 @@ def test_save_gsm(self, test_mode):
317317
self.assert_save_consistency(
318318
"gsm", test_mode=test_mode)
319319

320+
@nested_params(
321+
["path", "fileobj", "bytesio"],
322+
)
323+
def test_save_htk(self, test_mode):
324+
self.assert_save_consistency("htk", test_mode=test_mode, num_channels=1)
325+
320326
@parameterized.expand([
321327
("wav", "PCM_S", 16),
322328
("mp3", ),

torchaudio/backend/sox_io_backend.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ def save(
195195
When ``filepath`` argument is file-like object, this argument is required.
196196
197197
Valid values are ``"wav"``, ``"mp3"``, ``"ogg"``, ``"vorbis"``, ``"amr-nb"``,
198-
``"amb"``, ``"flac"``, ``"sph"`` and ``"gsm"``.
198+
``"amb"``, ``"flac"``, ``"sph"``, ``"gsm"``, and ``"htk"``.
199199
encoding (str, optional): Changes the encoding for the supported formats.
200200
This argument is effective only for supported formats, such as ``"wav"``, ``""amb"``
201201
and ``"sph"``. Valid values are;
@@ -294,6 +294,9 @@ def save(
294294
``"gsm"``
295295
Lossy Speech Compression, CPU intensive.
296296
297+
``"htk"``
298+
Uses its default single-channel 16-bit PCM format.
299+
297300
Note:
298301
To save into formats that ``libsox`` does not handle natively, (such as ``"mp3"``,
299302
``"flac"``, ``"ogg"`` and ``"vorbis"``), your installation of ``torchaudio`` has

torchaudio/csrc/sox/io.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,10 @@ void save_audio_file(
9797
const auto num_channels = tensor.size(channels_first ? 0 : 1);
9898
TORCH_CHECK(
9999
num_channels == 1, "amr-nb format only supports single channel audio.");
100+
} else if (filetype == "htk") {
101+
const auto num_channels = tensor.size(channels_first ? 0 : 1);
102+
TORCH_CHECK(
103+
num_channels == 1, "htk format only supports single channel audio.");
100104
}
101105
const auto signal_info =
102106
get_signalinfo(&tensor, sample_rate, filetype, channels_first);
@@ -233,6 +237,12 @@ void save_audio_fileobj(
233237
throw std::runtime_error(
234238
"amr-nb format only supports single channel audio.");
235239
}
240+
} else if (filetype == "htk") {
241+
const auto num_channels = tensor.size(channels_first ? 0 : 1);
242+
if (num_channels != 1) {
243+
throw std::runtime_error(
244+
"htk format only supports single channel audio.");
245+
}
236246
}
237247
const auto signal_info =
238248
get_signalinfo(&tensor, sample_rate, filetype, channels_first);

torchaudio/csrc/sox/types.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@ Format get_format_from_string(const std::string& format) {
2020
return Format::AMB;
2121
if (format == "sph")
2222
return Format::SPHERE;
23+
if (format == "htk")
24+
return Format::HTK;
2325
if (format == "gsm")
2426
return Format::GSM;
2527
std::ostringstream stream;

torchaudio/csrc/sox/types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ enum class Format {
1717
AMB,
1818
SPHERE,
1919
GSM,
20+
HTK,
2021
};
2122

2223
Format get_format_from_string(const std::string& format);

torchaudio/csrc/sox/utils.cpp

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,13 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding(
314314
throw std::runtime_error(
315315
"mp3 does not support `bits_per_sample` option.");
316316
return std::make_tuple<>(SOX_ENCODING_MP3, 16);
317+
case Format::HTK:
318+
if (enc != Encoding::NOT_PROVIDED)
319+
throw std::runtime_error("htk does not support `encoding` option.");
320+
if (bps != BitDepth::NOT_PROVIDED)
321+
throw std::runtime_error(
322+
"htk does not support `bits_per_sample` option.");
323+
return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
317324
case Format::VORBIS:
318325
if (enc != Encoding::NOT_PROVIDED)
319326
throw std::runtime_error("vorbis does not support `encoding` option.");
@@ -417,8 +424,12 @@ unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) {
417424
if (filetype == "amr-nb") {
418425
return 16;
419426
}
420-
if (filetype == "gsm")
427+
if (filetype == "gsm") {
421428
return 16;
429+
}
430+
if (filetype == "htk") {
431+
return 16;
432+
}
422433
throw std::runtime_error("Unsupported file type: " + filetype);
423434
}
424435

0 commit comments

Comments
 (0)