@@ -126,10 +126,31 @@ def scale_noise(
126
126
`torch.FloatTensor`:
127
127
A scaled input sample.
128
128
"""
129
- if self .step_index is None :
130
- self ._init_step_index (timestep )
129
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
130
+ sigmas = self .sigmas .to (device = sample .device , dtype = sample .dtype )
131
+
132
+ if sample .device .type == "mps" and torch .is_floating_point (timestep ):
133
+ # mps does not support float64
134
+ schedule_timesteps = self .timesteps .to (sample .device , dtype = torch .float32 )
135
+ timestep = timestep .to (sample .device , dtype = torch .float32 )
136
+ else :
137
+ schedule_timesteps = self .timesteps .to (sample .device )
138
+ timestep = timestep .to (sample .device )
139
+
140
+ # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
141
+ if self .begin_index is None :
142
+ step_indices = [self .index_for_timestep (t , schedule_timesteps ) for t in timestep ]
143
+ elif self .step_index is not None :
144
+ # add_noise is called after first denoising step (for inpainting)
145
+ step_indices = [self .step_index ] * timestep .shape [0 ]
146
+ else :
147
+ # add noise is called before first denoising step to create initial latent(img2img)
148
+ step_indices = [self .begin_index ] * timestep .shape [0 ]
149
+
150
+ sigma = sigmas [step_indices ].flatten ()
151
+ while len (sigma .shape ) < len (sample .shape ):
152
+ sigma = sigma .unsqueeze (- 1 )
131
153
132
- sigma = self .sigmas [self .step_index ]
133
154
sample = sigma * noise + (1.0 - sigma ) * sample
134
155
135
156
return sample
0 commit comments