Skip to content

Commit f8eac89

Browse files
authored
Add SignalInfo typedef, and extension module (#718)
This is a part of PRs to add new "sox_io" backend. #726 This PR adds `SignalInfo` structure, which is data exchange interface between Python and C++ in coming TorchScript-based sox IO backend. For the case, where C++ extension is not available (i.e. Windows), this PR also adds dummy class and module that will be substituted. This logic is implemented in `torchaudio.extension` moduel.
1 parent bc1df48 commit f8eac89

File tree

6 files changed

+121
-0
lines changed

6 files changed

+121
-0
lines changed

torchaudio/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from . import extension
12
from torchaudio._internal import module_utils as _mod_utils
23
from torchaudio import (
34
compliance,

torchaudio/csrc/register.cpp

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#ifndef TORCHAUDIO_REGISTER_H
2+
#define TORCHAUDIO_REGISTER_H
3+
4+
#include <torchaudio/csrc/typedefs.h>
5+
6+
namespace torchaudio {
7+
namespace {
8+
9+
static auto registerSignalInfo =
10+
torch::class_<SignalInfo>("torchaudio", "SignalInfo")
11+
.def(torch::init<int64_t, int64_t, int64_t>())
12+
.def("get_sample_rate", &SignalInfo::getSampleRate)
13+
.def("get_num_channels", &SignalInfo::getNumChannels)
14+
.def("get_num_samples", &SignalInfo::getNumSamples);
15+
16+
} // namespace
17+
} // namespace torchaudio
18+
#endif

torchaudio/csrc/typedefs.cpp

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#include <torchaudio/csrc/typedefs.h>
2+
3+
namespace torchaudio {
4+
SignalInfo::SignalInfo(
5+
const int64_t sample_rate_,
6+
const int64_t num_channels_,
7+
const int64_t num_samples_)
8+
: sample_rate(sample_rate_),
9+
num_channels(num_channels_),
10+
num_samples(num_samples_){};
11+
12+
int64_t SignalInfo::getSampleRate() const {
13+
return sample_rate;
14+
}
15+
16+
int64_t SignalInfo::getNumChannels() const {
17+
return num_channels;
18+
}
19+
20+
int64_t SignalInfo::getNumSamples() const {
21+
return num_samples;
22+
}
23+
} // namespace torchaudio

torchaudio/csrc/typedefs.h

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
#ifndef TORCHAUDIO_TYPDEFS_H
2+
#define TORCHAUDIO_TYPDEFS_H
3+
4+
#include <torch/script.h>
5+
6+
namespace torchaudio {
7+
struct SignalInfo : torch::CustomClassHolder {
8+
int64_t sample_rate;
9+
int64_t num_channels;
10+
int64_t num_samples;
11+
12+
SignalInfo(
13+
const int64_t sample_rate_,
14+
const int64_t num_channels_,
15+
const int64_t num_samples_);
16+
int64_t getSampleRate() const;
17+
int64_t getNumChannels() const;
18+
int64_t getNumSamples() const;
19+
};
20+
21+
} // namespace torchaudio
22+
23+
#endif

torchaudio/extension/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
from .extension import (
2+
_init_extension,
3+
)
4+
5+
_init_extension()
6+
7+
del _init_extension

torchaudio/extension/extension.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import warnings
2+
import importlib
3+
from collections import namedtuple
4+
5+
import torch
6+
from torchaudio._internal import module_utils as _mod_utils
7+
8+
9+
def _init_extension():
10+
ext = 'torchaudio._torchaudio'
11+
if _mod_utils.is_module_available(ext):
12+
_init_script_module(ext)
13+
else:
14+
warnings.warn('torchaudio C++ extension is not available.')
15+
_init_dummy_module()
16+
17+
18+
def _init_script_module(module):
19+
path = importlib.util.find_spec(module).origin
20+
torch.classes.load_library(path)
21+
torch.ops.load_library(path)
22+
23+
24+
def _init_dummy_module():
25+
class SignalInfo:
26+
"""Data class for audio format information
27+
28+
Used when torchaudio C++ extension is not available for annotating
29+
sox_io backend functions so that torchaudio is still importable
30+
without extension.
31+
This class has to implement the same interface as C++ equivalent.
32+
"""
33+
def __init__(self, sample_rate: int, num_channels: int, num_samples: int):
34+
self.sample_rate = sample_rate
35+
self.num_channels = num_channels
36+
self.num_samples = num_samples
37+
38+
def get_sample_rate(self):
39+
return self.sample_rate
40+
41+
def get_num_channels(self):
42+
return self.num_channels
43+
44+
def get_num_samples(self):
45+
return self.num_samples
46+
47+
DummyModule = namedtuple('torchaudio', ['SignalInfo'])
48+
module = DummyModule(SignalInfo)
49+
setattr(torch.classes, 'torchaudio', module)

0 commit comments

Comments
 (0)