4747from .logging import get_logger
4848
4949
50+ # Set global timeout
51+ request_timeout = int (os .environ .get ("DIFFUSERS_REQUEST_TIMEOUT" , 60 ))
52+
5053global_rng = random .Random ()
5154
5255logger = get_logger (__name__ )
@@ -594,7 +597,7 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
594597 # local_path can be passed to correct images of tests
595598 return Path (local_path , arry .split ("/" )[- 5 ], arry .split ("/" )[- 2 ], arry .split ("/" )[- 1 ]).as_posix ()
596599 elif arry .startswith ("http://" ) or arry .startswith ("https://" ):
597- response = requests .get (arry )
600+ response = requests .get (arry , timeout = request_timeout )
598601 response .raise_for_status ()
599602 arry = np .load (BytesIO (response .content ))
600603 elif os .path .isfile (arry ):
@@ -615,7 +618,7 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
615618
616619
617620def load_pt (url : str , map_location : str ):
618- response = requests .get (url )
621+ response = requests .get (url , timeout = request_timeout )
619622 response .raise_for_status ()
620623 arry = torch .load (BytesIO (response .content ), map_location = map_location )
621624 return arry
@@ -634,7 +637,7 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
634637 """
635638 if isinstance (image , str ):
636639 if image .startswith ("http://" ) or image .startswith ("https://" ):
637- image = PIL .Image .open (requests .get (image , stream = True ).raw )
640+ image = PIL .Image .open (requests .get (image , stream = True , timeout = request_timeout ).raw )
638641 elif os .path .isfile (image ):
639642 image = PIL .Image .open (image )
640643 else :
0 commit comments