Skip to content

Commit 3b01d72

Browse files
authored
Modify FlowMatch Scale Noise (#8678)
* initial fix * apply suggestion * delete step_index line
1 parent e2a4a46 commit 3b01d72

File tree

2 files changed

+25
-4
lines changed

2 files changed

+25
-4
lines changed

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -852,7 +852,7 @@ def __call__(
852852
# 4. Prepare timesteps
853853
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
854854
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
855-
latent_timestep = timesteps[:1].repeat(batch_size * num_inference_steps)
855+
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
856856

857857
# 5. Prepare latent variables
858858
if latents is None:

src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py

Lines changed: 24 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -126,10 +126,31 @@ def scale_noise(
126126
`torch.FloatTensor`:
127127
A scaled input sample.
128128
"""
129-
if self.step_index is None:
130-
self._init_step_index(timestep)
129+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
130+
sigmas = self.sigmas.to(device=sample.device, dtype=sample.dtype)
131+
132+
if sample.device.type == "mps" and torch.is_floating_point(timestep):
133+
# mps does not support float64
134+
schedule_timesteps = self.timesteps.to(sample.device, dtype=torch.float32)
135+
timestep = timestep.to(sample.device, dtype=torch.float32)
136+
else:
137+
schedule_timesteps = self.timesteps.to(sample.device)
138+
timestep = timestep.to(sample.device)
139+
140+
# self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
141+
if self.begin_index is None:
142+
step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timestep]
143+
elif self.step_index is not None:
144+
# add_noise is called after first denoising step (for inpainting)
145+
step_indices = [self.step_index] * timestep.shape[0]
146+
else:
147+
# add noise is called before first denoising step to create initial latent(img2img)
148+
step_indices = [self.begin_index] * timestep.shape[0]
149+
150+
sigma = sigmas[step_indices].flatten()
151+
while len(sigma.shape) < len(sample.shape):
152+
sigma = sigma.unsqueeze(-1)
131153

132-
sigma = self.sigmas[self.step_index]
133154
sample = sigma * noise + (1.0 - sigma) * sample
134155

135156
return sample

0 commit comments

Comments
 (0)