Skip to content

Commit e5c4de8

Browse files
authored
Replace dtype if-elseif-else with switch (#1270)
1 parent d58ac21 commit e5c4de8

File tree

2 files changed

+87
-65
lines changed

2 files changed

+87
-65
lines changed

torchaudio/csrc/sox/effects_chain.cpp

Lines changed: 27 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -80,29 +80,37 @@ int tensor_input_drain(sox_effect_t* effp, sox_sample_t* obuf, size_t* osamp) {
8080

8181
// Convert to sox_sample_t (int32_t) and write to buffer
8282
SOX_SAMPLE_LOCALS;
83-
const auto dtype = tensor_.dtype();
84-
if (dtype == torch::kFloat32) {
85-
auto ptr = tensor_.data_ptr<float_t>();
86-
for (size_t i = 0; i < *osamp; ++i) {
87-
obuf[i] = SOX_FLOAT_32BIT_TO_SAMPLE(ptr[i], effp->clips);
83+
switch (tensor_.dtype().toScalarType()) {
84+
case c10::ScalarType::Float: {
85+
auto ptr = tensor_.data_ptr<float_t>();
86+
for (size_t i = 0; i < *osamp; ++i) {
87+
obuf[i] = SOX_FLOAT_32BIT_TO_SAMPLE(ptr[i], effp->clips);
88+
}
89+
break;
8890
}
89-
} else if (dtype == torch::kInt32) {
90-
auto ptr = tensor_.data_ptr<int32_t>();
91-
for (size_t i = 0; i < *osamp; ++i) {
92-
obuf[i] = SOX_SIGNED_32BIT_TO_SAMPLE(ptr[i], effp->clips);
91+
case c10::ScalarType::Int: {
92+
auto ptr = tensor_.data_ptr<int32_t>();
93+
for (size_t i = 0; i < *osamp; ++i) {
94+
obuf[i] = SOX_SIGNED_32BIT_TO_SAMPLE(ptr[i], effp->clips);
95+
}
96+
break;
9397
}
94-
} else if (dtype == torch::kInt16) {
95-
auto ptr = tensor_.data_ptr<int16_t>();
96-
for (size_t i = 0; i < *osamp; ++i) {
97-
obuf[i] = SOX_SIGNED_16BIT_TO_SAMPLE(ptr[i], effp->clips);
98+
case c10::ScalarType::Short: {
99+
auto ptr = tensor_.data_ptr<int16_t>();
100+
for (size_t i = 0; i < *osamp; ++i) {
101+
obuf[i] = SOX_SIGNED_16BIT_TO_SAMPLE(ptr[i], effp->clips);
102+
}
103+
break;
98104
}
99-
} else if (dtype == torch::kUInt8) {
100-
auto ptr = tensor_.data_ptr<uint8_t>();
101-
for (size_t i = 0; i < *osamp; ++i) {
102-
obuf[i] = SOX_UNSIGNED_8BIT_TO_SAMPLE(ptr[i], effp->clips);
105+
case c10::ScalarType::Byte: {
106+
auto ptr = tensor_.data_ptr<uint8_t>();
107+
for (size_t i = 0; i < *osamp; ++i) {
108+
obuf[i] = SOX_UNSIGNED_8BIT_TO_SAMPLE(ptr[i], effp->clips);
109+
}
110+
break;
103111
}
104-
} else {
105-
throw std::runtime_error("Unexpected dtype.");
112+
default:
113+
throw std::runtime_error("Unexpected dtype.");
106114
}
107115
priv->index += *osamp;
108116
return (priv->index == num_samples) ? SOX_EOF : SOX_SUCCESS;

torchaudio/csrc/sox/utils.cpp

Lines changed: 60 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,15 @@ void validate_input_tensor(const torch::Tensor tensor) {
102102
throw std::runtime_error("Input tensor has to be 2D.");
103103
}
104104

105-
const auto dtype = tensor.dtype();
106-
if (!(dtype == torch::kFloat32 || dtype == torch::kInt32 ||
107-
dtype == torch::kInt16 || dtype == torch::kUInt8)) {
108-
throw std::runtime_error(
109-
"Input tensor has to be one of float32, int32, int16 or uint8 type.");
105+
switch (tensor.dtype().toScalarType()) {
106+
case c10::ScalarType::Byte:
107+
case c10::ScalarType::Short:
108+
case c10::ScalarType::Int:
109+
case c10::ScalarType::Float:
110+
break;
111+
default:
112+
throw std::runtime_error(
113+
"Input tensor has to be one of float32, int32, int16 or uint8 type.");
110114
}
111115
}
112116

@@ -209,22 +213,25 @@ namespace {
209213

210214
std::tuple<sox_encoding_t, unsigned> get_save_encoding_for_wav(
211215
const std::string format,
212-
const caffe2::TypeMeta dtype,
216+
caffe2::TypeMeta dtype,
213217
const Encoding& encoding,
214218
const BitDepth& bits_per_sample) {
215219
switch (encoding) {
216220
case Encoding::NOT_PROVIDED:
217221
switch (bits_per_sample) {
218222
case BitDepth::NOT_PROVIDED:
219-
if (dtype == torch::kFloat32)
220-
return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
221-
if (dtype == torch::kInt32)
222-
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
223-
if (dtype == torch::kInt16)
224-
return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
225-
if (dtype == torch::kUInt8)
226-
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
227-
throw std::runtime_error("Internal Error: Unexpected dtype.");
223+
switch (dtype.toScalarType()) {
224+
case c10::ScalarType::Float:
225+
return std::make_tuple<>(SOX_ENCODING_FLOAT, 32);
226+
case c10::ScalarType::Int:
227+
return std::make_tuple<>(SOX_ENCODING_SIGN2, 32);
228+
case c10::ScalarType::Short:
229+
return std::make_tuple<>(SOX_ENCODING_SIGN2, 16);
230+
case c10::ScalarType::Byte:
231+
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
232+
default:
233+
throw std::runtime_error("Internal Error: Unexpected dtype.");
234+
}
228235
case BitDepth::B8:
229236
return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8);
230237
default:
@@ -376,25 +383,26 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding(
376383
}
377384
}
378385

379-
unsigned get_precision(
380-
const std::string filetype,
381-
const caffe2::TypeMeta dtype) {
386+
unsigned get_precision(const std::string filetype, caffe2::TypeMeta dtype) {
382387
if (filetype == "mp3")
383388
return SOX_UNSPEC;
384389
if (filetype == "flac")
385390
return 24;
386391
if (filetype == "ogg" || filetype == "vorbis")
387392
return SOX_UNSPEC;
388393
if (filetype == "wav" || filetype == "amb") {
389-
if (dtype == torch::kUInt8)
390-
return 8;
391-
if (dtype == torch::kInt16)
392-
return 16;
393-
if (dtype == torch::kInt32)
394-
return 32;
395-
if (dtype == torch::kFloat32)
396-
return 32;
397-
throw std::runtime_error("Unsupported dtype.");
394+
switch (dtype.toScalarType()) {
395+
case c10::ScalarType::Byte:
396+
return 8;
397+
case c10::ScalarType::Short:
398+
return 16;
399+
case c10::ScalarType::Int:
400+
return 32;
401+
case c10::ScalarType::Float:
402+
return 32;
403+
default:
404+
throw std::runtime_error("Unsupported dtype.");
405+
}
398406
}
399407
if (filetype == "sph")
400408
return 32;
@@ -419,28 +427,34 @@ sox_signalinfo_t get_signalinfo(
419427
/*length=*/static_cast<uint64_t>(waveform->numel())};
420428
}
421429

422-
sox_encodinginfo_t get_tensor_encodinginfo(const caffe2::TypeMeta dtype) {
430+
sox_encodinginfo_t get_tensor_encodinginfo(caffe2::TypeMeta dtype) {
423431
sox_encoding_t encoding = [&]() {
424-
if (dtype == torch::kUInt8)
425-
return SOX_ENCODING_UNSIGNED;
426-
if (dtype == torch::kInt16)
427-
return SOX_ENCODING_SIGN2;
428-
if (dtype == torch::kInt32)
429-
return SOX_ENCODING_SIGN2;
430-
if (dtype == torch::kFloat32)
431-
return SOX_ENCODING_FLOAT;
432-
throw std::runtime_error("Unsupported dtype.");
432+
switch (dtype.toScalarType()) {
433+
case c10::ScalarType::Byte:
434+
return SOX_ENCODING_UNSIGNED;
435+
case c10::ScalarType::Short:
436+
return SOX_ENCODING_SIGN2;
437+
case c10::ScalarType::Int:
438+
return SOX_ENCODING_SIGN2;
439+
case c10::ScalarType::Float:
440+
return SOX_ENCODING_FLOAT;
441+
default:
442+
throw std::runtime_error("Unsupported dtype.");
443+
}
433444
}();
434445
unsigned bits_per_sample = [&]() {
435-
if (dtype == torch::kUInt8)
436-
return 8;
437-
if (dtype == torch::kInt16)
438-
return 16;
439-
if (dtype == torch::kInt32)
440-
return 32;
441-
if (dtype == torch::kFloat32)
442-
return 32;
443-
throw std::runtime_error("Unsupported dtype.");
446+
switch (dtype.toScalarType()) {
447+
case c10::ScalarType::Byte:
448+
return 8;
449+
case c10::ScalarType::Short:
450+
return 16;
451+
case c10::ScalarType::Int:
452+
return 32;
453+
case c10::ScalarType::Float:
454+
return 32;
455+
default:
456+
throw std::runtime_error("Unsupported dtype.");
457+
}
444458
}();
445459
return sox_encodinginfo_t{
446460
/*encoding=*/encoding,

0 commit comments

Comments
 (0)