|
| 1 | +import argparse |
| 2 | +from pathlib import Path |
| 3 | +from transformers import MimiModel, AutoFeatureExtractor |
| 4 | +from transformers.models.mimi.modeling_mimi import MimiEncoderOutput |
| 5 | + |
| 6 | +from scipy.io.wavfile import read |
| 7 | +from scipy.signal import resample |
| 8 | +import numpy as np |
| 9 | + |
| 10 | + |
| 11 | +def parse_args() -> argparse.Namespace: |
| 12 | + parser = argparse.ArgumentParser( |
| 13 | + description="Generate speaker reference file, used by llama-tts-csm example",) |
| 14 | + parser.add_argument( |
| 15 | + "--model-path", type=Path, |
| 16 | + help="custom Mimi model path (safetensors model). If not specified, will use the default model from Hugging Face hub", |
| 17 | + ) |
| 18 | + parser.add_argument( |
| 19 | + "infile", type=Path, |
| 20 | + help="the wav input file to use for generating the speaker reference file", |
| 21 | + nargs="?", |
| 22 | + ) |
| 23 | + # parser.add_argument( |
| 24 | + # "outfile", type=Path, |
| 25 | + # help="the output file, defaults to the input file with .codes suffix", |
| 26 | + # nargs="?", |
| 27 | + # ) |
| 28 | + |
| 29 | + return parser.parse_args() |
| 30 | + |
| 31 | + |
| 32 | +def main() -> None: |
| 33 | + args = parse_args() |
| 34 | + |
| 35 | + if args.infile is None: |
| 36 | + raise ValueError("Input file is required") |
| 37 | + |
| 38 | + if not args.infile.exists(): |
| 39 | + raise FileNotFoundError(f"Input file {args.infile} not found") |
| 40 | + |
| 41 | + # if args.outfile is None: |
| 42 | + # args.outfile = args.infile.with_suffix(".codes") |
| 43 | + |
| 44 | + model = MimiModel.from_pretrained(args.model_path or "kyutai/mimi") |
| 45 | + feature_extractor = AutoFeatureExtractor.from_pretrained(args.model_path or "kyutai/mimi") |
| 46 | + |
| 47 | + inp_audio = read(args.infile) |
| 48 | + original_sample_rate = inp_audio[0] |
| 49 | + audio_data = inp_audio[1] |
| 50 | + |
| 51 | + # If stereo, get only the first channel |
| 52 | + if len(audio_data.shape) > 1 and audio_data.shape[1] >= 2: |
| 53 | + audio_data = audio_data[:, 0] |
| 54 | + |
| 55 | + # resample |
| 56 | + target_sample_rate = 24000 |
| 57 | + number_of_samples = round(len(audio_data) * float(target_sample_rate) / original_sample_rate) |
| 58 | + resampled_audio = resample(audio_data, number_of_samples) |
| 59 | + resampled_audio = resampled_audio / max(np.max(np.abs(resampled_audio)), 1e-10) |
| 60 | + |
| 61 | + # pre-process the inputs |
| 62 | + audio_sample = np.array(resampled_audio, dtype=float) |
| 63 | + inputs = feature_extractor(raw_audio=audio_sample, sampling_rate=feature_extractor.sampling_rate, return_tensors="pt") |
| 64 | + print('inputs', inputs["input_values"], inputs["input_values"].shape) |
| 65 | + |
| 66 | + # encode |
| 67 | + encoder_outputs = model.encode(inputs["input_values"]) |
| 68 | + assert isinstance(encoder_outputs, MimiEncoderOutput), "encoder_outputs should be of type MimiEncoderOutput" |
| 69 | + |
| 70 | + # output |
| 71 | + flattened_audio_codes = encoder_outputs.audio_codes.transpose(-1, -2).flatten() |
| 72 | + for i in range(0, len(flattened_audio_codes), 16): |
| 73 | + for code in flattened_audio_codes[i:i+16].tolist(): |
| 74 | + print(f"{code:<5}", end=",") |
| 75 | + print() |
| 76 | + |
| 77 | + |
| 78 | +if __name__ == '__main__': |
| 79 | + main() |
0 commit comments