Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 35 additions & 2 deletions library/flux_train_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,11 +169,21 @@ def sample_image_inference(

if seed is not None:
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
if accelerator.device.type == "cuda":
torch.cuda.manual_seed(seed)
elif accelerator.device.type == "xpu":
torch.xpu.manual_seed(seed)
elif accelerator.device.type == "mps":
torch.mps.manual_seed(seed)
else:
# True random sample image generation
torch.seed()
torch.cuda.seed()
if accelerator.device.type == "cuda":
torch.cuda.seed()
elif accelerator.device.type == "xpu":
torch.xpu.seed()
elif accelerator.device.type == "mps":
torch.mps.seed()

if negative_prompt is None:
negative_prompt = ""
Expand Down Expand Up @@ -474,6 +484,29 @@ def get_noisy_model_input_and_timesteps(
bsz, _, h, w = latents.shape
assert bsz > 0, "Batch size not large enough"
num_timesteps = noise_scheduler.config.num_train_timesteps
min_timestep = 0 if args.min_timestep is None else args.min_timestep
max_timestep = num_timesteps if args.max_timestep is None else args.max_timestep
if min_timestep > max_timestep:
min_timestep, max_timestep = max_timestep, min_timestep

if min_timestep == max_timestep:
# Deterministic timesteps (used by validation) need fully fixed noise.
timestep_value = float(max_timestep)
timesteps = torch.full((bsz,), timestep_value, device=device, dtype=dtype)
sigma_value = timestep_value / num_timesteps
sigmas = torch.full((bsz,), sigma_value, device=device, dtype=dtype)
sigmas = sigmas.view(-1, 1, 1, 1)
if args.ip_noise_gamma:
xi = torch.randn_like(latents, device=latents.device, dtype=dtype)
if args.ip_noise_gamma_random_strength:
ip_noise_gamma = torch.rand(1, device=latents.device, dtype=dtype) * args.ip_noise_gamma
else:
ip_noise_gamma = args.ip_noise_gamma
noisy_model_input = (1.0 - sigmas) * latents + sigmas * (noise + ip_noise_gamma * xi)
else:
noisy_model_input = (1.0 - sigmas) * latents + sigmas * noise
return noisy_model_input.to(dtype), timesteps.to(dtype), sigmas

if args.timestep_sampling == "uniform" or args.timestep_sampling == "sigmoid":
# Simple random sigma-based noise sampling
if args.timestep_sampling == "sigmoid":
Expand Down