Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions src/diffusers/loaders/single_file_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from huggingface_hub.utils import validate_hf_hub_args
from typing_extensions import Self

from .. import __version__
from ..quantizers import DiffusersAutoQuantizer
from ..utils import deprecate, is_accelerate_available, logging
from .single_file_utils import (
Expand Down Expand Up @@ -260,6 +261,11 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
device = kwargs.pop("device", None)
disable_mmap = kwargs.pop("disable_mmap", False)

user_agent = {"diffusers": __version__, "file_type": "single_file", "framework": "pytorch"}
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
if quantization_config is not None:
user_agent["quant"] = quantization_config.quant_method.value

if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
torch_dtype = torch.float32
logger.warning(
Expand All @@ -278,6 +284,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
local_files_only=local_files_only,
revision=revision,
disable_mmap=disable_mmap,
user_agent=user_agent,
)
if quantization_config is not None:
hf_quantizer = DiffusersAutoQuantizer.from_config(quantization_config)
Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/loaders/single_file_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,13 +405,16 @@ def load_single_file_checkpoint(
local_files_only=None,
revision=None,
disable_mmap=False,
user_agent=None,
):
if user_agent is None:
user_agent = {"file_type": "single_file", "framework": "pytorch"}

if os.path.isfile(pretrained_model_link_or_path):
pretrained_model_link_or_path = pretrained_model_link_or_path

else:
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
user_agent = {"file_type": "single_file", "framework": "pytorch"}
pretrained_model_link_or_path = _get_model_file(
repo_id,
weights_name=weights_name,
Expand Down
Loading