@@ -173,15 +173,23 @@ def __init__(self, mimi: nn.Module):
173173 self .mimi_model = mimi
174174
175175 def forward (self , x ):
176- return self .mimi_model .decode (x )
176+ x = x .transpose (1 , 2 )
177+ x = self .mimi_model .upsample (x )
178+ (emb ,) = self .mimi_model .decoder_transformer (x )
179+ emb .transpose (1 , 2 )
180+ with self .mimi_model ._context_for_encoder_decoder :
181+ out = self .mimi_model .decoder (emb )
182+ return out
177183
178- sample_pcm = torch .tensor (self .sample_pcm , device = self .device )[None ]
179- pcm_chunk_size = int (self .mimi .sample_rate / self .mimi .frame_rate )
180- chunk = sample_pcm [..., 0 :pcm_chunk_size ]
181- input = self .mimi .encode (chunk )
184+ emb_input = torch .rand (1 , 1 , 512 , device = "cpu" )
182185
183186 mimi_decode = MimiDecode (self .mimi )
184- exported_decode : ExportedProgram = export (mimi_decode , (input ,), strict = False )
187+ mimi_decode .eval ()
188+ mimi_decode (emb_input )
189+
190+ exported_decode : ExportedProgram = export (
191+ mimi_decode , (emb_input ,), strict = False
192+ )
185193 quantization_config = get_symmetric_quantization_config (
186194 is_per_channel = True ,
187195 is_dynamic = True ,
@@ -190,12 +198,12 @@ def forward(self, x):
190198 quantizer .set_global (quantization_config )
191199 m = exported_decode .module ()
192200 m = prepare_pt2e (m , quantizer )
193- m (input )
201+ m (emb_input )
194202 m = convert_pt2e (m )
195203 print ("quantized graph:" )
196204 print (m .graph )
197205 # Export quantized module
198- exported_decode : ExportedProgram = export (m , (input ,), strict = False )
206+ exported_decode : ExportedProgram = export (m , (emb_input ,), strict = False )
199207 # Lower
200208 edge_manager = to_edge_transform_and_lower (
201209 exported_decode ,
@@ -208,16 +216,16 @@ def forward(self, x):
208216 with open (output_file , "wb" ) as file :
209217 exec_prog .write_to_file (file )
210218
211- eager_res = mimi_decode (input )
219+ eager_res = mimi_decode (emb_input )
212220 runtime = Runtime .get ()
213221 program = runtime .load_program (output_file )
214222 method = program .load_method ("forward" )
215- flattened_x = tree_flatten (input )[0 ]
223+ flattened_x = tree_flatten (emb_input )[0 ]
216224 res = method .execute (flattened_x )
217225 # Compare results
218226 sqnr = compute_sqnr (eager_res , res [0 ])
219227 print (f"SQNR: { sqnr } " )
220- torch .testing .assert_close (eager_res , res [0 ], atol = 1e -3 , rtol = 1e-3 )
228+ torch .testing .assert_close (eager_res , res [0 ], atol = 4e -3 , rtol = 1e-3 )
221229
222230
223231if __name__ == "__main__" :
0 commit comments