Skip to content

Commit 51d689e

Browse files
committed
Bug: Fix wrong params, docs, and handles noise=None
1 parent 745199a commit 51d689e

File tree

3 files changed

+22
-6
lines changed

3 files changed

+22
-6
lines changed

src/diffusers/schedulers/scheduling_flow_match_euler_discrete.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ def set_shift(self, shift: float):
171171
def scale_noise(
172172
self,
173173
sample: torch.FloatTensor,
174-
timestep: Union[float, torch.FloatTensor],
174+
timestep: torch.FloatTensor,
175175
noise: Optional[torch.FloatTensor] = None,
176176
) -> torch.FloatTensor:
177177
"""
@@ -180,8 +180,10 @@ def scale_noise(
180180
Args:
181181
sample (`torch.FloatTensor`):
182182
The input sample.
183-
timestep (`int`, *optional*):
183+
timestep (`torch.FloatTensor`):
184184
The current timestep in the diffusion chain.
185+
noise (`torch.FloatTensor`, *optional*):
186+
The noise tensor.
185187
186188
Returns:
187189
`torch.FloatTensor`:
@@ -212,6 +214,9 @@ def scale_noise(
212214
while len(sigma.shape) < len(sample.shape):
213215
sigma = sigma.unsqueeze(-1)
214216

217+
if noise is None:
218+
noise = torch.randn_like(sample)
219+
215220
sample = sigma * noise + (1.0 - sigma) * sample
216221

217222
return sample

src/diffusers/schedulers/scheduling_flow_match_heun_discrete.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def set_begin_index(self, begin_index: int = 0):
110110
def scale_noise(
111111
self,
112112
sample: torch.FloatTensor,
113-
timestep: Union[float, torch.FloatTensor],
113+
timestep: torch.FloatTensor,
114114
noise: Optional[torch.FloatTensor] = None,
115115
) -> torch.FloatTensor:
116116
"""
@@ -119,8 +119,10 @@ def scale_noise(
119119
Args:
120120
sample (`torch.FloatTensor`):
121121
The input sample.
122-
timestep (`int`, *optional*):
122+
timestep (`torch.FloatTensor`):
123123
The current timestep in the diffusion chain.
124+
noise (`torch.FloatTensor`, *optional*):
125+
The noise tensor.
124126
125127
Returns:
126128
`torch.FloatTensor`:
@@ -130,6 +132,10 @@ def scale_noise(
130132
self._init_step_index(timestep)
131133

132134
sigma = self.sigmas[self.step_index]
135+
136+
if noise is None:
137+
noise = torch.randn_like(sample)
138+
133139
sample = sigma * noise + (1.0 - sigma) * sample
134140

135141
return sample

src/diffusers/schedulers/scheduling_flow_match_lcm.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def set_scale_factors(self, scale_factors: list, upscale_mode):
192192
def scale_noise(
193193
self,
194194
sample: torch.FloatTensor,
195-
timestep: Union[float, torch.FloatTensor],
195+
timestep: torch.FloatTensor,
196196
noise: Optional[torch.FloatTensor] = None,
197197
) -> torch.FloatTensor:
198198
"""
@@ -201,8 +201,10 @@ def scale_noise(
201201
Args:
202202
sample (`torch.FloatTensor`):
203203
The input sample.
204-
timestep (`int`, *optional*):
204+
timestep (`torch.FloatTensor`):
205205
The current timestep in the diffusion chain.
206+
noise (`torch.FloatTensor`, *optional*):
207+
The noise tensor.
206208
207209
Returns:
208210
`torch.FloatTensor`:
@@ -233,6 +235,9 @@ def scale_noise(
233235
while len(sigma.shape) < len(sample.shape):
234236
sigma = sigma.unsqueeze(-1)
235237

238+
if noise is None:
239+
noise = torch.randn_like(sample)
240+
236241
sample = sigma * noise + (1.0 - sigma) * sample
237242

238243
return sample

0 commit comments

Comments
 (0)