Skip to content

Commit c56dc90

Browse files
authored
Merge pull request #1992 from rockerBOO/flux-ip-noise-gamma
Add IP noise gamma for Flux
2 parents ee0f754 + 89f0d27 commit c56dc90

File tree

2 files changed

+254
-32
lines changed

2 files changed

+254
-32
lines changed

library/flux_train_utils.py

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -366,8 +366,6 @@ def get_sigmas(noise_scheduler, timesteps, device, n_dim=4, dtype=torch.float32)
366366
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
367367

368368
sigma = sigmas[step_indices].flatten()
369-
while len(sigma.shape) < n_dim:
370-
sigma = sigma.unsqueeze(-1)
371369
return sigma
372370

373371

@@ -410,42 +408,34 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
410408

411409

412410
def get_noisy_model_input_and_timesteps(
413-
args, noise_scheduler, latents, noise, device, dtype
411+
args, noise_scheduler, latents: torch.Tensor, noise: torch.Tensor, device, dtype
414412
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
415413
bsz, _, h, w = latents.shape
416-
sigmas = None
417-
414+
assert bsz > 0, "Batch size not large enough"
415+
num_timesteps = noise_scheduler.config.num_train_timesteps
418416
if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
419-
# Simple random t-based noise sampling
417+
# Simple random sigma-based noise sampling
420418
if args.timestep_sampling == "sigmoid":
421419
# https://github.com/XLabs-AI/x-flux/tree/main
422-
t = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
420+
sigmas = torch.sigmoid(args.sigmoid_scale * torch.randn((bsz,), device=device))
423421
else:
424-
t = torch.rand((bsz,), device=device)
422+
sigmas = torch.rand((bsz,), device=device)
425423

426-
timesteps = t * 1000.0
427-
t = t.view(-1, 1, 1, 1)
428-
noisy_model_input = (1 - t) * latents + t * noise
424+
timesteps = sigmas * num_timesteps
429425
elif args.timestep_sampling == "shift":
430426
shift = args.discrete_flow_shift
431-
logits_norm = torch.randn(bsz, device=device)
432-
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
433-
timesteps = logits_norm.sigmoid()
434-
timesteps = (timesteps * shift) / (1 + (shift - 1) * timesteps)
435-
436-
t = timesteps.view(-1, 1, 1, 1)
437-
timesteps = timesteps * 1000.0
438-
noisy_model_input = (1 - t) * latents + t * noise
427+
sigmas = torch.randn(bsz, device=device)
428+
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
429+
sigmas = sigmas.sigmoid()
430+
sigmas = (sigmas * shift) / (1 + (shift - 1) * sigmas)
431+
timesteps = sigmas * num_timesteps
439432
elif args.timestep_sampling == "flux_shift":
440-
logits_norm = torch.randn(bsz, device=device)
441-
logits_norm = logits_norm * args.sigmoid_scale # larger scale for more uniform sampling
442-
timesteps = logits_norm.sigmoid()
443-
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2))
444-
timesteps = time_shift(mu, 1.0, timesteps)
445-
446-
t = timesteps.view(-1, 1, 1, 1)
447-
timesteps = timesteps * 1000.0
448-
noisy_model_input = (1 - t) * latents + t * noise
433+
sigmas = torch.randn(bsz, device=device)
434+
sigmas = sigmas * args.sigmoid_scale # larger scale for more uniform sampling
435+
sigmas = sigmas.sigmoid()
436+
mu = get_lin_function(y1=0.5, y2=1.15)((h // 2) * (w // 2)) # we are pre-packed so must adjust for packed size
437+
sigmas = time_shift(mu, 1.0, sigmas)
438+
timesteps = sigmas * num_timesteps
449439
else:
450440
# Sample a random timestep for each image
451441
# for weighting schemes where we sample timesteps non-uniformly
@@ -456,12 +446,24 @@ def get_noisy_model_input_and_timesteps(
456446
logit_std=args.logit_std,
457447
mode_scale=args.mode_scale,
458448
)
459-
indices = (u * noise_scheduler.config.num_train_timesteps).long()
449+
indices = (u * num_timesteps).long()
460450
timesteps = noise_scheduler.timesteps[indices].to(device=device)
461-
462-
# Add noise according to flow matching.
463451
sigmas = get_sigmas(noise_scheduler, timesteps, device, n_dim=latents.ndim, dtype=dtype)
464-
noisy_model_input = sigmas * noise + (1.0 - sigmas) * latents
452+
453+
# Broadcast sigmas to latent shape
454+
sigmas = sigmas.view(-1, 1, 1, 1)
455+
456+
# Add noise to the latents according to the noise magnitude at each timestep
457+
# (this is the forward diffusion process)
458+
if args.ip_noise_gamma:
459+
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
460+
if args.ip_noise_gamma_random_strength:
461+
ip_noise_gamma = (torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma)
462+
else:
463+
ip_noise_gamma = args.ip_noise_gamma
464+
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
465+
else:
466+
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
465467

466468
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas
467469

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,220 @@
1+
import pytest
2+
import torch
3+
from unittest.mock import MagicMock, patch
4+
from library.flux_train_utils import (
5+
get_noisy_model_input_and_timesteps,
6+
)
7+
8+
# Mock classes and functions
9+
class MockNoiseScheduler:
10+
def __init__(self, num_train_timesteps=1000):
11+
self.config = MagicMock()
12+
self.config.num_train_timesteps = num_train_timesteps
13+
self.timesteps = torch.arange(num_train_timesteps, dtype=torch.long)
14+
15+
16+
# Create fixtures for commonly used objects
17+
@pytest.fixture
18+
def args():
19+
args = MagicMock()
20+
args.timestep_sampling = "uniform"
21+
args.weighting_scheme = "uniform"
22+
args.logit_mean = 0.0
23+
args.logit_std = 1.0
24+
args.mode_scale = 1.0
25+
args.sigmoid_scale = 1.0
26+
args.discrete_flow_shift = 3.1582
27+
args.ip_noise_gamma = None
28+
args.ip_noise_gamma_random_strength = False
29+
return args
30+
31+
32+
@pytest.fixture
33+
def noise_scheduler():
34+
return MockNoiseScheduler(num_train_timesteps=1000)
35+
36+
37+
@pytest.fixture
38+
def latents():
39+
return torch.randn(2, 4, 8, 8)
40+
41+
42+
@pytest.fixture
43+
def noise():
44+
return torch.randn(2, 4, 8, 8)
45+
46+
47+
@pytest.fixture
48+
def device():
49+
# return "cuda" if torch.cuda.is_available() else "cpu"
50+
return "cpu"
51+
52+
53+
# Mock the required functions
54+
@pytest.fixture(autouse=True)
55+
def mock_functions():
56+
with (
57+
patch("torch.sigmoid", side_effect=torch.sigmoid),
58+
patch("torch.rand", side_effect=torch.rand),
59+
patch("torch.randn", side_effect=torch.randn),
60+
):
61+
yield
62+
63+
64+
# Test different timestep sampling methods
65+
def test_uniform_sampling(args, noise_scheduler, latents, noise, device):
66+
args.timestep_sampling = "uniform"
67+
dtype = torch.float32
68+
69+
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
70+
71+
assert noisy_input.shape == latents.shape
72+
assert timesteps.shape == (latents.shape[0],)
73+
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
74+
assert noisy_input.dtype == dtype
75+
assert timesteps.dtype == dtype
76+
77+
78+
def test_sigmoid_sampling(args, noise_scheduler, latents, noise, device):
79+
args.timestep_sampling = "sigmoid"
80+
args.sigmoid_scale = 1.0
81+
dtype = torch.float32
82+
83+
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
84+
85+
assert noisy_input.shape == latents.shape
86+
assert timesteps.shape == (latents.shape[0],)
87+
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
88+
89+
90+
def test_shift_sampling(args, noise_scheduler, latents, noise, device):
91+
args.timestep_sampling = "shift"
92+
args.sigmoid_scale = 1.0
93+
args.discrete_flow_shift = 3.1582
94+
dtype = torch.float32
95+
96+
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
97+
98+
assert noisy_input.shape == latents.shape
99+
assert timesteps.shape == (latents.shape[0],)
100+
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
101+
102+
103+
def test_flux_shift_sampling(args, noise_scheduler, latents, noise, device):
104+
args.timestep_sampling = "flux_shift"
105+
args.sigmoid_scale = 1.0
106+
dtype = torch.float32
107+
108+
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
109+
110+
assert noisy_input.shape == latents.shape
111+
assert timesteps.shape == (latents.shape[0],)
112+
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
113+
114+
115+
def test_weighting_scheme(args, noise_scheduler, latents, noise, device):
116+
# Mock the necessary functions for this specific test
117+
with patch("library.flux_train_utils.compute_density_for_timestep_sampling",
118+
return_value=torch.tensor([0.3, 0.7], device=device)), \
119+
patch("library.flux_train_utils.get_sigmas",
120+
return_value=torch.tensor([[0.3], [0.7]], device=device).view(-1, 1, 1, 1)):
121+
122+
args.timestep_sampling = "other" # Will trigger the weighting scheme path
123+
args.weighting_scheme = "uniform"
124+
args.logit_mean = 0.0
125+
args.logit_std = 1.0
126+
args.mode_scale = 1.0
127+
dtype = torch.float32
128+
129+
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(
130+
args, noise_scheduler, latents, noise, device, dtype
131+
)
132+
133+
assert noisy_input.shape == latents.shape
134+
assert timesteps.shape == (latents.shape[0],)
135+
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
136+
137+
138+
# Test IP noise options
139+
def test_with_ip_noise(args, noise_scheduler, latents, noise, device):
140+
args.ip_noise_gamma = 0.5
141+
args.ip_noise_gamma_random_strength = False
142+
dtype = torch.float32
143+
144+
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
145+
146+
assert noisy_input.shape == latents.shape
147+
assert timesteps.shape == (latents.shape[0],)
148+
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
149+
150+
151+
def test_with_random_ip_noise(args, noise_scheduler, latents, noise, device):
152+
args.ip_noise_gamma = 0.1
153+
args.ip_noise_gamma_random_strength = True
154+
dtype = torch.float32
155+
156+
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
157+
158+
assert noisy_input.shape == latents.shape
159+
assert timesteps.shape == (latents.shape[0],)
160+
assert sigmas.shape == (latents.shape[0], 1, 1, 1)
161+
162+
163+
# Test different data types
164+
def test_float16_dtype(args, noise_scheduler, latents, noise, device):
165+
dtype = torch.float16
166+
167+
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
168+
169+
assert noisy_input.dtype == dtype
170+
assert timesteps.dtype == dtype
171+
172+
173+
# Test different batch sizes
174+
def test_different_batch_size(args, noise_scheduler, device):
175+
latents = torch.randn(5, 4, 8, 8) # batch size of 5
176+
noise = torch.randn(5, 4, 8, 8)
177+
dtype = torch.float32
178+
179+
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
180+
181+
assert noisy_input.shape == latents.shape
182+
assert timesteps.shape == (5,)
183+
assert sigmas.shape == (5, 1, 1, 1)
184+
185+
186+
# Test different image sizes
187+
def test_different_image_size(args, noise_scheduler, device):
188+
latents = torch.randn(2, 4, 16, 16) # larger image size
189+
noise = torch.randn(2, 4, 16, 16)
190+
dtype = torch.float32
191+
192+
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
193+
194+
assert noisy_input.shape == latents.shape
195+
assert timesteps.shape == (2,)
196+
assert sigmas.shape == (2, 1, 1, 1)
197+
198+
199+
# Test edge cases
200+
def test_zero_batch_size(args, noise_scheduler, device):
201+
with pytest.raises(AssertionError): # expecting an error with zero batch size
202+
latents = torch.randn(0, 4, 8, 8)
203+
noise = torch.randn(0, 4, 8, 8)
204+
dtype = torch.float32
205+
206+
get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
207+
208+
209+
def test_different_timestep_count(args, device):
210+
noise_scheduler = MockNoiseScheduler(num_train_timesteps=500) # different timestep count
211+
latents = torch.randn(2, 4, 8, 8)
212+
noise = torch.randn(2, 4, 8, 8)
213+
dtype = torch.float32
214+
215+
noisy_input, timesteps, sigmas = get_noisy_model_input_and_timesteps(args, noise_scheduler, latents, noise, device, dtype)
216+
217+
assert noisy_input.shape == latents.shape
218+
assert timesteps.shape == (2,)
219+
# Check that timesteps are within the proper range
220+
assert torch.all(timesteps < 500)

0 commit comments

Comments
 (0)