@@ -210,22 +210,73 @@ def forward(self, x):
210210
211211
212212def mimi_decode (
213- mimi , encode_res_list , pcm_chunk_size , skip_node_id_set , skip_node_op_set
213+ args , mimi , encode_res_list , pcm_chunk_size , skip_node_id_set , skip_node_op_set
214214) -> torch .Tensor :
215- class MimiDecode (nn .Module ):
216- def __init__ (self , mimi : nn .Module ):
217- super ().__init__ ()
218- self .mimi_model = mimi
219-
215+ from pathlib import Path
216+ from safetensors .torch import load_model
217+ def _is_safetensors (path : Path | str ) -> bool :
218+ return Path (path ).suffix in (".safetensors" , ".sft" , ".sfts" )
219+ from moshi .models .compression import MimiModel
220+ from moshi .modules .seanet import SEANetEncoder , SEANetDecoder
221+ from moshi .modules import transformer
222+ from moshi .models .loaders import _seanet_kwargs , _quantizer_kwargs , _transformer_kwargs
223+ from moshi .quantization .vq import SplitResidualVectorQuantizer
224+
225+ class MimiDecode (MimiModel ):
220226 def forward (self , x ):
221- return self . mimi_model .decode (x )
227+ return super () .decode (x )
222228
223- mimi_decode_model = MimiDecode (mimi )
224- decode_inputs , decode_input_list = [], ""
225- for index , encoder_res in enumerate (encode_res_list ):
226- decode_inputs .append ((encoder_res .to (torch .int32 ),))
227- decode_input_list += f"input_{ index } _0.raw\n "
229+ encoder = SEANetEncoder (** _seanet_kwargs )
230+ decoder = SEANetDecoder (** _seanet_kwargs )
231+ encoder_transformer = transformer .ProjectedTransformer (
232+ device = 'cpu' , ** _transformer_kwargs
233+ )
234+ decoder_transformer = transformer .ProjectedTransformer (
235+ device = 'cpu' , ** _transformer_kwargs
236+ )
237+ quantizer = SplitResidualVectorQuantizer (
238+ ** _quantizer_kwargs ,
239+ )
228240
241+ mimi_decode_model = MimiDecode (
242+ encoder ,
243+ decoder ,
244+ quantizer ,
245+ channels = 1 ,
246+ sample_rate = 24000 ,
247+ frame_rate = 12.5 ,
248+ encoder_frame_rate = 24000 / encoder .hop_length ,
249+ causal = True ,
250+ resample_method = "conv" ,
251+ encoder_transformer = encoder_transformer ,
252+ decoder_transformer = decoder_transformer ,)
253+ mimi_decode_model .eval ()
254+ if _is_safetensors (args .mimi_weight ):
255+ load_model (mimi_decode_model , args .mimi_weight , strict = False )
256+
257+ decode_inputs , decode_input_list = [], ""
258+
259+
260+ all_codes = []
261+ sample_input = encode_res_list [..., 0 : 1 ]
262+ with mimi_decode_model .streaming (1 ):
263+ #---------------------------------------------Works fine below with nn.Module---------------------------------------------
264+ # for i in range(encode_res_list.shape[-1]):
265+ # codes = encode_res_list[..., i : i + 1]
266+ # pcm = mimi_decode_model(codes)
267+ # all_codes.append(pcm)
268+ #---------------------------------------------SQNR drops to 8.5 after export---------------------------------------------
269+ captured_model = torch .export .export (mimi_decode_model , (sample_input ,), strict = False ).module ()
270+ for i in range (encode_res_list .shape [- 1 ]):
271+ codes = encode_res_list [..., i : i + 1 ]
272+ pcm = captured_model (codes )
273+ all_codes .append (pcm )
274+
275+
276+
277+ cpu_decode_res = torch .cat (all_codes , dim = - 1 )
278+ return cpu_decode_res
279+
229280 pte_filename = "mimi_decoder_qnn"
230281
231282 quantizer = make_quantizer (
@@ -314,14 +365,14 @@ def export_mimi(mimi, args, max_duration_sec=10.0):
314365
315366 print ("streaming encoding..." )
316367 cpu_encode_res = mimi .encode (sample_pcm )
317- htp_encode_res = mimi_encode (
318- mimi ,
319- encoder_inputs ,
320- encoder_input_list ,
321- pcm_chunk_size ,
322- skip_node_id_set ,
323- skip_node_op_set ,
324- )
368+ # htp_encode_res = mimi_encode(
369+ # mimi,
370+ # encoder_inputs,
371+ # encoder_input_list,
372+ # pcm_chunk_size,
373+ # skip_node_id_set,
374+ # skip_node_op_set,
375+ # )
325376
326377 # Leave it here for now, uncomment this to check htp_encoder with cpu_decoder
327378 # htp_res = torch.cat(htp_encode_res, dim=-1)
@@ -332,19 +383,19 @@ def export_mimi(mimi, args, max_duration_sec=10.0):
332383 cpu_decode_res = mimi .decode (cpu_encode_res )
333384 # TODO: Enable streaming mode, which is the correct way to execute 1 chunk at a time.
334385 # with mimi.streaming(1):
335- htp_decode_res = mimi_decode (
336- mimi , htp_encode_res , pcm_chunk_size , skip_node_id_set , skip_node_op_set
386+ cpu_streaming_decode_res = mimi_decode (
387+ args , mimi , cpu_encode_res , pcm_chunk_size , skip_node_id_set , skip_node_op_set
337388 )
338- compute_scores (cpu_decode_res , htp_decode_res )
389+ compute_scores (cpu_decode_res , cpu_streaming_decode_res )
339390
340391 sphn .write_wav (
341392 f"{ args .artifact } /cpu_decode_res.wav" ,
342393 cpu_decode_res [0 , 0 ].cpu ().numpy (),
343394 sample_rate ,
344395 )
345396 sphn .write_wav (
346- f"{ args .artifact } /htp_decode_res .wav" ,
347- htp_decode_res [0 , 0 ].cpu ().numpy (),
397+ f"{ args .artifact } /cpu_streaming_decode_res .wav" ,
398+ cpu_streaming_decode_res [0 , 0 ].cpu ().numpy (),
348399 sample_rate ,
349400 )
350401
0 commit comments