2828 "https" : "http://fwdproxy:8080" ,
2929}
3030
31-
3231def compute_sqnr (x : torch .Tensor , y : torch .Tensor ) -> float :
3332 assert x .shape == y .shape , "Tensor shapes do not match"
3433 x = x .float ()
@@ -39,7 +38,6 @@ def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> float:
3938 sqnr = 10 * torch .log10 (original_power / error_power )
4039 return sqnr .item ()
4140
42-
4341def read_mp3_from_url (url ):
4442 try :
4543 response = requests .get (url )
@@ -173,15 +171,23 @@ def __init__(self, mimi: nn.Module):
173171 self .mimi_model = mimi
174172
175173 def forward (self , x ):
176- return self .mimi_model .decode (x )
174+ x = x .transpose (1 , 2 )
175+ x = self .mimi_model .upsample (x )
176+ (emb ,) = self .mimi_model .decoder_transformer (x )
177+ emb .transpose (1 , 2 )
178+ with self .mimi_model ._context_for_encoder_decoder :
179+ out = self .mimi_model .decoder (emb )
180+ return out
177181
178- sample_pcm = torch .tensor (self .sample_pcm , device = self .device )[None ]
179- pcm_chunk_size = int (self .mimi .sample_rate / self .mimi .frame_rate )
180- chunk = sample_pcm [..., 0 :pcm_chunk_size ]
181- input = self .mimi .encode (chunk )
182+ emb_input = torch .rand (1 , 1 , 512 , device = "cpu" )
182183
183184 mimi_decode = MimiDecode (self .mimi )
184- exported_decode : ExportedProgram = export (mimi_decode , (input ,), strict = False )
185+ mimi_decode .eval ()
186+ mimi_decode (emb_input )
187+
188+ exported_decode : ExportedProgram = export (
189+ mimi_decode , (emb_input ,), strict = False
190+ )
185191 quantization_config = get_symmetric_quantization_config (
186192 is_per_channel = True ,
187193 is_dynamic = True ,
@@ -190,12 +196,12 @@ def forward(self, x):
190196 quantizer .set_global (quantization_config )
191197 m = exported_decode .module ()
192198 m = prepare_pt2e (m , quantizer )
193- m (input )
199+ m (emb_input )
194200 m = convert_pt2e (m )
195201 print ("quantized graph:" )
196202 print (m .graph )
197203 # Export quantized module
198- exported_decode : ExportedProgram = export (m , (input ,), strict = False )
204+ exported_decode : ExportedProgram = export (m , (emb_input ,), strict = False )
199205 # Lower
200206 edge_manager = to_edge_transform_and_lower (
201207 exported_decode ,
@@ -208,16 +214,16 @@ def forward(self, x):
208214 with open (output_file , "wb" ) as file :
209215 exec_prog .write_to_file (file )
210216
211- eager_res = mimi_decode (input )
217+ eager_res = mimi_decode (emb_input )
212218 runtime = Runtime .get ()
213219 program = runtime .load_program (output_file )
214220 method = program .load_method ("forward" )
215- flattened_x = tree_flatten (input )[0 ]
221+ flattened_x = tree_flatten (emb_input )[0 ]
216222 res = method .execute (flattened_x )
217223 # Compare results
218224 sqnr = compute_sqnr (eager_res , res [0 ])
219225 print (f"SQNR: { sqnr } " )
220- torch .testing .assert_close (eager_res , res [0 ], atol = 1e -3 , rtol = 1e-3 )
226+ torch .testing .assert_close (eager_res , res [0 ], atol = 4e -3 , rtol = 1e-3 )
221227
222228
223229if __name__ == "__main__" :
0 commit comments