Skip to content

Commit 4dd115c

Browse files
Sai-Suraj-27kghamilton89
authored andcommitted
Fixed requests.get function call by adding timeout parameter.
1 parent 739d6ec commit 4dd115c

File tree

11 files changed

+60
-14
lines changed

11 files changed

+60
-14
lines changed

examples/instruct_pix2pix/train_instruct_pix2pix.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,10 @@
6767
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
6868

6969

70+
# Set global timeout
71+
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))
72+
73+
7074
def log_validation(
7175
pipeline,
7276
args,
@@ -418,7 +422,7 @@ def convert_to_np(image, resolution):
418422

419423

420424
def download_image(url):
421-
image = PIL.Image.open(requests.get(url, stream=True).raw)
425+
image = PIL.Image.open(requests.get(url, stream=True, timeout=request_timeout).raw)
422426
image = PIL.ImageOps.exif_transpose(image)
423427
image = image.convert("RGB")
424428
return image

examples/research_projects/instructpix2pix_lora/train_instruct_pix2pix_lora.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,10 @@
7474
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]
7575

7676

77+
# Set global timeout
78+
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))
79+
80+
7781
def save_model_card(
7882
repo_id: str,
7983
images: list = None,
@@ -475,7 +479,7 @@ def convert_to_np(image, resolution):
475479

476480

477481
def download_image(url):
478-
image = PIL.Image.open(requests.get(url, stream=True).raw)
482+
image = PIL.Image.open(requests.get(url, stream=True, timeout=request_timeout).raw)
479483
image = PIL.ImageOps.exif_transpose(image)
480484
image = image.convert("RGB")
481485
return image

examples/research_projects/promptdiffusion/convert_original_promptdiffusion_to_diffusers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
"""Conversion script for stable diffusion checkpoints which _only_ contain a controlnet."""
1616

1717
import argparse
18+
import os
1819
import re
1920
from contextlib import nullcontext
2021
from io import BytesIO
@@ -68,6 +69,10 @@
6869
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
6970

7071

72+
# Set global timeout
73+
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))
74+
75+
7176
def shave_segments(path, n_shave_prefix_segments=1):
7277
"""
7378
Removes segments. Positive values shave the first segments, negative shave the last segments.
@@ -1435,7 +1440,7 @@ def download_from_original_stable_diffusion_ckpt(
14351440
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
14361441

14371442
if config_url is not None:
1438-
original_config_file = BytesIO(requests.get(config_url).content)
1443+
original_config_file = BytesIO(requests.get(config_url, timeout=request_timeout).content)
14391444
else:
14401445
with open(original_config_file, "r") as f:
14411446
original_config_file = f.read()

scripts/convert_dance_diffusion_to_diffusers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@
4646
},
4747
}
4848

49+
# Set global timeout
50+
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))
51+
4952

5053
def alpha_sigma_to_t(alpha, sigma):
5154
"""Returns a timestep, given the scaling factors for the clean image and for
@@ -74,7 +77,7 @@ def __init__(self, global_args):
7477

7578
def download(model_name):
7679
url = MODELS_MAP[model_name]["url"]
77-
r = requests.get(url, stream=True)
80+
r = requests.get(url, stream=True, timeout=request_timeout)
7881

7982
local_filename = f"./{model_name}.ckpt"
8083
with open(local_filename, "wb") as fp:

scripts/convert_vae_pt_to_diffusers.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import argparse
22
import io
3+
import os
34

45
import requests
56
import torch
@@ -15,6 +16,10 @@
1516
)
1617

1718

19+
# Set global timeout
20+
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))
21+
22+
1823
def custom_convert_ldm_vae_checkpoint(checkpoint, config):
1924
vae_state_dict = checkpoint
2025

@@ -122,7 +127,7 @@ def vae_pt_to_vae_diffuser(
122127
):
123128
# Only support V1
124129
r = requests.get(
125-
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
130+
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml", timeout=request_timeout
126131
)
127132
io_obj = io.BytesIO(r.content)
128133

src/diffusers/loaders/single_file_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@
5757

5858
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
5959

60+
# Set global timeout
61+
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))
62+
6063
CHECKPOINT_KEY_NAMES = {
6164
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
6265
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
@@ -443,7 +446,7 @@ def fetch_original_config(original_config_file, local_files_only=False):
443446
"Please provide a valid local file path."
444447
)
445448

