Skip to content

Commit 1b2dc35

Browse files
committed
Support ONNX export
1 parent 379bfca commit 1b2dc35

File tree

6 files changed

+287
-14
lines changed

6 files changed

+287
-14
lines changed

deployment/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
from .base_onnx_module import BaseONNXModule
2+
from .me_onnx_module import MIDIExtractionONNXModule
3+
from .me_quant_onnx_module import QuantizedMIDIExtractionONNXModule
4+
5+
task_module_mapping = {
6+
'training.MIDIExtractionTask': 'deployment.MIDIExtractionONNXModule',
7+
'training.QuantizedMIDIExtractionTask': 'deployment.QuantizedMIDIExtractionONNXModule',
8+
}

deployment/base_onnx_module.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import pathlib
2+
from collections import OrderedDict
3+
4+
from librosa.filters import mel
5+
import torch
6+
from torch import nn
7+
8+
from utils import build_object_from_class_name
9+
10+
11+
class BaseONNXModule(nn.Module):
12+
def __init__(self, config: dict, model_path: pathlib.Path, device=None):
13+
super().__init__()
14+
if device is None:
15+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
16+
self.config = config
17+
self.model_path = model_path
18+
self.device = device
19+
self.timestep = self.config['hop_size'] / self.config['audio_sample_rate']
20+
self.model: torch.nn.Module = self.build_model()
21+
22+
def build_model(self) -> nn.Module:
23+
model: nn.Module = build_object_from_class_name(
24+
self.config['model_cls'], nn.Module, config=self.config
25+
).eval().to(self.device)
26+
state_dict = torch.load(self.model_path, map_location=self.device)['state_dict']
27+
prefix_in_ckpt = 'model'
28+
state_dict = OrderedDict({
29+
k[len(prefix_in_ckpt) + 1:]: v
30+
for k, v in state_dict.items() if k.startswith(f'{prefix_in_ckpt}.')
31+
})
32+
model.load_state_dict(state_dict, strict=True)
33+
print(f'| load \'{prefix_in_ckpt}\' from \'{self.model_path}\'.')
34+
return model
35+
36+
37+
class MelSpectrogram_ONNX(nn.Module):
38+
def __init__(
39+
self,
40+
n_mel_channels,
41+
sampling_rate,
42+
win_length,
43+
hop_length,
44+
n_fft=None,
45+
mel_fmin=0,
46+
mel_fmax=None,
47+
clamp=1e-5
48+
):
49+
super().__init__()
50+
n_fft = win_length if n_fft is None else n_fft
51+
mel_basis = mel(
52+
sr=sampling_rate,
53+
n_fft=n_fft,
54+
n_mels=n_mel_channels,
55+
fmin=mel_fmin,
56+
fmax=mel_fmax,
57+
htk=True)
58+
mel_basis = torch.from_numpy(mel_basis).float()
59+
self.register_buffer("mel_basis", mel_basis)
60+
self.n_fft = win_length if n_fft is None else n_fft
61+
self.hop_length = hop_length
62+
self.win_length = win_length
63+
self.sampling_rate = sampling_rate
64+
self.n_mel_channels = n_mel_channels
65+
self.clamp = clamp
66+
67+
def forward(self, audio, center=True):
68+
fft = torch.stft(
69+
audio,
70+
n_fft=self.n_fft,
71+
hop_length=self.hop_length,
72+
win_length=self.win_length,
73+
window=torch.hann_window(self.win_length, device=audio.device),
74+
center=center,
75+
return_complex=False
76+
)
77+
magnitude = torch.sqrt(torch.sum(fft ** 2, dim=-1))
78+
mel_output = torch.matmul(self.mel_basis, magnitude)
79+
log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
80+
return log_mel_spec

