|
3 | 3 | import argparse |
4 | 4 | import random |
5 | 5 | import re |
| 6 | +from torch import Tensor |
6 | 7 | from torch.types import Number |
7 | 8 | from typing import List, Optional, Union |
8 | 9 | from .utils import setup_logging |
9 | 10 |
|
| 11 | +try: |
| 12 | + import pywt |
| 13 | +except: |
| 14 | + pass |
| 15 | + |
| 16 | + |
10 | 17 | setup_logging() |
11 | 18 | import logging |
12 | 19 |
|
@@ -98,9 +105,26 @@ def add_v_prediction_like_loss(loss: torch.Tensor, timesteps: torch.IntTensor, n |
98 | 105 | return loss |
99 | 106 |
|
100 | 107 |
|
101 | | -def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False): |
102 | | - snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps]) # batch_size |
103 | | - snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) # if timestep is 0, snr_t is inf, so limit it to 1000 |
| 108 | + |
| 109 | +def apply_debiased_estimation(loss: torch.Tensor, timesteps: torch.IntTensor, noise_scheduler: DDPMScheduler, v_prediction=False, image_size=None): |
| 110 | + # Check if we have SNR values available |
| 111 | + if not (hasattr(noise_scheduler, "all_snr") or hasattr(noise_scheduler, "get_snr_for_timestep")): |
| 112 | + return loss |
| 113 | + |
| 114 | + if hasattr(noise_scheduler, "get_snr_for_timestep") and not callable(noise_scheduler.get_snr_for_timestep): |
| 115 | + return loss |
| 116 | + |
| 117 | + # Get SNR values with image_size consideration |
| 118 | + if hasattr(noise_scheduler, "get_snr_for_timestep") and callable(noise_scheduler.get_snr_for_timestep): |
| 119 | + snr_t: torch.Tensor = noise_scheduler.get_snr_for_timestep(timesteps, image_size) |
| 120 | + else: |
| 121 | + timesteps_indices = train_util.timesteps_to_indices(timesteps, len(noise_scheduler.all_snr)) |
| 122 | + snr_t = torch.stack([noise_scheduler.all_snr[t] for t in timesteps_indices]) |
| 123 | + |
| 124 | + # Cap the SNR to avoid numerical issues |
| 125 | + snr_t = torch.minimum(snr_t, torch.ones_like(snr_t) * 1000) |
| 126 | + |
| 127 | + # Apply weighting based on prediction type |
104 | 128 | if v_prediction: |
105 | 129 | weight = 1 / (snr_t + 1) |
106 | 130 | else: |
@@ -135,6 +159,12 @@ def add_custom_train_arguments(parser: argparse.ArgumentParser, support_weighted |
135 | 159 | action="store_true", |
136 | 160 | help="debiased estimation loss / debiased estimation loss", |
137 | 161 | ) |
| 162 | + parser.add_argument("--wavelet_loss", action="store_true", help="Activate wavelet loss") |
| 163 | + parser.add_argument("--wavelet_loss_alpha", type=float, default=0.015, help="Wavelet loss alpha") |
| 164 | + parser.add_argument("--wavelet_loss_type", help="Wavelet loss type l1, l2, huber, smooth_l1. Default to --loss_type value.") |
| 165 | + parser.add_argument("--wavelet_loss_transform", default="swt", help="Wavelet transform type of DWT or SWT") |
| 166 | + parser.add_argument("--wavelet_loss_wavelet", default="sym7", help="Wavelet") |
| 167 | + parser.add_argument("--wavelet_loss_level", type=int, default=1, help="Wavelet loss level 1 (main) or 2 (details)") |
138 | 168 | if support_weighted_captions: |
139 | 169 | parser.add_argument( |
140 | 170 | "--weighted_captions", |
@@ -503,6 +533,222 @@ def apply_masked_loss(loss, batch) -> torch.FloatTensor: |
503 | 533 | return loss |
504 | 534 |
|
505 | 535 |
|
| 536 | +class WaveletLoss(torch.nn.Module): |
| 537 | + def __init__(self, wavelet='db4', level=3, transform="dwt", loss_fn=torch.nn.functional.mse_loss, device=torch.device("cpu")): |
| 538 | + """ |
| 539 | + db4 (Daubechies 4) and sym7 (Symlet 7) are wavelet families with different characteristics: |
| 540 | +
|
| 541 | + db4 (Daubechies 4): |
| 542 | + - 8 coefficients in filter |
| 543 | + - Asymmetric shape |
| 544 | + - Good frequency localization |
| 545 | + - Widely used for general signal processing |
| 546 | +
|
| 547 | + sym7 (Symlet 7): |
| 548 | + - 14 coefficients in filter |
| 549 | + - Nearly symmetric shape |
| 550 | + - Better balance between smoothness and detail preservation |
| 551 | + - Designed to overcome the asymmetry limitation of Daubechies wavelets |
| 552 | +
|
| 553 | + The numbers (4 and 7) indicate the number of vanishing moments, which affects |
| 554 | + how well the wavelet can represent polynomial behavior in signals. |
| 555 | +
|
| 556 | + --- |
| 557 | +
|
| 558 | + DWT: Discrete Wavelet Transform - Decomposes a signal into wavelets at different |
| 559 | + scales with downsampling, which reduces resolution by half at each level. |
| 560 | + SWT: Stationary Wavelet Transform - Similar to DWT but without downsampling, |
| 561 | + maintaining the original resolution at all decomposition levels. |
| 562 | + This makes SWT translation-invariant and better for preserving spatial |
| 563 | + details, which is important for diffusion model training. |
| 564 | +
|
| 565 | + Args: |
| 566 | + - wavelet = "db4" | "sym7" |
| 567 | + - level = |
| 568 | + - transform = "dwt" | "swt" |
| 569 | + """ |
| 570 | + super().__init__() |
| 571 | + self.level = level |
| 572 | + self.wavelet = wavelet |
| 573 | + self.transform = transform |
| 574 | + |
| 575 | + self.loss_fn = loss_fn |
| 576 | + |
| 577 | + # Training Generative Image Super-Resolution Models by Wavelet-Domain Losses |
| 578 | + # Enables Better Control of Artifacts |
| 579 | + # λLL = 0.1, λLH = λHL = 0.01, λHH = 0.05 |
| 580 | + self.ll_weight = 0.1 |
| 581 | + self.lh_weight = 0.01 |
| 582 | + self.hl_weight = 0.01 |
| 583 | + self.hh_weight = 0.05 |
| 584 | + |
| 585 | + # Level 2, for detail we only use ll values (?) |
| 586 | + self.ll_weight2 = 0.1 |
| 587 | + self.lh_weight2 = 0.01 |
| 588 | + self.hl_weight2 = 0.01 |
| 589 | + self.hh_weight2 = 0.05 |
| 590 | + |
| 591 | + assert pywt.wavedec2 is not None, "PyWavelet module not available. Please install `pip install PyWavelet`" |
| 592 | + # Create GPU filters from wavelet |
| 593 | + wav = pywt.Wavelet(wavelet) |
| 594 | + self.register_buffer('dec_lo', torch.Tensor(wav.dec_lo).to(device)) |
| 595 | + self.register_buffer('dec_hi', torch.Tensor(wav.dec_hi).to(device)) |
| 596 | + |
| 597 | + def dwt(self, x): |
| 598 | + """ |
| 599 | + Discrete Wavelet Transform - Decomposes a signal into wavelets at different scales with downsampling, which reduces resolution by half at each level. |
| 600 | + """ |
| 601 | + batch, channels, height, width = x.shape |
| 602 | + x = x.view(batch * channels, 1, height, width) |
| 603 | + |
| 604 | + F = torch.nn.functional |
| 605 | + |
| 606 | + # Single-level 2D DWT on GPU |
| 607 | + # Pad for proper convolution |
| 608 | + # Padding |
| 609 | + x_pad = F.pad(x, (self.dec_lo.size(0)//2,) * 4, mode='reflect') |
| 610 | + |
| 611 | + # Apply filters separately to rows then columns |
| 612 | + # Rows |
| 613 | + lo = F.conv2d(x_pad, self.dec_lo.view(1,1,-1,1), stride=(2,1)) |
| 614 | + hi = F.conv2d(x_pad, self.dec_hi.view(1,1,-1,1), stride=(2,1)) |
| 615 | + |
| 616 | + # Columns |
| 617 | + ll = F.conv2d(lo, self.dec_lo.view(1,1,1,-1), stride=(1,2)) |
| 618 | + lh = F.conv2d(lo, self.dec_hi.view(1,1,1,-1), stride=(1,2)) |
| 619 | + hl = F.conv2d(hi, self.dec_lo.view(1,1,1,-1), stride=(1,2)) |
| 620 | + hh = F.conv2d(hi, self.dec_hi.view(1,1,1,-1), stride=(1,2)) |
| 621 | + |
| 622 | + ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]) |
| 623 | + lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]) |
| 624 | + hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]) |
| 625 | + hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]) |
| 626 | + |
| 627 | + return ll, lh, hl, hh |
| 628 | + |
| 629 | + def swt(self, x): |
| 630 | + """Stationary Wavelet Transform without downsampling""" |
| 631 | + F = torch.nn.functional |
| 632 | + dec_lo = self.dec_lo |
| 633 | + dec_hi = self.dec_hi |
| 634 | + |
| 635 | + batch, channels, height, width = x.shape |
| 636 | + x = x.view(batch * channels, 1, height, width) |
| 637 | + |
| 638 | + # Apply filter rows |
| 639 | + x_lo = F.conv2d(F.pad(x, (dec_lo.size(0)//2,)*4, mode='reflect'), |
| 640 | + dec_lo.view(1,1,-1,1).repeat(x.size(1),1,1,1), |
| 641 | + groups=x.size(1)) |
| 642 | + x_hi = F.conv2d(F.pad(x, (dec_hi.size(0)//2,)*4, mode='reflect'), |
| 643 | + dec_hi.view(1,1,-1,1).repeat(x.size(1),1,1,1), |
| 644 | + groups=x.size(1)) |
| 645 | + |
| 646 | + # Apply filter columns |
| 647 | + ll = F.conv2d(x_lo, dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) |
| 648 | + lh = F.conv2d(x_lo, dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) |
| 649 | + hl = F.conv2d(x_hi, dec_lo.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) |
| 650 | + hh = F.conv2d(x_hi, dec_hi.view(1,1,1,-1).repeat(x.size(1),1,1,1), groups=x.size(1)) |
| 651 | + |
| 652 | + ll = ll.view(batch, channels, ll.shape[2], ll.shape[3]) |
| 653 | + lh = lh.view(batch, channels, lh.shape[2], lh.shape[3]) |
| 654 | + hl = hl.view(batch, channels, hl.shape[2], hl.shape[3]) |
| 655 | + hh = hh.view(batch, channels, hh.shape[2], hh.shape[3]) |
| 656 | + |
| 657 | + return ll, lh, hl, hh |
| 658 | + |
| 659 | + def decompose_latent(self, latent): |
| 660 | + """Apply SWT directly to the latent representation""" |
| 661 | + ll_band, lh_band, hl_band, hh_band = self.swt(latent) |
| 662 | + |
| 663 | + combined_hf = torch.cat((lh_band, hl_band, hh_band), dim=1) |
| 664 | + |
| 665 | + result = { |
| 666 | + 'll': ll_band, |
| 667 | + 'lh': lh_band, |
| 668 | + 'hl': hl_band, |
| 669 | + 'hh': hh_band, |
| 670 | + 'combined_hf': combined_hf |
| 671 | + } |
| 672 | + |
| 673 | + if self.level == 2: |
| 674 | + # Second level decomposition of LL band |
| 675 | + ll_band2, lh_band2, hl_band2, hh_band2 = self.swt(ll_band) |
| 676 | + |
| 677 | + # Combined HF bands from both levels |
| 678 | + combined_lh = torch.cat((lh_band, lh_band2), dim=1) |
| 679 | + combined_hl = torch.cat((hl_band, hl_band2), dim=1) |
| 680 | + combined_hh = torch.cat((hh_band, hh_band2), dim=1) |
| 681 | + combined_hf = torch.cat((combined_lh, combined_hl, combined_hh), dim=1) |
| 682 | + |
| 683 | + result.update({ |
| 684 | + 'll2': ll_band2, |
| 685 | + 'lh2': lh_band2, |
| 686 | + 'hl2': hl_band2, |
| 687 | + 'hh2': hh_band2, |
| 688 | + 'combined_hf': combined_hf |
| 689 | + }) |
| 690 | + |
| 691 | + return result |
| 692 | + |
| 693 | + def swt_forward(self, pred, target): |
| 694 | + F = torch.nn.functional |
| 695 | + |
| 696 | + # Decompose latents |
| 697 | + pred_bands = self.decompose_latent(pred) |
| 698 | + target_bands = self.decompose_latent(target) |
| 699 | + |
| 700 | + loss = 0 |
| 701 | + |
| 702 | + # Calculate weighted loss for level 1 |
| 703 | + loss += self.ll_weight * self.loss_fn(pred_bands['ll'], target_bands['ll']) |
| 704 | + loss += self.lh_weight * self.loss_fn(pred_bands['lh'], target_bands['lh']) |
| 705 | + loss += self.hl_weight * self.loss_fn(pred_bands['hl'], target_bands['hl']) |
| 706 | + loss += self.hh_weight * self.loss_fn(pred_bands['hh'], target_bands['hh']) |
| 707 | + |
| 708 | + # Calculate weighted loss for level 2 if needed |
| 709 | + if self.level == 2: |
| 710 | + loss += self.ll_weight2 * self.loss_fn(pred_bands['ll2'], target_bands['ll2']) |
| 711 | + loss += self.lh_weight2 * self.loss_fn(pred_bands['lh2'], target_bands['lh2']) |
| 712 | + loss += self.hl_weight2 * self.loss_fn(pred_bands['hl2'], target_bands['hl2']) |
| 713 | + loss += self.hh_weight2 * self.loss_fn(pred_bands['hh2'], target_bands['hh2']) |
| 714 | + |
| 715 | + return loss, pred_bands['combined_hf'], target_bands['combined_hf'] |
| 716 | + |
| 717 | + def dwt_forward(self, pred, target): |
| 718 | + F = torch.nn.functional |
| 719 | + loss = 0 |
| 720 | + |
| 721 | + for level in range(self.level): |
| 722 | + # Get coefficients |
| 723 | + p_ll, p_lh, p_hl, p_hh = self.dwt(pred) |
| 724 | + t_ll, t_lh, t_hl, t_hh = self.dwt(target) |
| 725 | + |
| 726 | + loss += self.loss_fn(p_lh, t_lh) |
| 727 | + loss += self.loss_fn(p_hl, t_hl) |
| 728 | + loss += self.loss_fn(p_hh, t_hh) |
| 729 | + |
| 730 | + # Continue with approximation coefficients |
| 731 | + pred, target = p_ll, t_ll |
| 732 | + |
| 733 | + # Add final approximation loss |
| 734 | + loss += self.loss_fn(pred, target) |
| 735 | + |
| 736 | + return loss, None, None |
| 737 | + |
| 738 | + def forward(self, pred: Tensor, target: Tensor): |
| 739 | + """ |
| 740 | + Calculate wavelet loss using the rectified flow pred and target |
| 741 | + |
| 742 | + Args: |
| 743 | + pred: Rectified prediction from model |
| 744 | + target: Rectified target after noisy latent |
| 745 | + """ |
| 746 | + if self.transform == 'dwt': |
| 747 | + return self.dwt_forward(pred, target) |
| 748 | + else: |
| 749 | + return self.swt_forward(pred, target) |
| 750 | + |
| 751 | + |
506 | 752 | """ |
507 | 753 | ########################################## |
508 | 754 | # Perlin Noise |
|
0 commit comments