Skip to content

Commit d1de6cc

Browse files
committed
add speaker reference
1 parent 1219827 commit d1de6cc

File tree

4 files changed

+1753
-29
lines changed

4 files changed

+1753
-29
lines changed
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
import argparse
2+
from pathlib import Path
3+
from transformers import MimiModel, AutoFeatureExtractor
4+
from transformers.models.mimi.modeling_mimi import MimiEncoderOutput
5+
6+
from scipy.io.wavfile import read
7+
from scipy.signal import resample
8+
import numpy as np
9+
10+
11+
def parse_args() -> argparse.Namespace:
12+
parser = argparse.ArgumentParser(
13+
description="Generate speaker reference file, used by llama-tts-csm example",)
14+
parser.add_argument(
15+
"--model-path", type=Path,
16+
help="custom Mimi model path (safetensors model). If not specified, will use the default model from Hugging Face hub",
17+
)
18+
parser.add_argument(
19+
"infile", type=Path,
20+
help="the wav input file to use for generating the speaker reference file",
21+
nargs="?",
22+
)
23+
# parser.add_argument(
24+
# "outfile", type=Path,
25+
# help="the output file, defaults to the input file with .codes suffix",
26+
# nargs="?",
27+
# )
28+
29+
return parser.parse_args()
30+
31+
32+
def main() -> None:
33+
args = parse_args()
34+
35+
if args.infile is None:
36+
raise ValueError("Input file is required")
37+
38+
if not args.infile.exists():
39+
raise FileNotFoundError(f"Input file {args.infile} not found")
40+
41+
# if args.outfile is None:
42+
# args.outfile = args.infile.with_suffix(".codes")
43+
44+
model = MimiModel.from_pretrained(args.model_path or "kyutai/mimi")
45+
feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_path or "kyutai/mimi")
46+
47+
inp_audio = read(args.infile)
48+
original_sample_rate = inp_audio[0]
49+
audio_data = inp_audio[1]
50+
51+
# If stereo, get only the first channel
52+
if len(audio_data.shape) > 1 and audio_data.shape[1] >= 2:
53+
audio_data = audio_data[:, 0]
54+
55+
# resample
56+
target_sample_rate = 24000
57+
number_of_samples = round(len(audio_data) * float(target_sample_rate) / original_sample_rate)
58+
resampled_audio = resample(audio_data, number_of_samples)
59+
resampled_audio = resampled_audio / max(np.max(np.abs(resampled_audio)), 1e-10)
60+
61+
# pre-process the inputs
62+
audio_sample = np.array(resampled_audio, dtype=float)
63+
inputs = feature_extractor(raw_audio=audio_sample, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt")
64+
print('inputs', inputs["input_values"], inputs["input_values"].shape)
65+
66+
# encode
67+
encoder_outputs = model.encode(inputs["input_values"])
68+
assert isinstance(encoder_outputs, MimiEncoderOutput), "encoder_outputs should be of type MimiEncoderOutput"
69+
70+
# output
71+
flattened_audio_codes = encoder_outputs.audio_codes.transpose(-1, -2).flatten()
72+
for i in range(0, len(flattened_audio_codes), 16):
73+
for code in flattened_audio_codes[i:i+16].tolist():
74+
print(f"{code:<5}", end=",")
75+
print()
76+
77+
78+
if __name__ == '__main__':
79+
main()

0 commit comments

Comments
 (0)