deployment/me_onnx_module.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import pathlib
2+
3+
import torch
4+
5+
from utils.infer_utils import decode_bounds_to_alignment, decode_gaussian_blurred_probs, decode_note_sequence
6+
from .base_onnx_module import BaseONNXModule, MelSpectrogram_ONNX
7+
8+
9+
class MIDIExtractionONNXModule(BaseONNXModule):
10+
def __init__(self, config: dict, model_path: pathlib.Path, device=None):
11+
super().__init__(config, model_path, device=device)
12+
self.mel_extractor = MelSpectrogram_ONNX(
13+
n_mel_channels=self.config['units_dim'], sampling_rate=self.config['audio_sample_rate'],
14+
win_length=self.config['win_size'], hop_length=self.config['hop_size'],
15+
mel_fmin=self.config['fmin'], mel_fmax=self.config['fmax']
16+
).to(self.device)
17+
self.rmvpe = None
18+
self.midi_min = self.config['midi_min']
19+
self.midi_max = self.config['midi_max']
20+
self.midi_deviation = self.config['midi_prob_deviation']
21+
self.rest_threshold = self.config['rest_threshold']
22+
23+
def forward(self, waveform: torch.Tensor):
24+
units = self.mel_extractor(waveform).transpose(1, 2)
25+
pitch = torch.zeros(units.shape[:2], dtype=torch.float32, device=self.device)
26+
masks = torch.ones_like(pitch, dtype=torch.bool)
27+
probs, bounds = self.model(x=units, f0=pitch, mask=masks, sig=True)
28+
probs *= masks[..., None]
29+
bounds *= masks
30+
unit2note_pred = decode_bounds_to_alignment(bounds, use_diff=False) * masks
31+
midi_pred, rest_pred = decode_gaussian_blurred_probs(
32+
probs, vmin=self.midi_min, vmax=self.midi_max,
33+
deviation=self.midi_deviation, threshold=self.rest_threshold
34+
)
35+
note_midi_pred, note_dur_pred, note_mask_pred = decode_note_sequence(
36+
unit2note_pred, midi_pred, ~rest_pred & masks
37+
)
38+
note_rest_pred = ~note_mask_pred
39+
return note_midi_pred, note_rest_pred, note_dur_pred * self.timestep

deployment/me_quant_onnx_module.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
import pathlib
2+
3+
import torch
4+
5+
from utils.infer_utils import decode_bounds_to_alignment, decode_note_sequence
6+
from .base_onnx_module import BaseONNXModule, MelSpectrogram_ONNX
7+
8+
9+
class QuantizedMIDIExtractionONNXModule(BaseONNXModule):
10+
def __init__(self, config: dict, model_path: pathlib.Path, device=None):
11+
super().__init__(config, model_path, device=device)
12+
self.mel_extractor = MelSpectrogram_ONNX(
13+
n_mel_channels=self.config['units_dim'], sampling_rate=self.config['audio_sample_rate'],
14+
win_length=self.config['win_size'], hop_length=self.config['hop_size'],
15+
mel_fmin=self.config['fmin'], mel_fmax=self.config['fmax']
16+
).to(self.device)
17+
self.rmvpe = None
18+
19+
def forward(self, waveform: torch.Tensor):
20+
units = self.mel_extractor(waveform).transpose(1, 2)
21+
pitch = torch.zeros(units.shape[:2], dtype=torch.float32, device=self.device)
22+
masks = torch.ones_like(pitch, dtype=torch.bool)
23+
probs, bounds = self.model(x=units, f0=pitch, mask=masks, sig=True)
24+
probs *= masks[..., None]
25+
bounds *= masks
26+
unit2note_pred = decode_bounds_to_alignment(bounds) * masks
27+
midi_pred = probs.argmax(dim=-1)
28+
rest_pred = midi_pred == 128
29+
note_midi_pred, note_dur_pred, note_mask_pred = decode_note_sequence(
30+
unit2note_pred, midi_pred.clip(min=0, max=127), ~rest_pred & masks
31+
)
32+
note_rest_pred = ~note_mask_pred
33+
return note_midi_pred, note_rest_pred, note_dur_pred * self.timestep

