Skip to content

Commit 893e305

Browse files
committed
Add wavelet loss
1 parent 5a18a03 commit 893e305

File tree

3 files changed

+301
-6
lines changed

3 files changed

+301
-6
lines changed

flux_train_network.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -448,7 +448,7 @@ def call_dit(img, img_ids, t5_out, txt_ids, l_pooled, timesteps, guidance_vec, t
448448
)
449449
target[diff_output_pr_indices] = model_pred_prior.to(target.dtype)
450450

451-
return model_pred, target, timesteps, weighting
451+
return model_pred, noisy_model_input, target, sigmas, timesteps, weighting
452452

453453
def post_process_loss(self, loss, args, timesteps, noise_scheduler):
454454
return loss

library/custom_train_functions.py

Lines changed: 249 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,17 @@
33
import argparse
44
import random
55
import re
6+
from torch import Tensor
67
from torch.types import Number
78
from typing import List, Optional, Union
89
from .utils import setup_logging
910

11+
try:
12+
import pywt
13+
except:
14+
pass
15+
16+
1017
setup_logging()
1118
import logging
1219

@@ -98,9 +105,26 @@ def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, n
98105
return loss
99106

100107

101-
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False):
102-
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size
103-
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000
108+
109+
def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False, image_size=None):
110+
# Check if we have SNR values available
111+
if not (hasattr(noise_scheduler, "all_snr") or hasattr(noise_scheduler, "get_snr_for_timestep")):
112+
return loss
113+
114+
if hasattr(noise_scheduler, "get_snr_for_timestep") and not callable(noise_scheduler.get_snr_for_timestep):
115+
return loss
116+
117+
# Get SNR values with image_size consideration
118+
if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep):
119+
snr_t: torch.Tensor = noise_scheduler.get_snr_for_timestep(timesteps, image_size)
120+
else:
121+
timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr))
122+
snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices])
123+
124+
# Cap the SNR to avoid numerical issues
125+
snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000)
126+
127+
# Apply weighting based on prediction type
104128
if v_prediction:
105129
weight = 1 / (snr_t + 1)
106130
else:
@@ -135,6 +159,12 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted
135159
action="store_true",
136160
help="debiased estimation loss / debiased estimation loss",
137161
)
162+
parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss")
163+
parser.add_argument("--wavelet_loss_alpha", type=float, default=0.015, help="Wavelet loss alpha")
164+
parser.add_argument("--wavelet_loss_type", help="Wavelet loss type l1, l2, huber, smooth_l1. Default to --loss_type value.")
165+
parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT")
166+
parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet")
167+
parser.add_argument("--wavelet_loss_level", type=int, default=1, help="Wavelet loss level 1 (main) or 2 (details)")
138168
if support_weighted_captions:
139169
parser.add_argument(
140170
"--weighted_captions",
@@ -503,6 +533,222 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor:
503533
return loss
504534

505535

536+
class WaveletLoss(torch.nn.Module):
537+
def __init__(self, wavelet='db4', level=3, transform="dwt", loss_fn=torch.nn.functional.mse_loss, device=torch.device("cpu")):
538+
"""
539+
db4 (Daubechies 4) and sym7 (Symlet 7) are wavelet families with different characteristics:
540+
541+
db4 (Daubechies 4):
542+
- 8 coefficients in filter
543+
- Asymmetric shape
544+
- Good frequency localization
545+
- Widely used for general signal processing
546+
547+
sym7 (Symlet 7):
548+
- 14 coefficients in filter
549+
- Nearly symmetric shape
550+
- Better balance between smoothness and detail preservation
551+
- Designed to overcome the asymmetry limitation of Daubechies wavelets
552+
553+
The numbers (4 and 7) indicate the number of vanishing moments, which affects
554+
how well the wavelet can represent polynomial behavior in signals.
555+
556+
---
557+
558+
DWT: Discrete Wavelet Transform - Decomposes a signal into wavelets at different
559+
scales with downsampling, which reduces resolution by half at each level.
560+
SWT: Stationary Wavelet Transform - Similar to DWT but without downsampling,
561+
maintaining the original resolution at all decomposition levels.
562+
This makes SWT translation-invariant and better for preserving spatial
563+
details, which is important for diffusion model training.
564+
565+
Args:
566+
- wavelet = "db4" | "sym7"
567+
- level =
568+
- transform = "dwt" | "swt"
569+
"""
570+
super().__init__()
571+
self.level = level
572+
self.wavelet = wavelet
573+
self.transform = transform
574+
575+
self.loss_fn = loss_fn
576+
577+
# Training Generative Image Super-Resolution Models by Wavelet-Domain Losses
578+
# Enables Better Control of Artifacts
579+
# λLL = 0.1, λLH = λHL = 0.01, λHH = 0.05
580+
self.ll_weight = 0.1
581+
self.lh_weight = 0.01
582+
self.hl_weight = 0.01
583+
self.hh_weight = 0.05
584+
585+
# Level 2, for detail we only use ll values (?)
586+
self.ll_weight2 = 0.1
587+
self.lh_weight2 = 0.01
588+
self.hl_weight2 = 0.01
589+
self.hh_weight2 = 0.05
590+
591+
assert pywt.wavedec2 is not None, "PyWavelet module not available. Please install `pip install PyWavelet`"
592+
# Create GPU filters from wavelet
593+
wav = pywt.Wavelet(wavelet)
594+
self.register_buffer('dec_lo', torch.Tensor(wav.dec_lo).to(device))
595+
self.register_buffer('dec_hi', torch.Tensor(wav.dec_hi).to(device))
596+
597+
def dwt(self, x):
598+
"""
599+
Discrete Wavelet Transform - Decomposes a signal into wavelets at different scales with downsampling, which reduces resolution by half at each level.
600+
"""
601+
batch, channels, height, width = x.shape
602+
x = x.view(batch * channels, 1, height, width)
603+
604+
F = torch.nn.functional
605+
606+
# Single-level 2D DWT on GPU
607+
# Pad for proper convolution
608+
# Padding
609+
x_pad = F.pad(x, (self.dec_lo.size(0)//2,) * 4, mode='reflect')
610+
611+
# Apply filters separately to rows then columns
612+
# Rows
613+
lo = F.conv2d(x_pad, self.dec_lo.view(1,1,-1,1), stride=(2,1))
614+
hi = F.conv2d(x_pad, self.dec_hi.view(1,1,-1,1), stride=(2,1))
615+
616+
# Columns
617+
ll = F.conv2d(lo, self.dec_lo.view(1,1,1,-1), stride=(1,2))
618+
lh = F.conv2d(lo, self.dec_hi.view(1,1,1,-1), stride=(1,2))
619+
hl = F.conv2d(hi, self.dec_lo.view(1,1,1,-1), stride=(1,2))
620+
hh = F.conv2d(hi, self.dec_hi.view(1,1,1,-1), stride=(1,2))
621+
622+
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3])
623+
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3])
624+
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3])
625+
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3])
626+
627+
return ll, lh, hl, hh
628+
629+
def swt(self, x):
630+
"""Stationary Wavelet Transform without downsampling"""
631+
F = torch.nn.functional
632+
dec_lo = self.dec_lo
633+
dec_hi = self.dec_hi
634+
635+
batch, channels, height, width = x.shape
636+
x = x.view(batch * channels, 1, height, width)
637+
638+
# Apply filter rows
639+
x_lo = F.conv2d(F.pad(x, (dec_lo.size(0)//2,)*4, mode='reflect'),
640+
dec_lo.view(1,1,-1,1).repeat(x.size(1),1,1,1),
641+
groups=x.size(1))
642+
x_hi = F.conv2d(F.pad(x, (dec_hi.size(0)//2,)*4, mode='reflect'),
643+
dec_hi.view(1,1,-1,1).repeat(x.size(1),1,1,1),
644+
groups=x.size(1))
645+
646+
# Apply filter columns
647+
ll = F.conv2d(x_lo, dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
648+
lh = F.conv2d(x_lo, dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
649+
hl = F.conv2d(x_hi, dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
650+
hh = F.conv2d(x_hi, dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1))
651+
652+
ll = ll.view(batch, channels, ll.shape[2], ll.shape[3])
653+
lh = lh.view(batch, channels, lh.shape[2], lh.shape[3])
654+
hl = hl.view(batch, channels, hl.shape[2], hl.shape[3])
655+
hh = hh.view(batch, channels, hh.shape[2], hh.shape[3])
656+
657+
return ll, lh, hl, hh
658+
659+
def decompose_latent(self, latent):
660+
"""Apply SWT directly to the latent representation"""
661+
ll_band, lh_band, hl_band, hh_band = self.swt(latent)
662+
663+
combined_hf = torch.cat((lh_band, hl_band, hh_band), dim=1)
664+
665+
result = {
666+
'll': ll_band,
667+
'lh': lh_band,
668+
'hl': hl_band,
669+
'hh': hh_band,
670+
'combined_hf': combined_hf
671+
}
672+
673+
if self.level == 2:
674+
# Second level decomposition of LL band
675+
ll_band2, lh_band2, hl_band2, hh_band2 = self.swt(ll_band)
676+
677+
# Combined HF bands from both levels
678+
combined_lh = torch.cat((lh_band, lh_band2), dim=1)
679+
combined_hl = torch.cat((hl_band, hl_band2), dim=1)
680+
combined_hh = torch.cat((hh_band, hh_band2), dim=1)
681+
combined_hf = torch.cat((combined_lh, combined_hl, combined_hh), dim=1)
682+
683+
result.update({
684+
'll2': ll_band2,
685+
'lh2': lh_band2,
686+
'hl2': hl_band2,
687+
'hh2': hh_band2,
688+
'combined_hf': combined_hf
689+
})
690+
691+
return result
692+
693+
def swt_forward(self, pred, target):
694+
F = torch.nn.functional
695+
696+
# Decompose latents
697+
pred_bands = self.decompose_latent(pred)
698+
target_bands = self.decompose_latent(target)
699+
700+
loss = 0
701+
702+
# Calculate weighted loss for level 1
703+
loss += self.ll_weight * self.loss_fn(pred_bands['ll'], target_bands['ll'])
704+
loss += self.lh_weight * self.loss_fn(pred_bands['lh'], target_bands['lh'])
705+
loss += self.hl_weight * self.loss_fn(pred_bands['hl'], target_bands['hl'])
706+
loss += self.hh_weight * self.loss_fn(pred_bands['hh'], target_bands['hh'])
707+
708+
# Calculate weighted loss for level 2 if needed
709+
if self.level == 2:
710+
loss += self.ll_weight2 * self.loss_fn(pred_bands['ll2'], target_bands['ll2'])
711+
loss += self.lh_weight2 * self.loss_fn(pred_bands['lh2'], target_bands['lh2'])
712+
loss += self.hl_weight2 * self.loss_fn(pred_bands['hl2'], target_bands['hl2'])
713+
loss += self.hh_weight2 * self.loss_fn(pred_bands['hh2'], target_bands['hh2'])
714+
715+
return loss, pred_bands['combined_hf'], target_bands['combined_hf']
716+
717+
def dwt_forward(self, pred, target):
718+
F = torch.nn.functional
719+
loss = 0
720+
721+
for level in range(self.level):
722+
# Get coefficients
723+
p_ll, p_lh, p_hl, p_hh = self.dwt(pred)
724+
t_ll, t_lh, t_hl, t_hh = self.dwt(target)
725+
726+
loss += self.loss_fn(p_lh, t_lh)
727+
loss += self.loss_fn(p_hl, t_hl)
728+
loss += self.loss_fn(p_hh, t_hh)
729+
730+
# Continue with approximation coefficients
731+
pred, target = p_ll, t_ll
732+
733+
# Add final approximation loss
734+
loss += self.loss_fn(pred, target)
735+
736+
return loss, None, None
737+
738+
def forward(self, pred: Tensor, target: Tensor):
739+
"""
740+
Calculate wavelet loss using the rectified flow pred and target
741+
742+
Args:
743+
pred: Rectified prediction from model
744+
target: Rectified target after noisy latent
745+
"""
746+
if self.transform == 'dwt':
747+
return self.dwt_forward(pred, target)
748+
else:
749+
return self.swt_forward(pred, target)
750+
751+
506752
"""
507753
##########################################
508754
# Perlin Noise

train_network.py

Lines changed: 51 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
add_v_prediction_like_loss,
4444
apply_debiased_estimation,
4545
apply_masked_loss,
46+
WaveletLoss
4647
)
4748
from library.utils import setup_logging, add_logging_arguments
4849

@@ -321,7 +322,7 @@ def get_noise_pred_and_target(
321322
network.set_multiplier(1.0) # may be overwritten by "network_multipliers" in the next step
322323
target[diff_output_pr_indices] = noise_pred_prior.to(target.dtype)
323324

324-
return noise_pred, target, timesteps, None
325+
return noise_pred, noisy_latents, target, sigmas, timesteps, None
325326

326327
def post_process_loss(self, loss, args, timesteps: torch.IntTensor, noise_scheduler) -> torch.FloatTensor:
327328
if args.min_snr_gamma:
@@ -446,7 +447,7 @@ def process_batch(
446447
text_encoder_conds[i] = encoded_text_encoder_conds[i]
447448

448449
# sample noise, call unet, get target
449-
noise_pred, target, timesteps, weighting = self.get_noise_pred_and_target(
450+
noise_pred, noisy_latents, target, sigmas, timesteps, weighting = self.get_noise_pred_and_target(
450451
args,
451452
accelerator,
452453
noise_scheduler,
@@ -462,6 +463,18 @@ def process_batch(
462463

463464
huber_c = train_util.get_huber_threshold_if_needed(args, timesteps, noise_scheduler)
464465
loss = train_util.conditional_loss(noise_pred.float(), target.float(), args.loss_type, "none", huber_c)
466+
467+
if args.wavelet_loss_alpha:
468+
# Calculate flow-based clean estimate using the target
469+
flow_based_clean = noisy_latents - sigmas.view(-1, 1, 1, 1) * target
470+
471+
# Calculate model-based denoised estimate
472+
model_denoised = noisy_latents - sigmas.view(-1, 1, 1, 1) * noise_pred
473+
474+
wav_loss, pred_combined_hf, target_combined_hf = self.wavelet_loss(model_denoised.float(), flow_based_clean.float())
475+
# Weight the losses as needed
476+
loss = loss + args.wavelet_loss_alpha * wav_loss
477+
465478
if weighting is not None:
466479
loss = loss * weighting
467480
if args.masked_loss or ("alpha_masks" in batch and batch["alpha_masks"] is not None):
@@ -1040,6 +1053,12 @@ def load_model_hook(models, input_dir):
10401053
"ss_validate_every_n_epochs": args.validate_every_n_epochs,
10411054
"ss_validate_every_n_steps": args.validate_every_n_steps,
10421055
"ss_resize_interpolation": args.resize_interpolation,
1056+
"ss_wavelet_loss": args.wavelet_loss,
1057+
"ss_wavelet_loss_alpha": args.wavelet_loss_alpha,
1058+
"ss_wavelet_loss_type": args.wavelet_loss_type,
1059+
"ss_wavelet_loss_transform": args.wavelet_loss_transform,
1060+
"ss_wavelet_loss_wavelet": args.wavelet_loss_wavelet,
1061+
"ss_wavelet_loss_level": args.wavelet_loss_level,
10431062
}
10441063

10451064
self.update_metadata(metadata, args) # architecture specific metadata
@@ -1260,6 +1279,36 @@ def load_model_hook(models, input_dir):
12601279
val_step_loss_recorder = train_util.LossRecorder()
12611280
val_epoch_loss_recorder = train_util.LossRecorder()
12621281

1282+
if args.wavelet_loss:
1283+
def loss_fn(args):
1284+
loss_type = args.wavelet_loss_type if args.wavelet_loss_type is not None else args.loss_type
1285+
if loss_type == "huber":
1286+
def huber(pred, target, reduction="mean"):
1287+
if args.huber_c is None:
1288+
raise NotImplementedError("huber_c not implemented correctly")
1289+
b_size = pred.shape[0]
1290+
huber_c = torch.full((b_size,), args.huber_c * args.huber_scale, device=pred.device)
1291+
huber_c = huber_c.view(-1, 1, 1, 1)
1292+
loss = 2 * huber_c * (torch.sqrt((pred - target) ** 2 + huber_c**2) - huber_c)
1293+
return loss.mean()
1294+
return huber
1295+
1296+
elif loss_type == "smooth_l1":
1297+
def smooth_l1(pred, target, reduction="mean"):
1298+
if args.huber_c is None:
1299+
raise NotImplementedError("huber_c not implemented correctly")
1300+
b_size = pred.shape[0]
1301+
huber_c = torch.full((b_size,), args.huber_c * args.huber_scale, device=pred.device)
1302+
huber_c = huber_c.view(-1, 1, 1, 1)
1303+
loss = 2 * (torch.sqrt((pred - target) ** 2 + huber_c**2) - huber_c)
1304+
return loss.mean()
1305+
elif loss_type == "l2":
1306+
return torch.nn.functional.mse_loss
1307+
elif loss_type == "l1":
1308+
return torch.nn.functional.l1_loss
1309+
1310+
self.wavelet_loss = WaveletLoss(wavelet=args.wavelet_loss_wavelet, level=args.wavelet_loss_level, loss_fn=loss_fn(args), device=accelerator.device)
1311+
12631312
del train_dataset_group
12641313
if val_dataset_group is not None:
12651314
del val_dataset_group

0 commit comments

Comments
 (0)