@@ -102,11 +102,15 @@ void validate_input_tensor(const torch::Tensor tensor) {
102102 throw std::runtime_error (" Input tensor has to be 2D." );
103103 }
104104
105- const auto dtype = tensor.dtype ();
106- if (!(dtype == torch::kFloat32 || dtype == torch::kInt32 ||
107- dtype == torch::kInt16 || dtype == torch::kUInt8 )) {
108- throw std::runtime_error (
109- " Input tensor has to be one of float32, int32, int16 or uint8 type." );
105+ switch (tensor.dtype ().toScalarType ()) {
106+ case c10::ScalarType::Byte:
107+ case c10::ScalarType::Short:
108+ case c10::ScalarType::Int:
109+ case c10::ScalarType::Float:
110+ break ;
111+ default :
112+ throw std::runtime_error (
113+ " Input tensor has to be one of float32, int32, int16 or uint8 type." );
110114 }
111115}
112116
@@ -209,22 +213,25 @@ namespace {
209213
210214std::tuple<sox_encoding_t , unsigned > get_save_encoding_for_wav (
211215 const std::string format,
212- const caffe2::TypeMeta dtype,
216+ caffe2::TypeMeta dtype,
213217 const Encoding& encoding,
214218 const BitDepth& bits_per_sample) {
215219 switch (encoding) {
216220 case Encoding::NOT_PROVIDED:
217221 switch (bits_per_sample) {
218222 case BitDepth::NOT_PROVIDED:
219- if (dtype == torch::kFloat32 )
220- return std::make_tuple<>(SOX_ENCODING_FLOAT, 32 );
221- if (dtype == torch::kInt32 )
222- return std::make_tuple<>(SOX_ENCODING_SIGN2, 32 );
223- if (dtype == torch::kInt16 )
224- return std::make_tuple<>(SOX_ENCODING_SIGN2, 16 );
225- if (dtype == torch::kUInt8 )
226- return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8 );
227- throw std::runtime_error (" Internal Error: Unexpected dtype." );
223+ switch (dtype.toScalarType ()) {
224+ case c10::ScalarType::Float:
225+ return std::make_tuple<>(SOX_ENCODING_FLOAT, 32 );
226+ case c10::ScalarType::Int:
227+ return std::make_tuple<>(SOX_ENCODING_SIGN2, 32 );
228+ case c10::ScalarType::Short:
229+ return std::make_tuple<>(SOX_ENCODING_SIGN2, 16 );
230+ case c10::ScalarType::Byte:
231+ return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8 );
232+ default :
233+ throw std::runtime_error (" Internal Error: Unexpected dtype." );
234+ }
228235 case BitDepth::B8:
229236 return std::make_tuple<>(SOX_ENCODING_UNSIGNED, 8 );
230237 default :
@@ -376,25 +383,26 @@ std::tuple<sox_encoding_t, unsigned> get_save_encoding(
376383 }
377384}
378385
379- unsigned get_precision (
380- const std::string filetype,
381- const caffe2::TypeMeta dtype) {
386+ unsigned get_precision (const std::string filetype, caffe2::TypeMeta dtype) {
382387 if (filetype == " mp3" )
383388 return SOX_UNSPEC;
384389 if (filetype == " flac" )
385390 return 24 ;
386391 if (filetype == " ogg" || filetype == " vorbis" )
387392 return SOX_UNSPEC;
388393 if (filetype == " wav" || filetype == " amb" ) {
389- if (dtype == torch::kUInt8 )
390- return 8 ;
391- if (dtype == torch::kInt16 )
392- return 16 ;
393- if (dtype == torch::kInt32 )
394- return 32 ;
395- if (dtype == torch::kFloat32 )
396- return 32 ;
397- throw std::runtime_error (" Unsupported dtype." );
394+ switch (dtype.toScalarType ()) {
395+ case c10::ScalarType::Byte:
396+ return 8 ;
397+ case c10::ScalarType::Short:
398+ return 16 ;
399+ case c10::ScalarType::Int:
400+ return 32 ;
401+ case c10::ScalarType::Float:
402+ return 32 ;
403+ default :
404+ throw std::runtime_error (" Unsupported dtype." );
405+ }
398406 }
399407 if (filetype == " sph" )
400408 return 32 ;
@@ -419,28 +427,34 @@ sox_signalinfo_t get_signalinfo(
419427 /* length=*/ static_cast <uint64_t >(waveform->numel ())};
420428}
421429
422- sox_encodinginfo_t get_tensor_encodinginfo (const caffe2::TypeMeta dtype) {
430+ sox_encodinginfo_t get_tensor_encodinginfo (caffe2::TypeMeta dtype) {
423431 sox_encoding_t encoding = [&]() {
424- if (dtype == torch::kUInt8 )
425- return SOX_ENCODING_UNSIGNED;
426- if (dtype == torch::kInt16 )
427- return SOX_ENCODING_SIGN2;
428- if (dtype == torch::kInt32 )
429- return SOX_ENCODING_SIGN2;
430- if (dtype == torch::kFloat32 )
431- return SOX_ENCODING_FLOAT;
432- throw std::runtime_error (" Unsupported dtype." );
432+ switch (dtype.toScalarType ()) {
433+ case c10::ScalarType::Byte:
434+ return SOX_ENCODING_UNSIGNED;
435+ case c10::ScalarType::Short:
436+ return SOX_ENCODING_SIGN2;
437+ case c10::ScalarType::Int:
438+ return SOX_ENCODING_SIGN2;
439+ case c10::ScalarType::Float:
440+ return SOX_ENCODING_FLOAT;
441+ default :
442+ throw std::runtime_error (" Unsupported dtype." );
443+ }
433444 }();
434445 unsigned bits_per_sample = [&]() {
435- if (dtype == torch::kUInt8 )
436- return 8 ;
437- if (dtype == torch::kInt16 )
438- return 16 ;
439- if (dtype == torch::kInt32 )
440- return 32 ;
441- if (dtype == torch::kFloat32 )
442- return 32 ;
443- throw std::runtime_error (" Unsupported dtype." );
446+ switch (dtype.toScalarType ()) {
447+ case c10::ScalarType::Byte:
448+ return 8 ;
449+ case c10::ScalarType::Short:
450+ return 16 ;
451+ case c10::ScalarType::Int:
452+ return 32 ;
453+ case c10::ScalarType::Float:
454+ return 32 ;
455+ default :
456+ throw std::runtime_error (" Unsupported dtype." );
457+ }
444458 }();
445459 return sox_encodinginfo_t {
446460 /* encoding=*/ encoding,
0 commit comments