Skip to content

Commit 6bf38c8

Browse files
authored
Apply suggestions from code review
1 parent 675ae14 commit 6bf38c8

File tree

1 file changed

+3
-14
lines changed

1 file changed

+3
-14
lines changed

src/diffusers/modular_pipelines/qwenimage/before_denoise.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -241,13 +241,6 @@ def inputs(self) -> List[InputParam]:
241241
type_hint=torch.Tensor,
242242
description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
243243
),
244-
InputParam(
245-
name="batch_size",
246-
required=True,
247-
type_hint=int,
248-
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in expand textinput step.",
249-
),
250-
InputParam(name="num_images_per_prompt", required=True),
251244
]
252245

253246
@property
@@ -262,30 +255,26 @@ def intermediate_outputs(self) -> List[OutputParam]:
262255

263256
@staticmethod
264257
def check_inputs(image_latents, latents, batch_size):
265-
if image_latents.shape[0] != batch_size:
258+
if image_latents.shape[0] != latents.shape[0]:
266259
raise ValueError(
267-
f"`image_latents` must have have batch size {batch_size}, but got {image_latents.shape[0]}"
260+
f"`image_latents` must have have same batch size as `latents`, but got {image_latents.shape[0]} and {latents.shape[0]}"
268261
)
269262

270263
if image_latents.ndim != 3:
271264
raise ValueError(f"`image_latents` must have 3 dimensions (patchified), but got {image_latents.ndim}")
272265

273-
if latents.shape[0] != batch_size:
274-
raise ValueError(f"`latents` must have have batch size {batch_size}, but got {latents.shape[0]}")
275266

276267
@torch.no_grad()
277268
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
278269
block_state = self.get_block_state(state)
279-
final_batch_size = block_state.batch_size * block_state.num_images_per_prompt
280270

281271
self.check_inputs(
282272
image_latents=block_state.image_latents,
283273
latents=block_state.latents,
284-
batch_size=final_batch_size,
285274
)
286275

287276
# prepare latent timestep
288-
latent_timestep = block_state.timesteps[:1].repeat(final_batch_size)
277+
latent_timestep = block_state.timesteps[:1].repeat(block_state.latents.shape[0])
289278

290279
# make copy of initial_noise
291280
block_state.initial_noise = block_state.latents

0 commit comments

Comments
 (0)