Skip to content

Commit 7c9dc97

Browse files
committed
up
1 parent c12a61f commit 7c9dc97

9 files changed

+125
-716
lines changed

src/diffusers/pipelines/qwenimage/pipeline_qwen_utils.py

Lines changed: 99 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import inspect
1516
import math
1617
from typing import List, Optional, Union
1718

@@ -20,6 +21,104 @@
2021
from ...utils import deprecate
2122

2223

24+
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
25+
def calculate_shift(
26+
image_seq_len,
27+
base_seq_len: int = 256,
28+
max_seq_len: int = 4096,
29+
base_shift: float = 0.5,
30+
max_shift: float = 1.15,
31+
):
32+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
33+
b = base_shift - m * base_seq_len
34+
mu = image_seq_len * m + b
35+
return mu
36+
37+
38+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
39+
def retrieve_timesteps(
40+
scheduler,
41+
num_inference_steps: Optional[int] = None,
42+
device: Optional[Union[str, torch.device]] = None,
43+
timesteps: Optional[List[int]] = None,
44+
sigmas: Optional[List[float]] = None,
45+
**kwargs,
46+
):
47+
r"""
48+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
49+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
50+
51+
Args:
52+
scheduler (`SchedulerMixin`):
53+
The scheduler to get timesteps from.
54+
num_inference_steps (`int`):
55+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
56+
must be `None`.
57+
device (`str` or `torch.device`, *optional*):
58+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
59+
timesteps (`List[int]`, *optional*):
60+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
61+
`num_inference_steps` and `sigmas` must be `None`.
62+
sigmas (`List[float]`, *optional*):
63+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
64+
`num_inference_steps` and `timesteps` must be `None`.
65+
66+
Returns:
67+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
68+
second element is the number of inference steps.
69+
"""
70+
if timesteps is not None and sigmas is not None:
71+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
72+
if timesteps is not None:
73+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
74+
if not accepts_timesteps:
75+
raise ValueError(
76+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
77+
f" timestep schedules. Please check whether you are using the correct scheduler."
78+
)
79+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
80+
timesteps = scheduler.timesteps
81+
num_inference_steps = len(timesteps)
82+
elif sigmas is not None:
83+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
84+
if not accept_sigmas:
85+
raise ValueError(
86+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
87+
f" sigmas schedules. Please check whether you are using the correct scheduler."
88+
)
89+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
90+
timesteps = scheduler.timesteps
91+
num_inference_steps = len(timesteps)
92+
else:
93+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
94+
timesteps = scheduler.timesteps
95+
return timesteps, num_inference_steps
96+
97+
98+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
99+
def retrieve_latents(
100+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
101+
):
102+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
103+
return encoder_output.latent_dist.sample(generator)
104+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
105+
return encoder_output.latent_dist.mode()
106+
elif hasattr(encoder_output, "latents"):
107+
return encoder_output.latents
108+
else:
109+
raise AttributeError("Could not access latents of provided encoder_output")
110+
111+
112+
def calculate_dimensions(target_area, ratio):
113+
width = math.sqrt(target_area * ratio)
114+
height = width / ratio
115+
116+
width = round(width / 32) * 32
117+
height = round(height / 32) * 32
118+
119+
return width, height, None
120+
121+
23122
class QwenImageMixin:
24123
@property
25124
def guidance_scale(self):
@@ -340,13 +439,3 @@ def _get_qwen_prompt_embeds(
340439
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
341440

342441
return prompt_embeds, encoder_attention_mask
343-
344-
345-
def calculate_dimensions(target_area, ratio):
346-
width = math.sqrt(target_area * ratio)
347-
height = width / ratio
348-
349-
width = round(width / 32) * 32
350-
height = round(height / 32) * 32
351-
352-
return width, height, None

src/diffusers/pipelines/qwenimage/pipeline_qwenimage.py

Lines changed: 1 addition & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import inspect
1615
from typing import Any, Callable, Dict, List, Optional, Union
1716

1817
import numpy as np
@@ -27,7 +26,7 @@
2726
from ...utils.torch_utils import randn_tensor
2827
from ..pipeline_utils import DiffusionPipeline
2928
from .pipeline_output import QwenImagePipelineOutput
30-
from .pipeline_qwen_utils import QwenImagePipelineMixin
29+
from .pipeline_qwen_utils import QwenImagePipelineMixin, calculate_shift, retrieve_timesteps
3130

3231

3332
if is_torch_xla_available():
@@ -57,80 +56,6 @@
5756
"""
5857

5958

60-
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
61-
def calculate_shift(
62-
image_seq_len,
63-
base_seq_len: int = 256,
64-
max_seq_len: int = 4096,
65-
base_shift: float = 0.5,
66-
max_shift: float = 1.15,
67-
):
68-
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
69-
b = base_shift - m * base_seq_len
70-
mu = image_seq_len * m + b
71-
return mu
72-
73-
74-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
75-
def retrieve_timesteps(
76-
scheduler,
77-
num_inference_steps: Optional[int] = None,
78-
device: Optional[Union[str, torch.device]] = None,
79-
timesteps: Optional[List[int]] = None,
80-
sigmas: Optional[List[float]] = None,
81-
**kwargs,
82-
):
83-
r"""
84-
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
85-
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
86-
87-
Args:
88-
scheduler (`SchedulerMixin`):
89-
The scheduler to get timesteps from.
90-
num_inference_steps (`int`):
91-
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
92-
must be `None`.
93-
device (`str` or `torch.device`, *optional*):
94-
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
95-
timesteps (`List[int]`, *optional*):
96-
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
97-
`num_inference_steps` and `sigmas` must be `None`.
98-
sigmas (`List[float]`, *optional*):
99-
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
100-
`num_inference_steps` and `timesteps` must be `None`.
101-
102-
Returns:
103-
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
104-
second element is the number of inference steps.
105-
"""
106-
if timesteps is not None and sigmas is not None:
107-
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
108-
if timesteps is not None:
109-
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
110-
if not accepts_timesteps:
111-
raise ValueError(
112-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
113-
f" timestep schedules. Please check whether you are using the correct scheduler."
114-
)
115-
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
116-
timesteps = scheduler.timesteps
117-
num_inference_steps = len(timesteps)
118-
elif sigmas is not None:
119-
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
120-
if not accept_sigmas:
121-
raise ValueError(
122-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
123-
f" sigmas schedules. Please check whether you are using the correct scheduler."
124-
)
125-
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
126-
timesteps = scheduler.timesteps
127-
num_inference_steps = len(timesteps)
128-
else:
129-
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
130-
timesteps = scheduler.timesteps
131-
return timesteps, num_inference_steps
132-
133-
13459
class QwenImagePipeline(DiffusionPipeline, QwenImagePipelineMixin, QwenImageLoraLoaderMixin):
13560
r"""
13661
The QwenImage pipeline for text-to-image generation.

src/diffusers/pipelines/qwenimage/pipeline_qwenimage_controlnet.py

Lines changed: 1 addition & 90 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import inspect
1615
from typing import Any, Callable, Dict, List, Optional, Union
1716

1817
import numpy as np
@@ -28,7 +27,7 @@
2827
from ...utils.torch_utils import randn_tensor
2928
from ..pipeline_utils import DiffusionPipeline
3029
from .pipeline_output import QwenImagePipelineOutput
31-
from .pipeline_qwen_utils import QwenImagePipelineMixin
30+
from .pipeline_qwen_utils import QwenImagePipelineMixin, calculate_shift, retrieve_latents, retrieve_timesteps
3231

3332

3433
if is_torch_xla_available():
@@ -102,94 +101,6 @@
102101
"""
103102

104103

105-
# Coped from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
106-
def calculate_shift(
107-
image_seq_len,
108-
base_seq_len: int = 256,
109-
max_seq_len: int = 4096,
110-
base_shift: float = 0.5,
111-
max_shift: float = 1.15,
112-
):
113-
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
114-
b = base_shift - m * base_seq_len
115-
mu = image_seq_len * m + b
116-
return mu
117-
118-
119-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
120-
def retrieve_latents(
121-
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
122-
):
123-
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
124-
return encoder_output.latent_dist.sample(generator)
125-
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
126-
return encoder_output.latent_dist.mode()
127-
elif hasattr(encoder_output, "latents"):
128-
return encoder_output.latents
129-
else:
130-
raise AttributeError("Could not access latents of provided encoder_output")
131-
132-
133-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
134-
def retrieve_timesteps(
135-
scheduler,
136-
num_inference_steps: Optional[int] = None,
137-
device: Optional[Union[str, torch.device]] = None,
138-
timesteps: Optional[List[int]] = None,
139-
sigmas: Optional[List[float]] = None,
140-
**kwargs,
141-
):
142-
r"""
143-
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
144-
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
145-
146-
Args:
147-
scheduler (`SchedulerMixin`):
148-
The scheduler to get timesteps from.
149-
num_inference_steps (`int`):
150-
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
151-
must be `None`.
152-
device (`str` or `torch.device`, *optional*):
153-
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
154-
timesteps (`List[int]`, *optional*):
155-
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
156-
`num_inference_steps` and `sigmas` must be `None`.
157-
sigmas (`List[float]`, *optional*):
158-
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
159-
`num_inference_steps` and `timesteps` must be `None`.
160-
161-
Returns:
162-
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
163-
second element is the number of inference steps.
164-
"""
165-
if timesteps is not None and sigmas is not None:
166-
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
167-
if timesteps is not None:
168-
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
169-
if not accepts_timesteps:
170-
raise ValueError(
171-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
172-
f" timestep schedules. Please check whether you are using the correct scheduler."
173-
)
174-
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
175-
timesteps = scheduler.timesteps
176-
num_inference_steps = len(timesteps)
177-
elif sigmas is not None:
178-
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
179-
if not accept_sigmas:
180-
raise ValueError(
181-
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
182-
f" sigmas schedules. Please check whether you are using the correct scheduler."
183-
)
184-
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
185-
timesteps = scheduler.timesteps
186-
num_inference_steps = len(timesteps)
187-
else:
188-
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
189-
timesteps = scheduler.timesteps
190-
return timesteps, num_inference_steps
191-
192-
193104
class QwenImageControlNetPipeline(DiffusionPipeline, QwenImagePipelineMixin, QwenImageLoraLoaderMixin):
194105
r"""
195106
The QwenImage pipeline for text-to-image generation.

0 commit comments

Comments
 (0)