Skip to content

Commit 4e49f1c

Browse files
authored
Add orpheus (#47)
* Add orpheus * package it * rename to llama and fix inference * add snac * change tokenizer * working audio * remove prompt cache * move to stream generate * update args * add processing time and rtf * format * add tests * format * add snac as dep (temp) * mock tokenizer * change toknizer name * mock toknizer * fix * fix gen
1 parent 9d03df4 commit 4e49f1c

File tree

5 files changed

+576
-15
lines changed

5 files changed

+576
-15
lines changed

mlx_audio/tts/generate.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ def generate_audio(
3434
- text (str): The input text to be converted to speech.
3535
- model (str): The TTS model to use.
3636
- voice (str): The voice style to use.
37+
- temperature (float): The temperature for the model.
3738
- speed (float): Playback speed multiplier.
3839
- lang_code (str): The language code.
3940
- ref_audio (mx.array): Reference audio you would like to clone the voice from.
@@ -182,6 +183,14 @@ def parse_args():
182183
parser.add_argument(
183184
"--temperature", type=float, default=0.7, help="Temperature for the model"
184185
)
186+
parser.add_argument("--top_p", type=float, default=0.9, help="Top-p for the model")
187+
parser.add_argument("--top_k", type=int, default=50, help="Top-k for the model")
188+
parser.add_argument(
189+
"--repetition_penalty",
190+
type=float,
191+
default=1.1,
192+
help="Repetition penalty for the model",
193+
)
185194

186195
args = parser.parse_args()
187196

@@ -198,21 +207,7 @@ def parse_args():
198207
def main():
199208
args = parse_args()
200209

201-
generate_audio(
202-
text=args.text,
203-
model_path=args.model,
204-
voice=args.voice,
205-
speed=args.speed,
206-
lang_code=args.lang_code,
207-
ref_audio=args.ref_audio,
208-
ref_text=args.ref_text,
209-
file_prefix=args.file_prefix,
210-
audio_format=args.audio_format,
211-
sample_rate=args.sample_rate,
212-
join_audio=args.join_audio,
213-
play=args.play,
214-
verbose=args.verbose,
215-
)
210+
generate_audio(model_path=args.model, **vars(args))
216211

217212

218213
if __name__ == "__main__":
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .llama import Model, ModelConfig
2+
3+
__all__ = ["Model", "ModelConfig"]

0 commit comments

Comments
 (0)