Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion examples/instruct_pix2pix/train_instruct_pix2pix.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]


# Set global timeout
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))


def log_validation(
pipeline,
args,
Expand Down Expand Up @@ -418,7 +422,7 @@ def convert_to_np(image, resolution):


def download_image(url):
image = PIL.Image.open(requests.get(url, stream=True).raw)
image = PIL.Image.open(requests.get(url, stream=True, timeout=request_timeout).raw)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@
WANDB_TABLE_COL_NAMES = ["original_image", "edited_image", "edit_prompt"]


# Set global timeout
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))


def save_model_card(
repo_id: str,
images: list = None,
Expand Down Expand Up @@ -475,7 +479,7 @@ def convert_to_np(image, resolution):


def download_image(url):
image = PIL.Image.open(requests.get(url, stream=True).raw)
image = PIL.Image.open(requests.get(url, stream=True, timeout=request_timeout).raw)
image = PIL.ImageOps.exif_transpose(image)
image = image.convert("RGB")
return image
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
"""Conversion script for stable diffusion checkpoints which _only_ contain a controlnet."""

import argparse
import os
import re
from contextlib import nullcontext
from io import BytesIO
Expand Down Expand Up @@ -68,6 +69,10 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


# Set global timeout
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))


def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
Expand Down Expand Up @@ -1435,7 +1440,7 @@ def download_from_original_stable_diffusion_ckpt(
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"

if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content)
original_config_file = BytesIO(requests.get(config_url, timeout=request_timeout).content)
else:
with open(original_config_file, "r") as f:
original_config_file = f.read()
Expand Down
5 changes: 4 additions & 1 deletion scripts/convert_dance_diffusion_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,9 @@
},
}

# Set global timeout
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))


def alpha_sigma_to_t(alpha, sigma):
"""Returns a timestep, given the scaling factors for the clean image and for
Expand Down Expand Up @@ -74,7 +77,7 @@ def __init__(self, global_args):

def download(model_name):
url = MODELS_MAP[model_name]["url"]
r = requests.get(url, stream=True)
r = requests.get(url, stream=True, timeout=request_timeout)

local_filename = f"./{model_name}.ckpt"
with open(local_filename, "wb") as fp:
Expand Down
7 changes: 6 additions & 1 deletion scripts/convert_vae_pt_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import argparse
import io
import os

import requests
import torch
Expand All @@ -15,6 +16,10 @@
)


# Set global timeout
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))


def custom_convert_ldm_vae_checkpoint(checkpoint, config):
vae_state_dict = checkpoint

Expand Down Expand Up @@ -122,7 +127,7 @@ def vae_pt_to_vae_diffuser(
):
# Only support V1
r = requests.get(
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml"
" https://raw.githubusercontent.com/CompVis/stable-diffusion/main/configs/stable-diffusion/v1-inference.yaml", timeout=request_timeout
)
io_obj = io.BytesIO(r.content)

Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@

logger = logging.get_logger(__name__) # pylint: disable=invalid-name

# Set global timeout
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))

CHECKPOINT_KEY_NAMES = {
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
Expand Down Expand Up @@ -443,7 +446,7 @@ def fetch_original_config(original_config_file, local_files_only=False):
"Please provide a valid local file path."
)

original_config_file = BytesIO(requests.get(original_config_file).content)
original_config_file = BytesIO(requests.get(original_config_file, timeout=request_timeout).content)

else:
raise ValueError("Invalid `original_config_file` provided. Please set it to a valid file path or URL.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# limitations under the License.
"""Conversion script for the Stable Diffusion checkpoints."""

import os
import re
from contextlib import nullcontext
from io import BytesIO
Expand Down Expand Up @@ -66,6 +67,10 @@
logger = logging.get_logger(__name__) # pylint: disable=invalid-name


# Set global timeout
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))


def shave_segments(path, n_shave_prefix_segments=1):
"""
Removes segments. Positive values shave the first segments, negative shave the last segments.
Expand Down Expand Up @@ -1324,7 +1329,7 @@ def download_from_original_stable_diffusion_ckpt(
config_url = "https://raw.githubusercontent.com/Stability-AI/stablediffusion/main/configs/stable-diffusion/x4-upscaling.yaml"

if config_url is not None:
original_config_file = BytesIO(requests.get(config_url).content)
original_config_file = BytesIO(requests.get(config_url, timeout=request_timeout).content)
else:
with open(original_config_file, "r") as f:
original_config_file = f.read()
Expand Down
6 changes: 5 additions & 1 deletion src/diffusers/utils/loading_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
from .import_utils import BACKENDS_MAPPING, is_imageio_available


# Set global timeout
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))


