Skip to content

Commit 6731348

Browse files
authored
Add Suno bark (#45)
* add bark * update bark * fix pipeline and add config dataclass * fix audio * fix codec * update voices for suno * add temperature * add voice * add tests * bump version * format * remove lm_heads * migrate to mlx encodec * format * fix tests and voice * add todo * format * fix test
1 parent 79f0c1a commit 6731348

File tree

15 files changed

+1287
-41
lines changed

15 files changed

+1287
-41
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .encodec import Encodec
1+
from .encodec import Encodec, EncodecConfig

mlx_audio/codec/models/encodec/encodec.py

Lines changed: 38 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import functools
22
import json
33
import math
4+
from dataclasses import dataclass
45
from pathlib import Path
56
from types import SimpleNamespace
67
from typing import List, Optional, Tuple, Union
@@ -11,6 +12,40 @@
1112
from huggingface_hub import snapshot_download
1213

1314

15+
def filter_dataclass_fields(data_dict, dataclass_type):
16+
"""Filter a dictionary to only include keys that are fields in the dataclass."""
17+
valid_fields = {f.name for f in dataclass_type.__dataclass_fields__.values()}
18+
return {k: v for k, v in data_dict.items() if k in valid_fields}
19+
20+
21+
@dataclass
22+
class EncodecConfig:
23+
model_type: str = "encodec"
24+
audio_channels: int = 1
25+
num_filters: int = 32
26+
kernel_size: int = 7
27+
num_residual_layers: int = 1
28+
dilation_growth_rate: int = 2
29+
codebook_size: int = 1024
30+
codebook_dim: int = 128
31+
hidden_size: int = 128
32+
num_lstm_layers: int = 2
33+
residual_kernel_size: int = 3
34+
use_causal_conv: bool = True
35+
normalize: bool = False
36+
pad_mode: str = "reflect"
37+
norm_type: str = "weight_norm"
38+
last_kernel_size: int = 7
39+
trim_right_ratio: float = 1.0
40+
compress: int = 2
41+
upsampling_ratios: List[int] = None
42+
target_bandwidths: List[float] = None
43+
sampling_rate: int = 24000
44+
chunk_length_s: Optional[float] = None
45+
overlap: Optional[float] = None
46+
architectures: List[str] = None
47+
48+
1449
def preprocess_audio(
1550
raw_audio: Union[mx.array, List[mx.array]],
1651
sampling_rate: int = 24000,
@@ -513,7 +548,7 @@ def decode(self, codes: mx.array) -> mx.array:
513548
class Encodec(nn.Module):
514549
def __init__(self, config):
515550
super().__init__()
516-
self.config = SimpleNamespace(**config)
551+
self.config = config
517552
self.encoder = EncodecEncoder(self.config)
518553
self.decoder = EncodecDecoder(self.config)
519554
self.quantizer = EncodecResidualVectorQuantizer(self.config)
@@ -689,6 +724,8 @@ def from_pretrained(cls, path_or_repo: str):
689724
with open(path / "config.json", "r") as f:
690725
config = json.load(f)
691726

727+
filtered_config = filter_dataclass_fields(config, EncodecConfig)
728+
config = EncodecConfig(**filtered_config)
692729
model = cls(config)
693730
model.load_weights(str(path / "model.safetensors"))
694731
processor = functools.partial(

mlx_audio/codec/tests/test_encodec.py

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -2,33 +2,33 @@
22

33
import mlx.core as mx
44

5-
from ..models.encodec import Encodec
6-
7-
config = {
8-
"audio_channels": 1,
9-
"chunk_length_s": None,
10-
"codebook_dim": 128,
11-
"codebook_size": 1024,
12-
"compress": 2,
13-
"dilation_growth_rate": 2,
14-
"hidden_size": 128,
15-
"kernel_size": 7,
16-
"last_kernel_size": 7,
17-
"model_type": "encodec",
18-
"norm_type": "weight_norm",
19-
"normalize": False,
20-
"num_filters": 32,
21-
"num_lstm_layers": 2,
22-
"num_residual_layers": 1,
23-
"overlap": None,
24-
"pad_mode": "reflect",
25-
"residual_kernel_size": 3,
26-
"sampling_rate": 24000,
27-
"target_bandwidths": [1.5, 3.0, 6.0, 12.0, 24.0],
28-
"trim_right_ratio": 1.0,
29-
"upsampling_ratios": [8, 5, 4, 2],
30-
"use_causal_conv": True,
31-
}
5+
from ..models.encodec import Encodec, EncodecConfig
6+
7+
config = EncodecConfig(
8+
audio_channels=1,
9+
chunk_length_s=None,
10+
codebook_dim=128,
11+
codebook_size=1024,
12+
compress=2,
13+
dilation_growth_rate=2,
14+
hidden_size=128,
15+
kernel_size=7,
16+
last_kernel_size=7,
17+
model_type="encodec",
18+
norm_type="weight_norm",
19+
normalize=False,
20+
num_filters=32,
21+
num_lstm_layers=2,
22+
num_residual_layers=1,
23+
overlap=None,
24+
pad_mode="reflect",
25+
residual_kernel_size=3,
26+
sampling_rate=24000,
27+
target_bandwidths=[1.5, 3.0, 6.0, 12.0, 24.0],
28+
trim_right_ratio=1.0,
29+
upsampling_ratios=[8, 5, 4, 2],
30+
use_causal_conv=True,
31+
)
3232

3333

3434
class TesEncodec(unittest.TestCase):

mlx_audio/tts/generate.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@ def generate_audio(
2424
join_audio: bool = False,
2525
play: bool = False,
2626
verbose: bool = True,
27-
from_cli: bool = False,
27+
temperature: float = 0.7,
28+
**kwargs,
2829
) -> None:
2930
"""
3031
Generates audio from text using a specified TTS model.
@@ -85,6 +86,7 @@ def generate_audio(
8586
lang_code=lang_code,
8687
ref_audio=ref_audio,
8788
ref_text=ref_text,
89+
temperature=temperature,
8890
verbose=True,
8991
)
9092

@@ -154,7 +156,7 @@ def parse_args():
154156
default=None,
155157
help="Text to generate (leave blank to input via stdin)",
156158
)
157-
parser.add_argument("--voice", type=str, default="af_heart", help="Voice name")
159+
parser.add_argument("--voice", type=str, default=None, help="Voice name")
158160
parser.add_argument("--speed", type=float, default=1.0, help="Speed of the audio")
159161
parser.add_argument("--lang_code", type=str, default="a", help="Language code")
160162
parser.add_argument(
@@ -177,6 +179,9 @@ def parse_args():
177179
parser.add_argument(
178180
"--ref_text", type=str, default=None, help="Caption for reference audio"
179181
)
182+
parser.add_argument(
183+
"--temperature", type=float, default=0.7, help="Temperature for the model"
184+
)
180185

181186
args = parser.parse_args()
182187

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .bark import Model, ModelConfig
2+
from .pipeline import Pipeline
3+
4+
__all__ = ["Model", "Pipeline", "ModelConfig"]

0 commit comments

Comments
 (0)