@@ -135,16 +135,28 @@ def test_streaming_encoding_decoding(self):
135135
136136 all_codes_th = torch .cat (all_codes , dim = - 1 )
137137
138+ pcm_ref = self .mimi .decode (all_codes_th )
139+
138140 all_pcms = []
141+ for i in range (all_codes_th .shape [- 1 ]):
142+ codes = all_codes_th [..., i : i + 1 ]
143+ pcm = self .mimi .decode (codes )
144+ all_pcms .append (pcm )
145+ all_pcms = torch .cat (all_pcms , dim = - 1 )
146+ sqnr = compute_sqnr (pcm_ref , all_pcms )
147+ print (f"sqnr = { sqnr } dB" )
148+ self .assertTrue (sqnr > 4 )
149+
150+ all_pcms_streaming = []
139151 with self .mimi .streaming (1 ):
140152 for i in range (all_codes_th .shape [- 1 ]):
141153 codes = all_codes_th [..., i : i + 1 ]
142- pcm = self .mimi .decode (codes )
143- all_pcms .append (pcm )
144- all_pcms = torch .cat (all_pcms , dim = - 1 )
145-
146- pcm_ref = self . mimi . decode ( all_codes_th )
147- self .assertTrue (torch . allclose ( pcm_ref , all_pcms , atol = 1e-5 ) )
154+ pcm_streaming = self .mimi .decode (codes )
155+ all_pcms_streaming .append (pcm_streaming )
156+ all_pcms_streaming = torch .cat (all_pcms_streaming , dim = - 1 )
157+ sqnr_streaming = compute_sqnr ( pcm_ref , all_pcms_streaming )
158+ print ( f"sqnr_streaming = { sqnr_streaming } dB" )
159+ self .assertTrue (sqnr_streaming > 100 )
148160
149161 def test_exported_encoding (self ):
150162 """Ensure exported encoding model is consistent with reference output."""
0 commit comments