22import os
33import random
44import unittest
5+ import urllib
56
67import numpy as np
78import requests
2324from torch .export import export , ExportedProgram
2425from torch .utils ._pytree import tree_flatten
2526
26- os .environ ["https_proxy" ] = "http://fwdproxy:8080"
27+ proxies = {
28+ "http" : "http://fwdproxy:8080" ,
29+ "https" : "http://fwdproxy:8080" ,
30+ }
2731
2832
2933def compute_sqnr (x : torch .Tensor , y : torch .Tensor ) -> float :
@@ -38,7 +42,12 @@ def compute_sqnr(x: torch.Tensor, y: torch.Tensor) -> float:
3842
3943
4044def read_mp3_from_url (url ):
41- response = requests .get (url )
45+ try :
46+ response = requests .get (url )
47+ except :
48+ # FB-only hack, need to use a forwarding proxy to get url
49+ response = requests .get (url , proxies = proxies )
50+
4251 response .raise_for_status () # Ensure request is successful
4352 audio_stream = io .BytesIO (response .content )
4453 waveform , sample_rate = torchaudio .load (audio_stream , format = "mp3" )
@@ -68,7 +77,13 @@ def seed_all(seed):
6877 seed_all (42424242 )
6978
7079 if mimi_weight is None :
71- mimi_weight = hf_hub_download (hf_repo , loaders .MIMI_NAME )
80+ try :
81+ mimi_weight = hf_hub_download (hf_repo , loaders .MIMI_NAME )
82+ except :
83+ mimi_weight = hf_hub_download (
84+ hf_repo , loaders .MIMI_NAME , proxies = proxies
85+ )
86+
7287 cls .mimi = loaders .get_mimi (mimi_weight , device )
7388 cls .device = device
7489 cls .sample_pcm , cls .sample_sr = read_mp3_from_url (
0 commit comments