@@ -59,7 +59,7 @@ def setUpClass(cls):
5959 """Setup once for all tests: Load model and prepare test data."""
6060
6161 # Get environment variables (if set), otherwise use default values
62- mimi_weight = os .getenv ("MIMI_WEIGHT" , None )
62+ cls . mimi_weight = os .getenv ("MIMI_WEIGHT" , None )
6363 hf_repo = os .getenv ("HF_REPO" , loaders .DEFAULT_REPO )
6464 device = "cuda" if torch .cuda .device_count () else "cpu"
6565
@@ -75,15 +75,15 @@ def seed_all(seed):
7575
7676 seed_all (42424242 )
7777
78- if mimi_weight is None :
78+ if cls . mimi_weight is None :
7979 try :
80- mimi_weight = hf_hub_download (hf_repo , loaders .MIMI_NAME )
80+ cls . mimi_weight = hf_hub_download (hf_repo , loaders .MIMI_NAME )
8181 except :
82- mimi_weight = hf_hub_download (
82+ cls . mimi_weight = hf_hub_download (
8383 hf_repo , loaders .MIMI_NAME , proxies = proxies
8484 )
8585
86- cls .mimi = loaders .get_mimi (mimi_weight , device )
86+ cls .mimi = loaders .get_mimi (cls . mimi_weight , device )
8787 cls .device = device
8888 cls .sample_pcm , cls .sample_sr = read_mp3_from_url (
8989 "https://huggingface.co/lmz/moshi-swift/resolve/main/bria-24khz.mp3"
@@ -182,8 +182,8 @@ def forward(self, x):
182182 return out
183183
184184 emb_input = torch .rand (1 , 1 , 512 , device = "cpu" )
185-
186- mimi_decode = MimiDecode (self . mimi )
185+ mimi_cpu = loaders . get_mimi ( self . mimi_weight , "cpu" )
186+ mimi_decode = MimiDecode (mimi_cpu )
187187 mimi_decode .eval ()
188188 mimi_decode (emb_input )
189189
@@ -225,7 +225,9 @@ def forward(self, x):
225225 # Compare results
226226 sqnr = compute_sqnr (eager_res , res [0 ])
227227 print (f"SQNR: { sqnr } " )
228- torch .testing .assert_close (eager_res , res [0 ], atol = 4e-3 , rtol = 1e-3 )
228+ # Don't check for exact equality, but check that the SQNR is high enough
229+ # torch.testing.assert_close(eager_res, res[0], atol=4e-3, rtol=1e-3)
230+ self .assertGreater (sqnr , 25.0 )
229231
230232
231233if __name__ == "__main__" :
0 commit comments