Skip to content

Commit 2c4645c

Browse files
committed
make fix-copies
1 parent 37e8a95 commit 2c4645c

File tree

3 files changed

+69
-11
lines changed

3 files changed

+69
-11
lines changed

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,10 @@ def retrieve_timesteps(
7171
num_inference_steps: Optional[int] = None,
7272
device: Optional[Union[str, torch.device]] = None,
7373
timesteps: Optional[List[int]] = None,
74+
sigmas: Optional[List[float]] = None,
7475
**kwargs,
7576
):
76-
"""
77+
r"""
7778
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
7879
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
7980
@@ -86,14 +87,18 @@ def retrieve_timesteps(
8687
device (`str` or `torch.device`, *optional*):
8788
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
8889
timesteps (`List[int]`, *optional*):
89-
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default
90-
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps`
91-
must be `None`.
90+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
91+
`num_inference_steps` and `sigmas` must be `None`.
92+
sigmas (`List[float]`, *optional*):
93+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
94+
`num_inference_steps` and `timesteps` must be `None`.
9295
9396
Returns:
9497
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
9598
second element is the number of inference steps.
9699
"""
100+
if timesteps is not None and sigmas is not None:
101+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
97102
if timesteps is not None:
98103
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
99104
if not accepts_timesteps:
@@ -104,6 +109,16 @@ def retrieve_timesteps(
104109
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
105110
timesteps = scheduler.timesteps
106111
num_inference_steps = len(timesteps)
112+
elif sigmas is not None:
113+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
114+
if not accept_sigmas:
115+
raise ValueError(
116+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
117+
f" sigmas schedules. Please check whether you are using the correct scheduler."
118+
)
119+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
120+
timesteps = scheduler.timesteps
121+
num_inference_steps = len(timesteps)
107122
else:
108123
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
109124
timesteps = scheduler.timesteps
@@ -458,14 +473,12 @@ def _clean_caption(self, caption):
458473
caption = re.sub("<person>", "person", caption)
459474
# urls:
460475
caption = re.sub(
461-
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))",
462-
# noqa
476+
r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
463477
"",
464478
caption,
465479
) # regex for urls
466480
caption = re.sub(
467-
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))",
468-
# noqa
481+
r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa
469482
"",
470483
caption,
471484
) # regex for urls
@@ -488,13 +501,12 @@ def _clean_caption(self, caption):
488501
caption = re.sub(r"[\u3300-\u33ff]+", "", caption)
489502
caption = re.sub(r"[\u3400-\u4dbf]+", "", caption)
490503
caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption)
491-
# caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
504+
caption = re.sub(r"[\u4e00-\u9fff]+", "", caption)
492505
#######################################################
493506

494507
# все виды тире / all types of dash --> "-"
495508
caption = re.sub(
496-
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+",
497-
# noqa
509+
r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa
498510
"-",
499511
caption,
500512
)
@@ -565,6 +577,7 @@ def _clean_caption(self, caption):
565577
caption = re.sub(r"^[\'\_,\-\:;]", r"", caption)
566578
caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption)
567579
caption = re.sub(r"^\.\S+$", "", caption)
580+
568581
return caption.strip()
569582

570583
def prepare_latents(

src/diffusers/utils/dummy_pt_objects.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,21 @@
22
from ..utils import DummyObject, requires_backends
33

44

5+
class AllegroTransformer3DModel(metaclass=DummyObject):
6+
_backends = ["torch"]
7+
8+
def __init__(self, *args, **kwargs):
9+
requires_backends(self, ["torch"])
10+
11+
@classmethod
12+
def from_config(cls, *args, **kwargs):
13+
requires_backends(cls, ["torch"])
14+
15+
@classmethod
16+
def from_pretrained(cls, *args, **kwargs):
17+
requires_backends(cls, ["torch"])
18+
19+
520
class AsymmetricAutoencoderKL(metaclass=DummyObject):
621
_backends = ["torch"]
722

@@ -47,6 +62,21 @@ def from_pretrained(cls, *args, **kwargs):
4762
requires_backends(cls, ["torch"])
4863

4964

65+
class AutoencoderKLAllegro(metaclass=DummyObject):
66+
_backends = ["torch"]
67+
68+
def __init__(self, *args, **kwargs):
69+
requires_backends(self, ["torch"])
70+
71+
@classmethod
72+
def from_config(cls, *args, **kwargs):
73+
requires_backends(cls, ["torch"])
74+
75+
@classmethod
76+
def from_pretrained(cls, *args, **kwargs):
77+
requires_backends(cls, ["torch"])
78+
79+
5080
class AutoencoderKLCogVideoX(metaclass=DummyObject):
5181
_backends = ["torch"]
5282

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,21 @@
22
from ..utils import DummyObject, requires_backends
33

44

5+
class AllegroPipeline(metaclass=DummyObject):
6+
_backends = ["torch", "transformers"]
7+
8+
def __init__(self, *args, **kwargs):
9+
requires_backends(self, ["torch", "transformers"])
10+
11+
@classmethod
12+
def from_config(cls, *args, **kwargs):
13+
requires_backends(cls, ["torch", "transformers"])
14+
15+
@classmethod
16+
def from_pretrained(cls, *args, **kwargs):
17+
requires_backends(cls, ["torch", "transformers"])
18+
19+
520
class AltDiffusionImg2ImgPipeline(metaclass=DummyObject):
621
_backends = ["torch", "transformers"]
722

0 commit comments

Comments
 (0)