def load_image(
image: Union[str, PIL.Image.Image], convert_method: Optional[Callable[[PIL.Image.Image], PIL.Image.Image]] = None
) -> PIL.Image.Image:
Expand All @@ -29,7 +33,7 @@ def load_image(
"""
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
image = PIL.Image.open(requests.get(image, stream=True).raw)
image = PIL.Image.open(requests.get(image, stream=True, timeout=request_timeout).raw)
elif os.path.isfile(image):
image = PIL.Image.open(image)
else:
Expand Down
9 changes: 6 additions & 3 deletions src/diffusers/utils/testing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@
from .logging import get_logger


# Set global timeout
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))

global_rng = random.Random()

logger = get_logger(__name__)
Expand Down Expand Up @@ -594,7 +597,7 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -
# local_path can be passed to correct images of tests
return Path(local_path, arry.split("/")[-5], arry.split("/")[-2], arry.split("/")[-1]).as_posix()
elif arry.startswith("http://") or arry.startswith("https://"):
response = requests.get(arry)
response = requests.get(arry, timeout=request_timeout)
response.raise_for_status()
arry = np.load(BytesIO(response.content))
elif os.path.isfile(arry):
Expand All @@ -615,7 +618,7 @@ def load_numpy(arry: Union[str, np.ndarray], local_path: Optional[str] = None) -


def load_pt(url: str, map_location: str):
response = requests.get(url)
response = requests.get(url, timeout=request_timeout)
response.raise_for_status()
arry = torch.load(BytesIO(response.content), map_location=map_location)
return arry
Expand All @@ -634,7 +637,7 @@ def load_image(image: Union[str, PIL.Image.Image]) -> PIL.Image.Image:
"""
if isinstance(image, str):
if image.startswith("http://") or image.startswith("https://"):
image = PIL.Image.open(requests.get(image, stream=True).raw)
image = PIL.Image.open(requests.get(image, stream=True, timeout=request_timeout).raw)
elif os.path.isfile(image):
image = PIL.Image.open(image)
else:
Expand Down
8 changes: 7 additions & 1 deletion utils/fetch_latest_release_branch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import requests
from packaging.version import parse

Expand All @@ -22,12 +24,16 @@
REPO = "diffusers"


# Set global timeout
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))


def fetch_all_branches(user, repo):
branches = [] # List to store all branches
page = 1 # Start from first page
while True:
# Make a request to the GitHub API for the branches
response = requests.get(f"https://api.github.com/repos/{user}/{repo}/branches", params={"page": page})
response = requests.get(f"https://api.github.com/repos/{user}/{repo}/branches", params={"page": page}, timeout=request_timeout)

# Check if the request was successful
if response.status_code == 200:
Expand Down
8 changes: 6 additions & 2 deletions utils/notify_slack_about_release.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@
SLACK_WEBHOOK_URL = os.getenv("SLACK_WEBHOOK_URL")


# Set global timeout
request_timeout = int(os.environ.get("DIFFUSERS_REQUEST_TIMEOUT", 60))


def check_pypi_for_latest_release(library_name):
"""Check PyPI for the latest release of the library."""
response = requests.get(f"https://pypi.org/pypi/{library_name}/json")
response = requests.get(f"https://pypi.org/pypi/{library_name}/json", timeout=request_timeout)
if response.status_code == 200:
data = response.json()
return data["info"]["version"]
Expand All @@ -38,7 +42,7 @@ def check_pypi_for_latest_release(library_name):
def get_github_release_info(github_repo):
"""Fetch the latest release info from GitHub."""
url = f"https://api.github.com/repos/{github_repo}/releases/latest"
response = requests.get(url)
response = requests.get(url, timeout=request_timeout)

if response.status_code == 200:
data = response.json()
Expand Down