export.py

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
import importlib
2+
import pathlib
3+
from typing import Dict, Tuple, Union
4+
5+
import click
6+
import onnx
7+
import onnxsim
8+
import torch
9+
import yaml
10+
11+
import deployment
12+
from utils.config_utils import print_config
13+
14+
15+
def onnx_override_io_shapes(
16+
model, # ModelProto
17+
input_shapes: Dict[str, Tuple[Union[str, int]]] = None,
18+
output_shapes: Dict[str, Tuple[Union[str, int]]] = None,
19+
):
20+
"""
21+
Override the shapes of inputs/outputs of the model graph (in-place operation).
22+
:param model: model to perform the operation on
23+
:param input_shapes: a dict with keys as input/output names and values as shape tuples
24+
:param output_shapes: the same as input_shapes
25+
"""
26+
def _override_shapes(
27+
shape_list_old, # RepeatedCompositeFieldContainer[ValueInfoProto]
28+
shape_dict_new: Dict[str, Tuple[Union[str, int]]]):
29+
for value_info in shape_list_old:
30+
if value_info.name in shape_dict_new:
31+
name = value_info.name
32+
dims = value_info.type.tensor_type.shape.dim
33+
assert len(shape_dict_new[name]) == len(dims), \
34+
f'Number of given and existing dimensions mismatch: {name}'
35+
for i, dim in enumerate(shape_dict_new[name]):
36+
if isinstance(dim, int):
37+
dims[i].dim_param = ''
38+
dims[i].dim_value = dim
39+
else:
40+
dims[i].dim_value = 0
41+
dims[i].dim_param = dim
42+
43+
if input_shapes is not None:
44+
_override_shapes(model.graph.input, input_shapes)
45+
if output_shapes is not None:
46+
_override_shapes(model.graph.output, output_shapes)
47+
48+
49+
@click.command(help='Run inference with a trained model')
50+
@click.option('--model', required=True, metavar='CKPT_PATH', help='Path to the model checkpoint (*.ckpt)')
51+
@click.option('--out', required=False, metavar='ONNX_PATH', help='Path to the output model (*.onnx)')
52+
def export(model, out):
53+
model_path = pathlib.Path(model)
54+
with open(model_path.with_name('config.yaml'), 'r', encoding='utf8') as f:
55+
config = yaml.safe_load(f)
56+
print_config(config)
57+
module_cls = deployment.task_module_mapping[config['task_cls']]
58+
59+
pkg = ".".join(module_cls.split(".")[:-1])
60+
cls_name = module_cls.split(".")[-1]
61+
module_cls = getattr(importlib.import_module(pkg), cls_name)
62+
assert issubclass(module_cls, deployment.BaseONNXModule), \
63+
f'Module class {module_cls} is not a subclass of {deployment.BaseONNXModule}.'
64+
module_ins = module_cls(config=config, model_path=model_path)
65+
66+
waveform = torch.randn((1, 114514), dtype=torch.float32, device=module_ins.device)
67+
out_path = pathlib.Path(out) if out is not None else model_path.with_suffix('.onnx')
68+
torch.onnx.export(
69+
module_ins,
70+
waveform,
71+
out_path,
72+
input_names=['waveform'],
73+
output_names=[
74+
'note_midi',
75+
'note_rest',
76+
'note_dur'
77+
],
78+
dynamic_axes={
79+
'waveform': {
80+
1: 'n_samples'
81+
},
82+
'note_midi': {
83+
1: 'n_notes'
84+
},
85+
'note_rest': {
86+
1: 'n_notes'
87+
},
88+
'note_dur': {
89+
1: 'n_notes'
90+
},
91+
},
92+
opset_version=17
93+
)
94+
onnx_model = onnx.load(out_path.as_posix())
95+
onnx_override_io_shapes(onnx_model, output_shapes={
96+
'note_midi': (1, 'n_notes'),
97+
'note_rest': (1, 'n_notes'),
98+
'note_dur': (1, 'n_notes'),
99+
})
100+
print('Running ONNX Simplifier...')
101+
onnx_model, check = onnxsim.simplify(
102+
onnx_model,
103+
include_subgraph=True
104+
)
105+
assert check, 'Simplified ONNX model could not be validated'
106+
onnx.save(onnx_model, out_path)
107+
108+
109+
if __name__ == '__main__':
110+
export()

