Skip to content

Commit c80f572

Browse files
committed
revert unnecessary changes to scheduler
1 parent b247c5f commit c80f572

File tree

1 file changed

+0
-45
lines changed

1 file changed

+0
-45
lines changed

src/diffusers/schedulers/scheduling_scm.py

Lines changed: 0 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -261,51 +261,6 @@ def step(
261261

262262
return SCMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_x0)
263263

264-
# ... (previous code from the SCMScheduler class) ...
265-
266-
def add_noise(
267-
self,
268-
original_samples: torch.Tensor,
269-
noise: torch.Tensor,
270-
timesteps: torch.Tensor,
271-
) -> torch.Tensor:
272-
"""
273-
Adds noise to the original samples according to the SCM forward process.
274-
275-
Args:
276-
original_samples (`torch.Tensor`):
277-
The original clean samples (x_0).
278-
noise (`torch.Tensor`):
279-
Random noise (epsilon) drawn from a standard normal distribution N(0,I),
280-
with the same shape as `original_samples`.
281-
timesteps (`torch.Tensor`):
282-
The timesteps (s) at which to noise the samples. These should be the
283-
angular timesteps used by this scheduler (e.g., values from `self.timesteps`).
284-
The shape should be broadcastable to `original_samples` (e.g., a 1D tensor
285-
of timesteps for a batch of samples, or a single timestep value).
286-
287-
Returns:
288-
`torch.Tensor`: The noisy samples (x_s).
289-
"""
290-
if not hasattr(self.config, "sigma_data"):
291-
raise ValueError("SCMScheduler config must have `sigma_data` attribute.")
292-
293-
if timesteps.ndim == 1:
294-
# Reshape timesteps to be broadcastable: (batch_size,) -> (batch_size, 1, 1, 1)
295-
# Assuming original_samples is (batch, channels, height, width)
296-
dims_to_add = original_samples.ndim - timesteps.ndim
297-
timesteps = timesteps.reshape(timesteps.shape + (1,) * dims_to_add)
298-
299-
# The forward process: x_s = cos(s) * x_0 + sin(s) * sigma_data * epsilon
300-
# Ensure timesteps, original_samples, and noise are on the same device
301-
timesteps = timesteps.to(original_samples.device)
302-
303-
cos_t = torch.cos(timesteps)
304-
sin_t = torch.sin(timesteps)
305-
306-
noisy_samples = cos_t * original_samples + sin_t * noise * self.config.sigma_data
307-
308-
return noisy_samples
309264

310265

311266
def __len__(self):

0 commit comments

Comments
 (0)