446-
original_config_file = BytesIO(requests.get(original_config_file).content)
449+
original_config_file = BytesIO(requests.get(original_config_file, timeout=request_timeout).content)
447450

448451
else:
449452
raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")

src/diffusers/pipelines/stable_diffusion/convert_from_ckpt.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# limitations under the License.
1515
"""Conversion script for the Stable Diffusion checkpoints."""
1616

17+
import os
1718
import re
1819
from contextlib import nullcontext
1920
from io import BytesIO
@@ -66,6 +67,10 @@
6667
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
6768

6869

70+
# Set global timeout
71+
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))
72+
73+
6974
def shave_segments(path, n_shave_prefix_segments=1):
7075
"""
7176
Removes segments. Positive values shave the first segments, negative shave the last segments.
@@ -1324,7 +1329,7 @@ def download_from_original_stable_diffusion_ckpt(
13241329
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"
13251330

13261331
if config_url is not None:
1327-
original_config_file = BytesIO(requests.get(config_url).content)
1332+
original_config_file = BytesIO(requests.get(config_url, timeout=request_timeout).content)
13281333
else:
13291334
with open(original_config_file, "r") as f:
13301335
original_config_file = f.read()

src/diffusers/utils/loading_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
from .import_utils import BACKENDS_MAPPING, is_imageio_available
1111

1212

13+
# Set global timeout
14+
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))
15+
16+
1317
def load_image(
1418
image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None
1519
) -> PIL.Image.Image:
@@ -29,7 +33,7 @@ def load_image(
2933
"""
3034
if isinstance(image, str):
3135
if image.startswith("http://") or image.startswith("https://"):
32-
image = PIL.Image.open(requests.get(image, stream=True).raw)
36+
image = PIL.Image.open(requests.get(image, stream=True, timeout=request_timeout).raw)
3337
elif os.path.isfile(image):
3438
image = PIL.Image.open(image)
3539
else:

src/diffusers/utils/testing_utils.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@
4747
from .logging import get_logger
4848

4949

50+
# Set global timeout
51+
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))
52+
5053
global_rng = random.Random()
5154

5255
logger = get_logger(__name__)
@@ -594,7 +597,7 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
594597
# local_path can be passed to correct images of tests
595598
return Path(local_path, arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]).as_posix()
596599
elif arry.startswith("http://") or arry.startswith("https://"):
597-
response = requests.get(arry)
600+
response = requests.get(arry, timeout=request_timeout)
598601
response.raise_for_status()
599602
arry = np.load(BytesIO(response.content))
600603
elif os.path.isfile(arry):
@@ -615,7 +618,7 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
615618

616619

617620
def load_pt(url: str, map_location: str):
618-
response = requests.get(url)
621+
response = requests.get(url, timeout=request_timeout)
619622
response.raise_for_status()
620623
arry = torch.load(BytesIO(response.content), map_location=map_location)
621624
return arry
@@ -634,7 +637,7 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
634637
"""
635638
if isinstance(image, str):
636639
if image.startswith("http://") or image.startswith("https://"):
637-
image = PIL.Image.open(requests.get(image, stream=True).raw)
640+
image = PIL.Image.open(requests.get(image, stream=True, timeout=request_timeout).raw)
638641
elif os.path.isfile(image):
639642
image = PIL.Image.open(image)
640643
else:

utils/fetch_latest_release_branch.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import os
17+
1618
import requests
1719
from packaging.version import parse
1820

@@ -22,12 +24,16 @@
2224
REPO = "diffusers"
2325

2426

27+
# Set global timeout
28+
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))
29+
30+
2531
def fetch_all_branches(user, repo):
2632
branches = [] # List to store all branches
2733
page = 1 # Start from first page
2834
while True:
2935
# Make a request to the GitHub API for the branches
30-
response = requests.get(f"https://api.github.com/repos/{user}/{repo}/branches", params={"page": page})
36+
response = requests.get(f"https://api.github.com/repos/{user}/{repo}/branches", params={"page": page}, timeout=request_timeout)
3137

3238
# Check if the request was successful
3339
if response.status_code == 200:

0 commit comments

Comments
 (0)