Skip to content

Commit 09b805f

Browse files
committed
fix
1 parent 77c097d commit 09b805f

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

src/diffusers/modular_pipelines/wan/before_denoise.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,7 @@ def retrieve_timesteps(
9292
else:
9393
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
9494
timesteps = scheduler.timesteps
95+
return timesteps, num_inference_steps
9596

9697

9798
class WanInputStep(PipelineBlock):
@@ -304,7 +305,7 @@ def check_inputs(components, block_state):
304305
f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}."
305306
)
306307

307-
@classmethod
308+
@staticmethod
308309
# Copied from diffusers.pipelines.wan.pipeline_wan.WanPipeline.prepare_latents with self->comp
309310
def prepare_latents(
310311
comp,
@@ -349,8 +350,9 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe
349350

350351
self.check_inputs(components, block_state)
351352

352-
block_state.height = block_state.height or components.default_sample_size * components.vae_scale_factor
353-
block_state.width = block_state.width or components.default_sample_size * components.vae_scale_factor
353+
block_state.height = block_state.height or components.default_height
354+
block_state.width = block_state.width or components.default_width
355+
block_state.num_frames = block_state.num_frames or components.default_num_frames
354356
block_state.num_channels_latents = components.num_channels_latents
355357
block_state.latents = self.prepare_latents(
356358
components,

src/diffusers/modular_pipelines/wan/modular_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,11 @@ class WanModularPipeline(
3939

4040
@property
4141
def default_height(self):
42-
return self.default_sample_height * self.vae_scale_factor
42+
return self.default_sample_height * self.vae_scale_factor_spatial
4343

4444
@property
4545
def default_width(self):
46-
return self.default_sample_width * self.vae_scale_factor
46+
return self.default_sample_width * self.vae_scale_factor_spatial
4747

4848
@property
4949
def default_num_frames(self):

0 commit comments

Comments
 (0)