Skip to content

Commit 5452431

Browse files
committed
fix image encoding
1 parent a70f29d commit 5452431

File tree

2 files changed

+48
-1
lines changed

2 files changed

+48
-1
lines changed

src/diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -633,7 +633,7 @@ def prepare_latents(self,
633633
)
634634

635635
noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
636-
latents = self.scheduler.scale_noise(image_latents, timestep, noise)
636+
latents = self.scheduler.add_noise(image_latents, timestep, noise)
637637
return latents
638638

639639
@property

src/diffusers/schedulers/scheduling_scm.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -261,5 +261,52 @@ def step(
261261

262262
return SCMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0)
263263

264+
# ... (previous code from the SCMScheduler class) ...
265+
266+
def add_noise(
267+
self,
268+
original_samples: torch.Tensor,
269+
noise: torch.Tensor,
270+
timesteps: torch.Tensor,
271+
) -> torch.Tensor:
272+
"""
273+
Adds noise to the original samples according to the SCM forward process.
274+
275+
Args:
276+
original_samples (`torch.Tensor`):
277+
The original clean samples (x_0).
278+
noise (`torch.Tensor`):
279+
Random noise (epsilon) drawn from a standard normal distribution N(0,I),
280+
with the same shape as `original_samples`.
281+
timesteps (`torch.Tensor`):
282+
The timesteps (s) at which to noise the samples. These should be the
283+
angular timesteps used by this scheduler (e.g., values from `self.timesteps`).
284+
The shape should be broadcastable to `original_samples` (e.g., a 1D tensor
285+
of timesteps for a batch of samples, or a single timestep value).
286+
287+
Returns:
288+
`torch.Tensor`: The noisy samples (x_s).
289+
"""
290+
if not hasattr(self.config, "sigma_data"):
291+
raise ValueError("SCMScheduler config must have `sigma_data` attribute.")
292+
293+
if timesteps.ndim == 1:
294+
# Reshape timesteps to be broadcastable: (batch_size,) -> (batch_size, 1, 1, 1)
295+
# Assuming original_samples is (batch, channels, height, width)
296+
dims_to_add = original_samples.ndim - timesteps.ndim
297+
timesteps = timesteps.reshape(timesteps.shape + (1,) * dims_to_add)
298+
299+
# The forward process: x_s = cos(s) * x_0 + sin(s) * sigma_data * epsilon
300+
# Ensure timesteps, original_samples, and noise are on the same device
301+
timesteps = timesteps.to(original_samples.device)
302+
303+
cos_t = torch.cos(timesteps)
304+
sin_t = torch.sin(timesteps)
305+
306+
noisy_samples = cos_t * original_samples + sin_t * noise * self.config.sigma_data
307+
308+
return noisy_samples
309+
310+
264311
def __len__(self):
265312
return self.config.num_train_timesteps

0 commit comments

Comments
 (0)