Skip to content

Commit dc6a4d4

Browse files
committed
more
1 parent 8946974 commit dc6a4d4

File tree

4 files changed

+330
-424
lines changed

4 files changed

+330
-424
lines changed

src/diffusers/modular_pipelines/modular_pipeline_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,10 @@ class ComponentSpec:
9191
type_hint: Optional[Type] = None
9292
description: Optional[str] = None
9393
config: Optional[FrozenDict] = None
94-
# YiYi Notes: should we change it to pretrained_model_name_or_path for consistency? a bit long for a field name
94+
# YiYi TODO: currently required is only used to mark optional components that the block can run without, in the future:
95+
# 1. the spec for an optional component should has lower priority when combined in sequential/auto blocks
96+
# 2. should not need to define default_creation_method for optional components
97+
required: bool = True
9598
repo: Optional[Union[str, List[str]]] = field(default=None, metadata={"loading": True})
9699
subfolder: Optional[str] = field(default="", metadata={"loading": True})
97100
variant: Optional[str] = field(default=None, metadata={"loading": True})

src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -418,21 +418,21 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
418418
device = components._execution_device
419419

420420
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
421-
components.scheduler,
422-
block_state.num_inference_steps,
423-
block_state.device,
424-
block_state.timesteps,
425-
block_state.sigmas,
421+
scheduler=components.scheduler,
422+
num_inference_steps=block_state.num_inference_steps,
423+
device=device,
424+
timesteps=block_state.timesteps,
425+
sigmas=block_state.sigmas,
426426
)
427427

428428
def denoising_value_valid(dnv):
429429
return isinstance(dnv, float) and 0 < dnv < 1
430430

431431
block_state.timesteps, block_state.num_inference_steps = self.get_timesteps(
432-
components,
433-
block_state.num_inference_steps,
434-
block_state.strength,
435-
device,
432+
components=components,
433+
num_inference_steps=block_state.num_inference_steps,
434+
strength=block_state.strength,
435+
device=device,
436436
denoising_start=block_state.denoising_start
437437
if denoising_value_valid(block_state.denoising_start)
438438
else None,
@@ -498,14 +498,14 @@ def intermediate_outputs(self) -> List[OutputParam]:
498498
def __call__(self, components: StableDiffusionXLModularPipeline, state: PipelineState) -> PipelineState:
499499
block_state = self.get_block_state(state)
500500

501-
block_state.device = components._execution_device
501+
device = components._execution_device
502502

503503
block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
504-
components.scheduler,
505-
block_state.num_inference_steps,
506-
block_state.device,
507-
block_state.timesteps,
508-
block_state.sigmas,
504+
scheduler=components.scheduler,
505+
num_inference_steps=block_state.num_inference_steps,
506+
device=device,
507+
timesteps=block_state.timesteps,
508+
sigmas=block_state.sigmas,
509509
)
510510

511511
if (
@@ -581,7 +581,7 @@ def intermediate_inputs(self) -> List[str]:
581581
description="The latents representing the reference image for image-to-image/inpainting generation. Can be generated in vae_encode step.",
582582
),
583583
InputParam(
584-
"mask",
584+
"processed_mask_image",
585585
required=True,
586586
type_hint=torch.Tensor,
587587
description="The mask for the inpainting generation. Can be generated in vae_encode step.",
@@ -591,7 +591,7 @@ def intermediate_inputs(self) -> List[str]:
591591
type_hint=torch.Tensor,
592592
description="The masked image latents for the inpainting generation (only for inpainting-specific unet). Can be generated in vae_encode step.",
593593
),
594-
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
594+
InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs, can be generated in input step."),
595595
]
596596

597597
@property

0 commit comments

Comments
 (0)