2828 "https" : "http://fwdproxy:8080" ,
2929}
3030
31-
3231def compute_sqnr (x : torch .Tensor , y : torch .Tensor ) -> float :
3332 assert x .shape == y .shape , "Tensor shapes do not match"
3433 x = x .float ()
@@ -173,15 +172,21 @@ def __init__(self, mimi: nn.Module):
173172 self .mimi_model = mimi
174173
175174 def forward (self , x ):
176- return self .mimi_model .decode (x )
175+ x = x .transpose (1 , 2 )
176+ x = self .mimi_model .upsample (x )
177+ (emb ,) = self .mimi_model .decoder_transformer (x )
178+ emb .transpose (1 , 2 )
179+ with self .mimi_model ._context_for_encoder_decoder :
180+ out = self .mimi_model .decoder (emb )
181+ return out
177182
178- sample_pcm = torch .tensor (self .sample_pcm , device = self .device )[None ]
179- pcm_chunk_size = int (self .mimi .sample_rate / self .mimi .frame_rate )
180- chunk = sample_pcm [..., 0 :pcm_chunk_size ]
181- input = self .mimi .encode (chunk )
183+ emb_input = torch .rand (1 , 1 , 512 , device = 'cpu' )
182184
183185 mimi_decode = MimiDecode (self .mimi )
184- exported_decode : ExportedProgram = export (mimi_decode , (input ,), strict = False )
186+ mimi_decode .eval ()
187+ mimi_decode (emb_input )
188+
189+ exported_decode : ExportedProgram = export (mimi_decode , (emb_input ,), strict = False )
185190 quantization_config = get_symmetric_quantization_config (
186191 is_per_channel = True ,
187192 is_dynamic = True ,
@@ -190,12 +195,12 @@ def forward(self, x):
190195 quantizer .set_global (quantization_config )
191196 m = exported_decode .module ()
192197 m = prepare_pt2e (m , quantizer )
193- m (input )
198+ m (emb_input )
194199 m = convert_pt2e (m )
195200 print ("quantized graph:" )
196201 print (m .graph )
197202 # Export quantized module
198- exported_decode : ExportedProgram = export (m , (input ,), strict = False )
203+ exported_decode : ExportedProgram = export (m , (emb_input ,), strict = False )
199204 # Lower
200205 edge_manager = to_edge_transform_and_lower (
201206 exported_decode ,
@@ -208,16 +213,16 @@ def forward(self, x):
208213 with open (output_file , "wb" ) as file :
209214 exec_prog .write_to_file (file )
210215
211- eager_res = mimi_decode (input )
216+ eager_res = mimi_decode (emb_input )
212217 runtime = Runtime .get ()
213218 program = runtime .load_program (output_file )
214219 method = program .load_method ("forward" )
215- flattened_x = tree_flatten (input )[0 ]
220+ flattened_x = tree_flatten (emb_input )[0 ]
216221 res = method .execute (flattened_x )
217222 # Compare results
218223 sqnr = compute_sqnr (eager_res , res [0 ])
219224 print (f"SQNR: { sqnr } " )
220- torch .testing .assert_close (eager_res , res [0 ], atol = 1e -3 , rtol = 1e-3 )
225+ torch .testing .assert_close (eager_res , res [0 ], atol = 4e -3 , rtol = 1e-3 )
221226
222227
223228if __name__ == "__main__" :
0 commit comments