Skip to content

Commit b5f2a75

Browse files
author
Martin Yuan
committed
Export Mimi model to ExecuTorch
1 parent afcec1d commit b5f2a75

File tree

2 files changed

+160
-0
lines changed

2 files changed

+160
-0
lines changed
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
#!/bin/bash
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
set -x
9+
10+
pip install -U moshi
11+
pip install bitsandbytes
12+
# Run llama2/install requirements for torchao deps
13+
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
14+
15+
bash "$SCRIPT_DIR"/../llama/install_requirements.sh
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
# Copyright (c) Kyutai, all rights reserved.
2+
# This source code is licensed under the license found in the
3+
# LICENSE file in the root directory of this source tree.
4+
5+
import argparse
6+
import random
7+
import time
8+
9+
from huggingface_hub import hf_hub_download
10+
import numpy as np
11+
import sphn
12+
import torch
13+
from torch.profiler import profile, ProfilerActivity
14+
15+
from moshi.models import loaders
16+
17+
import torch.nn as nn
18+
19+
from executorch.examples.models.llama.llama_transformer import Transformer
20+
21+
from executorch.examples.models.llama.model_args import ModelArgs
22+
23+
from torch.export import export, export_for_training, ExportedProgram
24+
25+
from executorch.exir import (
26+
EdgeCompileConfig,
27+
ExecutorchBackendConfig,
28+
to_edge_transform_and_lower,
29+
)
30+
from executorch.backends.xnnpack.partition.xnnpack_partitioner import XnnpackPartitioner
31+
32+
parser = argparse.ArgumentParser()
33+
parser.add_argument("--mimi-weight", type=str)
34+
parser.add_argument("--hf-repo", type=str, default=loaders.DEFAULT_REPO)
35+
parser.add_argument(
36+
"--device", type=str, default="cuda" if torch.cuda.device_count() else "cpu"
37+
)
38+
parser.add_argument("--profile", action="store_true")
39+
args = parser.parse_args()
40+
41+
42+
def seed_all(seed):
43+
torch.manual_seed(seed)
44+
if torch.cuda.is_available():
45+
torch.cuda.manual_seed(seed)
46+
torch.cuda.manual_seed_all(seed) # for multi-GPU setups
47+
random.seed(seed)
48+
np.random.seed(seed)
49+
torch.backends.cudnn.deterministic = True
50+
torch.backends.cudnn.benchmark = False
51+
52+
53+
seed_all(42424242)
54+
55+
56+
print("loading mimi")
57+
if args.mimi_weight is None:
58+
args.mimi_weight = hf_hub_download(args.hf_repo, loaders.MIMI_NAME)
59+
mimi = loaders.get_mimi(args.mimi_weight, args.device)
60+
print("mimi loaded")
61+
# emb = torch.load('emb.pt')
62+
63+
def mimi_test(mimi, max_duration_sec=10.0):
64+
pcm_chunk_size = int(mimi.sample_rate / mimi.frame_rate)
65+
sample_rate = mimi.sample_rate
66+
# Uncomment below to get real audio
67+
# # wget https://github.com/metavoiceio/metavoice-src/raw/main/assets/bria.mp3
68+
# sample_pcm, sample_sr = sphn.read("/Users/myuan/src/moshi0/src/moshi/data/bria-24khz.mp3")
69+
# print("loaded pcm", sample_pcm.shape, sample_sr)
70+
# sample_pcm = sphn.resample(
71+
# sample_pcm, src_sample_rate=sample_sr, dst_sample_rate=sample_rate
72+
# )
73+
# sample_pcm = torch.tensor(sample_pcm, device=args.device)
74+
# max_duration_len = int(sample_rate * max_duration_sec)
75+
# if sample_pcm.shape[-1] > max_duration_len:
76+
# sample_pcm = sample_pcm[..., :max_duration_len]
77+
# print("resampled pcm", sample_pcm.shape, sample_sr)
78+
# sample_pcm = sample_pcm[None].to(device=args.device)
79+
#
80+
sample_pcm = torch.ones(1,1,240000)
81+
82+
print("streaming encoding...")
83+
start_time = time.time()
84+
all_codes = []
85+
86+
def run_loop():
87+
for start_idx in range(0, sample_pcm.shape[-1], pcm_chunk_size):
88+
end_idx = min(sample_pcm.shape[-1], start_idx + pcm_chunk_size)
89+
chunk = sample_pcm[..., start_idx:end_idx]
90+
codes = mimi.encode(chunk)
91+
if codes.shape[-1]:
92+
print(start_idx, codes.shape, end="\r")
93+
all_codes.append(codes)
94+
95+
if args.profile:
96+
with profile(activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA]) as prof:
97+
run_loop()
98+
prof.export_chrome_trace("trace.json")
99+
else:
100+
run_loop()
101+
all_codes_th = torch.cat(all_codes, dim=-1)
102+
print(f"codes {all_codes_th.shape} generated in {time.time() - start_time:.2f}s")
103+
print("streaming decoding...")
104+
all_pcms = []
105+
# with mimi.streaming(1):
106+
# for i in range(all_codes_th.shape[-1]):
107+
# codes = all_codes_th[..., i : i + 1]
108+
# pcm = mimi.decode(codes)
109+
# print(i, pcm.shape, end="\r")
110+
# all_pcms.append(pcm)
111+
# all_pcms = torch.cat(all_pcms, dim=-1)
112+
# print("pcm", all_pcms.shape, all_pcms.dtype)
113+
# sphn.write_wav("streaming_out.wav", all_pcms[0, 0].cpu().numpy(), sample_rate)
114+
pcm_ref = mimi.decode(all_codes_th)
115+
116+
class MimiDecode(nn.Module):
117+
def __init__(self, mimi: nn.Module):
118+
super().__init__()
119+
self.mimi_model = mimi
120+
121+
def forward(self, x):
122+
return self.mimi_model.decode(x)
123+
124+
mimi_decode = MimiDecode(mimi)
125+
126+
ep: ExportedProgram = torch.export.export(mimi_decode, (all_codes_th,), strict=False)
127+
edge_prog = to_edge_transform_and_lower(
128+
ep,
129+
partitioner=[XnnpackPartitioner()],
130+
)
131+
class MimiEncode(nn.Module):
132+
def __init__(self, mimi: nn.Module):
133+
super().__init__()
134+
self.mimi_model = mimi
135+
136+
def forward(self, x):
137+
return self.mimi_model.encode(x)
138+
139+
mimi_encode = MimiEncode(mimi)
140+
chunk = sample_pcm[..., 0:pcm_chunk_size]
141+
out = mimi_encode(chunk)
142+
exported_encode = torch.export.export(mimi_encode, (chunk,), strict=False).module()
143+
144+
with torch.no_grad():
145+
mimi_test(mimi)

0 commit comments

Comments
 (0)