@@ -261,51 +261,6 @@ def step(
261261
262262 return SCMSchedulerOutput (prev_sample = prev_sample , pred_original_sample = pred_x0 )
263263
264- # ... (previous code from the SCMScheduler class) ...
265-
266- def add_noise (
267- self ,
268- original_samples : torch .Tensor ,
269- noise : torch .Tensor ,
270- timesteps : torch .Tensor ,
271- ) -> torch .Tensor :
272- """
273- Adds noise to the original samples according to the SCM forward process.
274-
275- Args:
276- original_samples (`torch.Tensor`):
277- The original clean samples (x_0).
278- noise (`torch.Tensor`):
279- Random noise (epsilon) drawn from a standard normal distribution N(0,I),
280- with the same shape as `original_samples`.
281- timesteps (`torch.Tensor`):
282- The timesteps (s) at which to noise the samples. These should be the
283- angular timesteps used by this scheduler (e.g., values from `self.timesteps`).
284- The shape should be broadcastable to `original_samples` (e.g., a 1D tensor
285- of timesteps for a batch of samples, or a single timestep value).
286-
287- Returns:
288- `torch.Tensor`: The noisy samples (x_s).
289- """
290- if not hasattr (self .config , "sigma_data" ):
291- raise ValueError ("SCMScheduler config must have `sigma_data` attribute." )
292-
293- if timesteps .ndim == 1 :
294- # Reshape timesteps to be broadcastable: (batch_size,) -> (batch_size, 1, 1, 1)
295- # Assuming original_samples is (batch, channels, height, width)
296- dims_to_add = original_samples .ndim - timesteps .ndim
297- timesteps = timesteps .reshape (timesteps .shape + (1 ,) * dims_to_add )
298-
299- # The forward process: x_s = cos(s) * x_0 + sin(s) * sigma_data * epsilon
300- # Ensure timesteps, original_samples, and noise are on the same device
301- timesteps = timesteps .to (original_samples .device )
302-
303- cos_t = torch .cos (timesteps )
304- sin_t = torch .sin (timesteps )
305-
306- noisy_samples = cos_t * original_samples + sin_t * noise * self .config .sigma_data
307-
308- return noisy_samples
309264
310265
311266 def __len__ (self ):
0 commit comments