Skip to content

Commit 57a1bc6

Browse files
committed
first dynamic block!
1 parent f72763c commit 57a1bc6

File tree

2 files changed

+16
-7
lines changed

2 files changed

+16
-7
lines changed

src/diffusers/modular_pipelines/qwenimage/encoders.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -348,9 +348,14 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
348348
return components, state
349349

350350

351-
class QwenImageVaeEncoderStep(ModularPipelineBlocks):
351+
class QwenImageVaeEncoderDynamicStep(ModularPipelineBlocks):
352352
model_name = "qwenimage"
353353

354+
def __init__(self, input_name: str = "image", output_name: str = "image_latents"):
355+
self.input_name = input_name
356+
self.output_name = output_name
357+
super().__init__()
358+
354359
@property
355360
def description(self) -> str:
356361
return "Vae Encoder step that encode the input image into a latent representation"
@@ -370,15 +375,15 @@ def expected_components(self) -> List[ComponentSpec]:
370375
@property
371376
def inputs(self) -> List[InputParam]:
372377
return [
373-
InputParam("image", required=True, description="The image to encode, should already be resized using resize step"),
378+
InputParam(self.input_name, required=True),
374379
InputParam("generator"),
375380
]
376381

377382
@property
378383
def intermediate_outputs(self) -> List[OutputParam]:
379384
return [
380385
OutputParam(
381-
"image_latents",
386+
self.output_name,
382387
type_hint=torch.Tensor,
383388
description="The latents representing the reference image",
384389
)
@@ -391,16 +396,20 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
391396
device = components._execution_device
392397
dtype = components.vae.dtype
393398

394-
image = components.image_processor.preprocess(block_state.image)
399+
image = getattr(block_state, self.input_name)
400+
401+
image = components.image_processor.preprocess(image)
395402
image = image.unsqueeze(2)
396403
image = image.to(device=device, dtype=dtype)
397404

398405

399406
# Encode image into latents
400-
block_state.image_latents = encode_vae_image(
407+
image_latents = encode_vae_image(
401408
image=image, vae=components.vae, generator=block_state.generator, latent_channels=components.num_channels_latents
402409
)
403410

411+
setattr(block_state, self.output_name, image_latents)
412+
404413
self.set_block_state(state, block_state)
405414

406415
return components, state

src/diffusers/modular_pipelines/qwenimage/modular_blocks.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
from ...utils import logging
1616

17-
from .encoders import QwenImageTextEncoderStep, QwenImageEditTextEncoderStep, QwenImageVaeEncoderStep
17+
from .encoders import QwenImageTextEncoderStep, QwenImageEditTextEncoderStep, QwenImageVaeEncoderDynamicStep
1818
from .decoders import QwenImageDecodeStep
1919
from .denoise import QwenImageDenoiseStep, QwenImageEditDenoiseStep
2020
from .before_denoise import QwenImageInputStep, QwenImagePrepareLatentsStep, QwenImageSetTimestepsStep, QwenImagePrepareAdditionalInputsStep, QwenImagePrepareImageLatentsStep, QwenImageEditPrepareAdditionalInputsStep, QwenImageImageResizeStep
@@ -41,7 +41,7 @@
4141
[
4242
("image_resize", QwenImageImageResizeStep),
4343
("text_encoder", QwenImageEditTextEncoderStep),
44-
("vae_encoder", QwenImageVaeEncoderStep),
44+
("vae_encoder", QwenImageVaeEncoderDynamicStep(input_name="image", output_name="image_latents")),
4545
("input", QwenImageInputStep),
4646
("prepare_image_latents", QwenImagePrepareImageLatentsStep),
4747
("prepare_latents", QwenImagePrepareLatentsStep),

0 commit comments

Comments
 (0)