Skip to content

Commit b9d6ede

Browse files
committed
fix for lcm
1 parent 1bea2d8 commit b9d6ede

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

src/diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1272,6 +1272,11 @@ def intermediate_inputs(self) -> List[InputParam]:
12721272
type_hint=int,
12731273
description="Number of prompts, the final batch size of model inputs should be batch_size * num_images_per_prompt. Can be generated in input step.",
12741274
),
1275+
InputParam(
1276+
"dtype",
1277+
type_hint=torch.dtype,
1278+
description="The dtype of the model inputs. Can be generated in input step.",
1279+
),
12751280
]
12761281

12771282
@property
@@ -1332,9 +1337,7 @@ def __call__(self, components: StableDiffusionXLModularPipeline, state: Pipeline
13321337
# Optionally get Guidance Scale Embedding for LCM
13331338
block_state.timestep_cond = None
13341339

1335-
guidance_scale_tensor = (
1336-
torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size).to(device=device)
1337-
)
1340+
guidance_scale_tensor = torch.tensor(block_state.embedded_guidance_scale - 1).repeat(final_batch_size)
13381341
block_state.timestep_cond = self.get_guidance_scale_embedding(
13391342
guidance_scale_tensor, embedding_dim=components.unet.config.time_cond_proj_dim
13401343
).to(device=device, dtype=dtype)

0 commit comments

Comments
 (0)