Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mlx_audio/codec/models/encodec/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .encodec import Encodec
from .encodec import Encodec, EncodecConfig
39 changes: 38 additions & 1 deletion mlx_audio/codec/models/encodec/encodec.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import json
import math
from dataclasses import dataclass
from pathlib import Path
from types import SimpleNamespace
from typing import List, Optional, Tuple, Union
Expand All @@ -11,6 +12,40 @@
from huggingface_hub import snapshot_download


def filter_dataclass_fields(data_dict, dataclass_type):
"""Filter a dictionary to only include keys that are fields in the dataclass."""
valid_fields = {f.name for f in dataclass_type.__dataclass_fields__.values()}
return {k: v for k, v in data_dict.items() if k in valid_fields}


@dataclass
class EncodecConfig:
model_type: str = "encodec"
audio_channels: int = 1
num_filters: int = 32
kernel_size: int = 7
num_residual_layers: int = 1
dilation_growth_rate: int = 2
codebook_size: int = 1024
codebook_dim: int = 128
hidden_size: int = 128
num_lstm_layers: int = 2
residual_kernel_size: int = 3
use_causal_conv: bool = True
normalize: bool = False
pad_mode: str = "reflect"
norm_type: str = "weight_norm"
last_kernel_size: int = 7
trim_right_ratio: float = 1.0
compress: int = 2
upsampling_ratios: List[int] = None
target_bandwidths: List[float] = None
sampling_rate: int = 24000
chunk_length_s: Optional[float] = None
overlap: Optional[float] = None
architectures: List[str] = None


def preprocess_audio(
raw_audio: Union[mx.array, List[mx.array]],
sampling_rate: int = 24000,
Expand Down Expand Up @@ -513,7 +548,7 @@ def decode(self, codes: mx.array) -> mx.array:
class Encodec(nn.Module):
def __init__(self, config):
super().__init__()
self.config = SimpleNamespace(**config)
self.config = config
self.encoder = EncodecEncoder(self.config)
self.decoder = EncodecDecoder(self.config)
self.quantizer = EncodecResidualVectorQuantizer(self.config)
Expand Down Expand Up @@ -689,6 +724,8 @@ def from_pretrained(cls, path_or_repo: str):
with open(path / "config.json", "r") as f:
config = json.load(f)

filtered_config = filter_dataclass_fields(config, EncodecConfig)
config = EncodecConfig(**filtered_config)
model = cls(config)
model.load_weights(str(path / "model.safetensors"))
processor = functools.partial(
Expand Down
54 changes: 27 additions & 27 deletions mlx_audio/codec/tests/test_encodec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,33 +2,33 @@

import mlx.core as mx

from ..models.encodec import Encodec

config = {
"audio_channels": 1,
"chunk_length_s": None,
"codebook_dim": 128,
"codebook_size": 1024,
"compress": 2,
"dilation_growth_rate": 2,
"hidden_size": 128,
"kernel_size": 7,
"last_kernel_size": 7,
"model_type": "encodec",
"norm_type": "weight_norm",
"normalize": False,
"num_filters": 32,
"num_lstm_layers": 2,
"num_residual_layers": 1,
"overlap": None,
"pad_mode": "reflect",
"residual_kernel_size": 3,
"sampling_rate": 24000,
"target_bandwidths": [1.5, 3.0, 6.0, 12.0, 24.0],
"trim_right_ratio": 1.0,
"upsampling_ratios": [8, 5, 4, 2],
"use_causal_conv": True,
}
from ..models.encodec import Encodec, EncodecConfig

config = EncodecConfig(
audio_channels=1,
chunk_length_s=None,
codebook_dim=128,
codebook_size=1024,
compress=2,
dilation_growth_rate=2,
hidden_size=128,
kernel_size=7,
last_kernel_size=7,
model_type="encodec",
norm_type="weight_norm",
normalize=False,
num_filters=32,
num_lstm_layers=2,
num_residual_layers=1,
overlap=None,
pad_mode="reflect",
residual_kernel_size=3,
sampling_rate=24000,
target_bandwidths=[1.5, 3.0, 6.0, 12.0, 24.0],
trim_right_ratio=1.0,
upsampling_ratios=[8, 5, 4, 2],
use_causal_conv=True,
)


class TesEncodec(unittest.TestCase):
Expand Down
9 changes: 7 additions & 2 deletions mlx_audio/tts/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ def generate_audio(
join_audio: bool = False,
play: bool = False,
verbose: bool = True,
from_cli: bool = False,
temperature: float = 0.7,
**kwargs,
) -> None:
"""
Generates audio from text using a specified TTS model.
Expand Down Expand Up @@ -85,6 +86,7 @@ def generate_audio(
lang_code=lang_code,
ref_audio=ref_audio,
ref_text=ref_text,
temperature=temperature,
verbose=True,
)

Expand Down Expand Up @@ -154,7 +156,7 @@ def parse_args():
default=None,
help="Text to generate (leave blank to input via stdin)",
)
parser.add_argument("--voice", type=str, default="af_heart", help="Voice name")
parser.add_argument("--voice", type=str, default=None, help="Voice name")
parser.add_argument("--speed", type=float, default=1.0, help="Speed of the audio")
parser.add_argument("--lang_code", type=str, default="a", help="Language code")
parser.add_argument(
Expand All @@ -177,6 +179,9 @@ def parse_args():
parser.add_argument(
"--ref_text", type=str, default=None, help="Caption for reference audio"
)
parser.add_argument(
"--temperature", type=float, default=0.7, help="Temperature for the model"
)

args = parser.parse_args()

Expand Down
4 changes: 4 additions & 0 deletions mlx_audio/tts/models/bark/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .bark import Model, ModelConfig
from .pipeline import Pipeline

__all__ = ["Model", "Pipeline", "ModelConfig"]
Loading