Skip to content

Commit 2e1b801

Browse files
committed
DRAFT PR to reproduce low sqnr
1 parent 08759d6 commit 2e1b801

File tree

1 file changed

+76
-25
lines changed
  • examples/qualcomm/oss_scripts/moshi

1 file changed

+76
-25
lines changed

examples/qualcomm/oss_scripts/moshi/mimi.py

Lines changed: 76 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -210,22 +210,73 @@ def forward(self, x):
210210

211211

212212
def mimi_decode(
213-
mimi, encode_res_list, pcm_chunk_size, skip_node_id_set, skip_node_op_set
213+
args, mimi, encode_res_list, pcm_chunk_size, skip_node_id_set, skip_node_op_set
214214
) -> torch.Tensor:
215-
class MimiDecode(nn.Module):
216-
def __init__(self, mimi: nn.Module):
217-
super().__init__()
218-
self.mimi_model = mimi
219-
215+
from pathlib import Path
216+
from safetensors.torch import load_model
217+
def _is_safetensors(path: Path | str) -> bool:
218+
return Path(path).suffix in (".safetensors", ".sft", ".sfts")
219+
from moshi.models.compression import MimiModel
220+
from moshi.modules.seanet import SEANetEncoder, SEANetDecoder
221+
from moshi.modules import transformer
222+
from moshi.models.loaders import _seanet_kwargs, _quantizer_kwargs, _transformer_kwargs
223+
from moshi.quantization.vq import SplitResidualVectorQuantizer
224+
225+
class MimiDecode(MimiModel):
220226
def forward(self, x):
221-
return self.mimi_model.decode(x)
227+
return super().decode(x)
222228

223-
mimi_decode_model = MimiDecode(mimi)
224-
decode_inputs, decode_input_list = [], ""
225-
for index, encoder_res in enumerate(encode_res_list):
226-
decode_inputs.append((encoder_res.to(torch.int32),))
227-
decode_input_list += f"input_{index}_0.raw\n"
229+
encoder = SEANetEncoder(**_seanet_kwargs)
230+
decoder = SEANetDecoder(**_seanet_kwargs)
231+
encoder_transformer = transformer.ProjectedTransformer(
232+
device='cpu', **_transformer_kwargs
233+
)
234+
decoder_transformer = transformer.ProjectedTransformer(
235+
device='cpu', **_transformer_kwargs
236+
)
237+
quantizer = SplitResidualVectorQuantizer(
238+
**_quantizer_kwargs,
239+
)
228240

241+
mimi_decode_model = MimiDecode(
242+
encoder,
243+
decoder,
244+
quantizer,
245+
channels=1,
246+
sample_rate=24000,
247+
frame_rate=12.5,
248+
encoder_frame_rate=24000 / encoder.hop_length,
249+
causal=True,
250+
resample_method="conv",
251+
encoder_transformer=encoder_transformer,
252+
decoder_transformer=decoder_transformer,)
253+
mimi_decode_model.eval()
254+
if _is_safetensors(args.mimi_weight):
255+
load_model(mimi_decode_model, args.mimi_weight, strict=False)
256+
257+
decode_inputs, decode_input_list = [], ""
258+
259+
260+
all_codes = []
261+
sample_input = encode_res_list[..., 0 : 1]
262+
with mimi_decode_model.streaming(1):
263+
#---------------------------------------------Works fine below with nn.Module---------------------------------------------
264+
# for i in range(encode_res_list.shape[-1]):
265+
# codes = encode_res_list[..., i : i + 1]
266+
# pcm = mimi_decode_model(codes)
267+
# all_codes.append(pcm)
268+
#---------------------------------------------SQNR drops to 8.5 after export---------------------------------------------
269+
captured_model = torch.export.export(mimi_decode_model, (sample_input,), strict=False).module()
270+
for i in range(encode_res_list.shape[-1]):
271+
codes = encode_res_list[..., i : i + 1]
272+
pcm = captured_model(codes)
273+
all_codes.append(pcm)
274+
275+
276+
277+
cpu_decode_res = torch.cat(all_codes, dim=-1)
278+
return cpu_decode_res
279+
229280
pte_filename = "mimi_decoder_qnn"
230281

231282
quantizer = make_quantizer(
@@ -314,14 +365,14 @@ def export_mimi(mimi, args, max_duration_sec=10.0):
314365

315366
print("streaming encoding...")
316367
cpu_encode_res = mimi.encode(sample_pcm)
317-
htp_encode_res = mimi_encode(
318-
mimi,
319-
encoder_inputs,
320-
encoder_input_list,
321-
pcm_chunk_size,
322-
skip_node_id_set,
323-
skip_node_op_set,
324-
)
368+
# htp_encode_res = mimi_encode(
369+
# mimi,
370+
# encoder_inputs,
371+
# encoder_input_list,
372+
# pcm_chunk_size,
373+
# skip_node_id_set,
374+
# skip_node_op_set,
375+
# )
325376

326377
# Leave it here for now, uncomment this to check htp_encoder with cpu_decoder
327378
# htp_res = torch.cat(htp_encode_res, dim=-1)
@@ -332,19 +383,19 @@ def export_mimi(mimi, args, max_duration_sec=10.0):
332383
cpu_decode_res = mimi.decode(cpu_encode_res)
333384
# TODO: Enable streaming mode, which is the correct way to execute 1 chunk at a time.
334385
# with mimi.streaming(1):
335-
htp_decode_res = mimi_decode(
336-
mimi, htp_encode_res, pcm_chunk_size, skip_node_id_set, skip_node_op_set
386+
cpu_streaming_decode_res = mimi_decode(
387+
args, mimi, cpu_encode_res, pcm_chunk_size, skip_node_id_set, skip_node_op_set
337388
)
338-
compute_scores(cpu_decode_res, htp_decode_res)
389+
compute_scores(cpu_decode_res, cpu_streaming_decode_res)
339390

340391
sphn.write_wav(
341392
f"{args.artifact}/cpu_decode_res.wav",
342393
cpu_decode_res[0, 0].cpu().numpy(),
343394
sample_rate,
344395
)
345396
sphn.write_wav(
346-
f"{args.artifact}/htp_decode_res.wav",
347-
htp_decode_res[0, 0].cpu().numpy(),
397+
f"{args.artifact}/cpu_streaming_decode_res.wav",
398+
cpu_streaming_decode_res[0, 0].cpu().numpy(),
348399
sample_rate,
349400
)
350401

0 commit comments

Comments
 (0)