Skip to content

Commit 3bd289f

Browse files
committed
initial
1 parent 03c3f69 commit 3bd289f

File tree

4 files changed

+483
-0
lines changed

4 files changed

+483
-0
lines changed
Lines changed: 289 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,289 @@
1+
2+
3+
4+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.calculate_shift
5+
def calculate_shift(
6+
image_seq_len,
7+
base_seq_len: int = 256,
8+
max_seq_len: int = 4096,
9+
base_shift: float = 0.5,
10+
max_shift: float = 1.15,
11+
):
12+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
13+
b = base_shift - m * base_seq_len
14+
mu = image_seq_len * m + b
15+
return mu
16+
17+
18+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
19+
def retrieve_timesteps(
20+
scheduler,
21+
num_inference_steps: Optional[int] = None,
22+
device: Optional[Union[str, torch.device]] = None,
23+
timesteps: Optional[List[int]] = None,
24+
sigmas: Optional[List[float]] = None,
25+
**kwargs,
26+
):
27+
r"""
28+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
29+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
30+
31+
Args:
32+
scheduler (`SchedulerMixin`):
33+
The scheduler to get timesteps from.
34+
num_inference_steps (`int`):
35+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
36+
must be `None`.
37+
device (`str` or `torch.device`, *optional*):
38+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
39+
timesteps (`List[int]`, *optional*):
40+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
41+
`num_inference_steps` and `sigmas` must be `None`.
42+
sigmas (`List[float]`, *optional*):
43+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
44+
`num_inference_steps` and `timesteps` must be `None`.
45+
46+
Returns:
47+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
48+
second element is the number of inference steps.
49+
"""
50+
if timesteps is not None and sigmas is not None:
51+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
52+
if timesteps is not None:
53+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
54+
if not accepts_timesteps:
55+
raise ValueError(
56+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
57+
f" timestep schedules. Please check whether you are using the correct scheduler."
58+
)
59+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
60+
timesteps = scheduler.timesteps
61+
num_inference_steps = len(timesteps)
62+
elif sigmas is not None:
63+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
64+
if not accept_sigmas:
65+
raise ValueError(
66+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
67+
f" sigmas schedules. Please check whether you are using the correct scheduler."
68+
)
69+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
70+
timesteps = scheduler.timesteps
71+
num_inference_steps = len(timesteps)
72+
else:
73+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
74+
timesteps = scheduler.timesteps
75+
return timesteps, num_inference_steps
76+
77+
78+
79+
def pack_latents(latents, batch_size, num_channels_latents, height, width):
80+
latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
81+
latents = latents.permute(0, 2, 4, 1, 3, 5)
82+
latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
83+
84+
return latents
85+
86+
def unpack_latents(latents, height, width, vae_scale_factor):
87+
batch_size, num_patches, channels = latents.shape
88+
89+
# VAE applies 8x compression on images but we must also account for packing which requires
90+
# latent height and width to be divisible by 2.
91+
height = 2 * (int(height) // (vae_scale_factor * 2))
92+
width = 2 * (int(width) // (vae_scale_factor * 2))
93+
94+
latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
95+
latents = latents.permute(0, 3, 1, 4, 2, 5)
96+
97+
latents = latents.reshape(batch_size, channels // (2 * 2), 1, height, width)
98+
99+
return latents
100+
101+
class QwenImagePrepareLatentsStep(PipelineBlock):
102+
103+
model_name = "qwenimage"
104+
105+
@property
106+
def description(self) -> str:
107+
return "Prepare latents step that prepares the latents for the text-to-image generation process"
108+
109+
@property
110+
def inputs(self) -> List[InputParam]:
111+
return [
112+
InputParam(name="height"),
113+
InputParam(name="width"),
114+
InputParam(name="latents"),
115+
InputParam(name="num_images_per_prompt", default=1),
116+
]
117+
118+
@property
119+
def intermediate_inputs(self) -> List[InputParam]:
120+
return [
121+
InputParam(
122+
name="batch_size",
123+
required=True,
124+
type_hint=int,
125+
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
126+
),
127+
InputParam(name="generator"),
128+
InputParam(name="dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
129+
]
130+
131+
@property
132+
def intermediate_outputs(self) -> List[OutputParam]:
133+
return [
134+
OutputParam(name="latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"),
135+
]
136+
137+
138+
def check_inputs(self, height, width, components):
139+
140+
if height is not None and height % (components.vae_scale_factor * 2) != 0:
141+
raise ValueError(f"Height must be divisible by {components.vae_scale_factor * 2} but is {height}")
142+
143+
if width is not None and width % (components.vae_scale_factor * 2) != 0:
144+
raise ValueError(f"Width must be divisible by {components.vae_scale_factor * 2} but is {width}")
145+
146+
# Copied from diffusers.pipelines.qwenimage.pipeline_qwenimage.QwenImagePipeline.prepare_latents with self->components
147+
def prepare_latents(
148+
components,
149+
batch_size,
150+
num_channels_latents,
151+
height,
152+
width,
153+
dtype,
154+
device,
155+
generator,
156+
):
157+
# VAE applies 8x compression on images but we must also account for packing which requires
158+
# latent height and width to be divisible by 2.
159+
height = 2 * (int(height) // (components.vae_scale_factor * 2))
160+
width = 2 * (int(width) // (components.vae_scale_factor * 2))
161+
162+
shape = (batch_size, 1, num_channels_latents, height, width)
163+
164+
if isinstance(generator, list) and len(generator) != batch_size:
165+
raise ValueError(
166+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
167+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
168+
)
169+
170+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
171+
latents = pack_latents(latents, batch_size, num_channels_latents, height, width)
172+
173+
return latents
174+
175+
176+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
177+
178+
block_state = self.get_block_state(state)
179+
180+
device = components._execution_device
181+
dtype = block_state.dtype
182+
183+
height = block_state.height or components.default_height
184+
width = block_state.width or components.default_width
185+
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
186+
187+
latents = self.prepare_latents(
188+
components=components,
189+
batch_size=final_batch_size,
190+
num_channels_latents=components.num_channels_latents,
191+
height=height,
192+
width=width,
193+
dtype=dtype,
194+
device=device,
195+
generator=block_state.generator)
196+
197+
self.set_block_state(state, block_state)
198+
199+
return components, state
200+
201+
202+
203+
class QwenImageSetTimestepsStep(PipelineBlock):
204+
205+
model_name = "qwenimage"
206+
207+
@property
208+
def description(self) -> str:
209+
return "Step that sets the the scheduler's timesteps for inference"
210+
211+
@property
212+
def expected_components(self) -> List[ComponentSpec]:
213+
return [
214+
ComponentSpec(name="scheduler", FlowMatchEulerDiscreteScheduler),
215+
]
216+
217+
@property
218+
def inputs(self) -> List[InputParam]:
219+
return [
220+
InputParam(name="num_inference_steps", default=50),
221+
InputParam(name="sigmas"),
222+
]
223+
224+
@property
225+
def intermediate_inputs(self) -> List[InputParam]:
226+
return [
227+
InputParam(name="latents", required=True, type_hint=torch.Tensor, description="The latents to use for the denoising process"),
228+
]
229+
230+
@property
231+
def intermediate_outputs(self) -> List[OutputParam]:
232+
return [
233+
OutputParam(name="timesteps", type_hint=torch.Tensor, description="The timesteps to use for the denoising process"),
234+
OutputParam(name="num_inference_steps", type_hint=int, description="The number of inference steps to use for the denoising process"),
235+
]
236+
237+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
238+
block_state = self.get_block_state(state)
239+
240+
device = components._execution_device
241+
242+
sigmas = np.linspace(1.0, 1 / block_state.num_inference_steps, block_state.num_inference_steps) if block_state.sigmas is None else block_state.sigmas
243+
244+
mu = calculate_shift(
245+
image_seq_len=block_state.latents.shape[1],
246+
base_seq_len= components.scheduler.config.get("base_image_seq_len", 256),
247+
max_seq_len= components.scheduler.config.get("max_image_seq_len", 4096),
248+
base_shift= components.scheduler.config.get("base_shift", 0.5),
249+
max_shift= components.scheduler.config.get("max_shift", 1.15),
250+
)
251+
timesteps, num_inference_steps = retrieve_timesteps(
252+
scheduler=components.scheduler,
253+
num_inference_steps=block_state.num_inference_steps,
254+
device,
255+
sigmas=sigmas,
256+
mu=mu,
257+
)
258+
259+
self.set_block_state(state, block_state)
260+
261+
return components, state
262+
263+
264+
class QwenImagePrepareAdditionalConditioningStep(PipelineBlock):
265+
266+
model_name = "qwenimage"
267+
268+
@property
269+
def description(self) -> str:
270+
return "Step that prepares the additional conditioning for the text-to-image generation process"
271+
272+
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
273+
274+
block_state = self.get_block_state(state)
275+
276+
height = block_state.height or components.default_height
277+
width = block_state.width or components.default_width
278+
279+
block_state.img_shapes = [(1, height // components.vae_scale_factor // 2, width // components.vae_scale_factor // 2)] * block_state.final_batch_size
280+
image_seq_len = block_state.latents.shape[1]
281+
txt_seq_lens = block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
282+
negative_txt_seq_lens = (
283+
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() if block_state.negative_prompt_embeds_mask is not None else None
284+
)
285+
286+
287+
self.set_block_state(state, block_state)
288+
289+
return components, state

0 commit comments

Comments
 (0)