utils/infer_utils.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88

99
def decode_gaussian_blurred_probs(probs, vmin, vmax, deviation, threshold):
10-
num_bins = probs.shape[-1]
10+
num_bins = int(probs.shape[-1])
1111
interval = (vmax - vmin) / (num_bins - 1)
1212
width = int(3 * deviation / interval) # 3 * sigma
1313
idx = torch.arange(num_bins, device=probs.device)[None, None, :] # [1, 1, N]
@@ -24,14 +24,17 @@ def decode_gaussian_blurred_probs(probs, vmin, vmax, deviation, threshold):
2424
return values, rest
2525

2626

27-
def decode_bounds_to_alignment(bounds):
27+
def decode_bounds_to_alignment(bounds, use_diff=True):
2828
bounds_step = bounds.cumsum(dim=1).round().long()
29-
bounds_inc = torch.diff(
30-
bounds_step, dim=1, prepend=torch.full(
31-
(bounds.shape[0], 1), fill_value=-1,
32-
dtype=bounds_step.dtype, device=bounds_step.device
33-
)
34-
) > 0
29+
if use_diff:
30+
bounds_inc = torch.diff(
31+
bounds_step, dim=1, prepend=torch.full(
32+
(bounds.shape[0], 1), fill_value=-1,
33+
dtype=bounds_step.dtype, device=bounds_step.device
34+
)
35+
) > 0
36+
else:
37+
bounds_inc = F.pad((bounds_step[:, 1:] > bounds_step[:, :-1]), [1, 0], value=True)
3538
frame2item = bounds_inc.long().cumsum(dim=1)
3639
return frame2item
3740

@@ -48,25 +51,25 @@ def decode_note_sequence(frame2item, values, masks, threshold=0.5):
4851
b = frame2item.shape[0]
4952
space = frame2item.max() + 1
5053

51-
item_dur = frame2item.new_zeros(b, space).scatter_add(
54+
item_dur = frame2item.new_zeros(b, space, dtype=frame2item.dtype).scatter_add(
5255
1, frame2item, torch.ones_like(frame2item)
5356
)[:, 1:]
54-
item_unmasked_dur = frame2item.new_zeros(b, space).scatter_add(
57+
item_unmasked_dur = frame2item.new_zeros(b, space, dtype=frame2item.dtype).scatter_add(
5558
1, frame2item, masks.long()
5659
)[:, 1:]
5760
item_masks = item_unmasked_dur / item_dur >= threshold
5861

5962
values_quant = values.round().long()
60-
histogram = frame2item.new_zeros(b, space * 128).scatter_add(
63+
histogram = frame2item.new_zeros(b, space * 128, dtype=frame2item.dtype).scatter_add(
6164
1, frame2item * 128 + values_quant, torch.ones_like(frame2item) * masks
6265
).unflatten(1, [space, 128])[:, 1:, :]
63-
item_values_center = histogram.argmax(dim=2).to(dtype=values.dtype)
66+
item_values_center = histogram.float().argmax(dim=2).to(dtype=values.dtype)
6467
values_center = torch.gather(F.pad(item_values_center, [1, 0]), 1, frame2item)
6568
values_near_center = masks & (values >= values_center - 0.5) & (values <= values_center + 0.5)
66-
item_valid_dur = frame2item.new_zeros(b, space).scatter_add(
69+
item_valid_dur = frame2item.new_zeros(b, space, dtype=frame2item.dtype).scatter_add(
6770
1, frame2item, values_near_center.long()
6871
)[:, 1:]
69-
item_values = values.new_zeros(b, space).scatter_add(
72+
item_values = values.new_zeros(b, space, dtype=values.dtype).scatter_add(
7073
1, frame2item, values * values_near_center
7174
)[:, 1:] / (item_valid_dur + (item_valid_dur == 0))
7275

0 commit comments

Comments
 (0)