@@ -261,5 +261,52 @@ def step(
261261
262262 return SCMSchedulerOutput (prev_sample = prev_sample , pred_original_sample = pred_x0 )
263263
264+ # ... (previous code from the SCMScheduler class) ...
265+
266+ def add_noise (
267+ self ,
268+ original_samples : torch .Tensor ,
269+ noise : torch .Tensor ,
270+ timesteps : torch .Tensor ,
271+ ) -> torch .Tensor :
272+ """
273+ Adds noise to the original samples according to the SCM forward process.
274+
275+ Args:
276+ original_samples (`torch.Tensor`):
277+ The original clean samples (x_0).
278+ noise (`torch.Tensor`):
279+ Random noise (epsilon) drawn from a standard normal distribution N(0,I),
280+ with the same shape as `original_samples`.
281+ timesteps (`torch.Tensor`):
282+ The timesteps (s) at which to noise the samples. These should be the
283+ angular timesteps used by this scheduler (e.g., values from `self.timesteps`).
284+ The shape should be broadcastable to `original_samples` (e.g., a 1D tensor
285+ of timesteps for a batch of samples, or a single timestep value).
286+
287+ Returns:
288+ `torch.Tensor`: The noisy samples (x_s).
289+ """
290+ if not hasattr (self .config , "sigma_data" ):
291+ raise ValueError ("SCMScheduler config must have `sigma_data` attribute." )
292+
293+ if timesteps .ndim == 1 :
294+ # Reshape timesteps to be broadcastable: (batch_size,) -> (batch_size, 1, 1, 1)
295+ # Assuming original_samples is (batch, channels, height, width)
296+ dims_to_add = original_samples .ndim - timesteps .ndim
297+ timesteps = timesteps .reshape (timesteps .shape + (1 ,) * dims_to_add )
298+
299+ # The forward process: x_s = cos(s) * x_0 + sin(s) * sigma_data * epsilon
300+ # Ensure timesteps, original_samples, and noise are on the same device
301+ timesteps = timesteps .to (original_samples .device )
302+
303+ cos_t = torch .cos (timesteps )
304+ sin_t = torch .sin (timesteps )
305+
306+ noisy_samples = cos_t * original_samples + sin_t * noise * self .config .sigma_data
307+
308+ return noisy_samples
309+
310+
264311 def __len__ (self ):
265312 return self .config .num_train_timesteps
0 commit comments