diff --git a/MANIFEST.in b/MANIFEST.in
index 2f29c8b18..14ac7a048 100644
--- a/MANIFEST.in
+++ b/MANIFEST.in
@@ -1,5 +1,6 @@
include basicsr/ops/dcn/src/*.cu basicsr/ops/dcn/src/*.cpp
include basicsr/ops/fused_act/src/*.cu basicsr/ops/fused_act/src/*.cpp
include basicsr/ops/upfirdn2d/src/*.cu basicsr/ops/upfirdn2d/src/*.cpp
+include basicsr/metrics/niqe_pris_params.npz
include VERSION
include requirements.txt
diff --git a/README.md b/README.md
index d49f2f056..0e48e36fd 100644
--- a/README.md
+++ b/README.md
@@ -65,21 +65,21 @@ Other recommended projects:
We provide simple pipelines to train/test/inference models for a quick start.
These pipelines/commands cannot cover all the cases and more details are in the following sections.
-| GAN | | | | | |
-| :--- | :---: | :---: | :--- | :---: | :---: |
-| StyleGAN2 | [Train](docs/HOWTOs.md#How-to-train-StyleGAN2) | [Inference](docs/HOWTOs.md#How-to-inference-StyleGAN2) | | | |
-| **Face Restoration** | | | | | |
-| DFDNet | - | [Inference](docs/HOWTOs.md#How-to-inference-DFDNet) | | | |
-| **Super Resolution** | | | | | |
-| ESRGAN | *TODO* | *TODO* | SRGAN | *TODO* | *TODO*|
-| EDSR | *TODO* | *TODO* | SRResNet | *TODO* | *TODO*|
-| RCAN | *TODO* | *TODO* | SwinIR | [Train](docs/HOWTOs.md#how-to-train-swinir-sr) | [Inference](docs/HOWTOs.md#how-to-inference-swinir-sr)|
-| EDVR | *TODO* | *TODO* | DUF | - | *TODO* |
-| BasicVSR | *TODO* | *TODO* | TOF | - | *TODO* |
-| **Deblurring** | | | | | |
-| DeblurGANv2 | - | *TODO* | | | |
-| **Denoise** | | | | | |
-| RIDNet | - | *TODO* | CBDNet | - | *TODO*|
+| GAN | | | | | |
+| :------------------- | :--------------------------------------------: | :----------------------------------------------------: | :------- | :--------------------------------------------: | :----------------------------------------------------: |
+| StyleGAN2 | [Train](docs/HOWTOs.md#How-to-train-StyleGAN2) | [Inference](docs/HOWTOs.md#How-to-inference-StyleGAN2) | | | |
+| **Face Restoration** | | | | | |
+| DFDNet | - | [Inference](docs/HOWTOs.md#How-to-inference-DFDNet) | | | |
+| **Super Resolution** | | | | | |
+| ESRGAN | *TODO* | *TODO* | SRGAN | *TODO* | *TODO* |
+| EDSR | *TODO* | *TODO* | SRResNet | *TODO* | *TODO* |
+| RCAN | *TODO* | *TODO* | SwinIR | [Train](docs/HOWTOs.md#how-to-train-swinir-sr) | [Inference](docs/HOWTOs.md#how-to-inference-swinir-sr) |
+| EDVR | *TODO* | *TODO* | DUF | - | *TODO* |
+| BasicVSR | *TODO* | *TODO* | TOF | - | *TODO* |
+| **Deblurring** | | | | | |
+| DeblurGANv2 | - | *TODO* | | | |
+| **Denoise** | | | | | |
+| RIDNet | - | *TODO* | CBDNet | - | *TODO* |
## :wrench: Dependencies and Installation
@@ -114,7 +114,7 @@ Please see [project boards](https://github.com/xinntao/BasicSR/projects).
Please see [DesignConvention.md](docs/DesignConvention.md) for the designs and conventions of the BasicSR codebase.
The figure below shows the overall framework. More descriptions for each component:
-**[Datasets.md](docs/Datasets.md)** | **[Models.md](docs/Models.md)** | **[Config.md](Config.md)** | **[Logging.md](docs/Logging.md)**
+**[Datasets.md](docs/Datasets.md)** | **[Models.md](docs/Models.md)** | **[Config.md](docs/Config.md)** | **[Logging.md](docs/Logging.md)**

@@ -144,7 +144,12 @@ The following is a BibTeX reference. The BibTeX entry requires the `url` LaTeX p
If you have any questions, please email `xintao.wang@outlook.com`.
+
+
+- **QQ群**: 扫描左边二维码 或者 搜索QQ群号: 320960100 入群答案:互帮互助共同进步
+- **微信群**: 因为微信群超过200人,需要邀请才可以进群;要进微信群的小伙伴可以先添加 Liangbin 的个人微信 (右边二维码),他会在空闲的时候拉大家入群~
+
-
+
diff --git a/README_CN.md b/README_CN.md
index 7b3424f74..e34173721 100644
--- a/README_CN.md
+++ b/README_CN.md
@@ -62,21 +62,21 @@ BasicSR (**Basic** **S**uper **R**estoration) 是一个基于 PyTorch 的开源
我们提供了简单的流程来快速上手 训练/测试/推理 模型. 这些命令并不能涵盖所有用法, 更多的细节参见下面的部分.
-| GAN | | | | | |
-| :--- | :---: | :---: | :--- | :---: | :---: |
-| StyleGAN2 | [训练](docs/HOWTOs_CN.md#如何训练-StyleGAN2) | [测试](docs/HOWTOs_CN.md#如何测试-StyleGAN2) | | | |
-| **Face Restoration** | | | | | |
-| DFDNet | - | [测试](docs/HOWTOs_CN.md#如何测试-DFDNet) | | | |
-| **Super Resolution** | | | | | |
-| ESRGAN | *TODO* | *TODO* | SRGAN | *TODO* | *TODO*|
-| EDSR | *TODO* | *TODO* | SRResNet | *TODO* | *TODO*|
-| RCAN | *TODO* | *TODO* | SwinIR | [Train](docs/HOWTOs.md#how-to-train-swinir-sr) | [Inference](docs/HOWTOs.md#how-to-inference-swinir-sr)|
-| EDVR | *TODO* | *TODO* | DUF | - | *TODO* |
-| BasicVSR | *TODO* | *TODO* | TOF | - | *TODO* |
-| **Deblurring** | | | | | |
-| DeblurGANv2 | - | *TODO* | | | |
-| **Denoise** | | | | | |
-| RIDNet | - | *TODO* | CBDNet | - | *TODO*|
+| GAN | | | | | |
+| :------------------- | :------------------------------------------: | :------------------------------------------: | :------- | :--------------------------------------------: | :----------------------------------------------------: |
+| StyleGAN2 | [训练](docs/HOWTOs_CN.md#如何训练-StyleGAN2) | [测试](docs/HOWTOs_CN.md#如何测试-StyleGAN2) | | | |
+| **Face Restoration** | | | | | |
+| DFDNet | - | [测试](docs/HOWTOs_CN.md#如何测试-DFDNet) | | | |
+| **Super Resolution** | | | | | |
+| ESRGAN | *TODO* | *TODO* | SRGAN | *TODO* | *TODO* |
+| EDSR | *TODO* | *TODO* | SRResNet | *TODO* | *TODO* |
+| RCAN | *TODO* | *TODO* | SwinIR | [Train](docs/HOWTOs.md#how-to-train-swinir-sr) | [Inference](docs/HOWTOs.md#how-to-inference-swinir-sr) |
+| EDVR | *TODO* | *TODO* | DUF | - | *TODO* |
+| BasicVSR | *TODO* | *TODO* | TOF | - | *TODO* |
+| **Deblurring** | | | | | |
+| DeblurGANv2 | - | *TODO* | | | |
+| **Denoise** | | | | | |
+| RIDNet | - | *TODO* | CBDNet | - | *TODO* |
## :wrench: 依赖和安装
@@ -112,7 +112,7 @@ For detailed instructions refer to [INSTALL.md](INSTALL.md).
参见 [DesignConvention_CN.md](docs/DesignConvention_CN.md).
下图概括了整体的框架. 每个模块更多的描述参见:
-**[Datasets_CN.md](docs/Datasets_CN.md)** | **[Models_CN.md](docs/Models_CN.md)** | **[Config_CN.md](Config_CN.md)** | **[Logging_CN.md](docs/Logging_CN.md)**
+**[Datasets_CN.md](docs/Datasets_CN.md)** | **[Models_CN.md](docs/Models_CN.md)** | **[Config_CN.md](docs/Config_CN.md)** | **[Logging_CN.md](docs/Logging_CN.md)**

@@ -142,7 +142,12 @@ For detailed instructions refer to [INSTALL.md](INSTALL.md).
若有任何问题, 请电邮 `xintao.wang@outlook.com`.
+
+
+- **QQ群**: 扫描左边二维码 或者 搜索QQ群号: 320960100 入群答案:互帮互助共同进步
+- **微信群**: 因为微信群超过200人,需要邀请才可以进群;要进微信群的小伙伴可以先添加 Liangbin 的个人微信 (右边二维码),他会在空闲的时候拉大家入群~
+
-
+
diff --git a/VERSION b/VERSION
index c42fd3d8b..a349a55be 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-1.3.4.4
+1.3.4.7
diff --git a/basicsr/archs/discriminator_arch.py b/basicsr/archs/discriminator_arch.py
index bc6603e0f..2e33bd3b2 100644
--- a/basicsr/archs/discriminator_arch.py
+++ b/basicsr/archs/discriminator_arch.py
@@ -1,7 +1,7 @@
from torch import nn as nn
-
from basicsr.utils.registry import ARCH_REGISTRY
-
+from torch.nn import functional as F
+from torch.nn.utils import spectral_norm
@ARCH_REGISTRY.register()
class VGGStyleDiscriminator128(nn.Module):
@@ -147,3 +147,59 @@ def forward(self, x):
feat = self.lrelu(self.linear1(feat))
out = self.linear2(feat)
return out
+
+
+@ARCH_REGISTRY.register()
+class UNetDiscriminatorSN(nn.Module):
+ """Defines a U-Net discriminator with spectral normalization (SN)"""
+
+ def __init__(self, num_in_ch, num_feat=64, skip_connection=True):
+ super(UNetDiscriminatorSN, self).__init__()
+ self.skip_connection = skip_connection
+ norm = spectral_norm
+
+ self.conv0 = nn.Conv2d(num_in_ch, num_feat, kernel_size=3, stride=1, padding=1)
+
+ self.conv1 = norm(nn.Conv2d(num_feat, num_feat * 2, 4, 2, 1, bias=False))
+ self.conv2 = norm(nn.Conv2d(num_feat * 2, num_feat * 4, 4, 2, 1, bias=False))
+ self.conv3 = norm(nn.Conv2d(num_feat * 4, num_feat * 8, 4, 2, 1, bias=False))
+ # upsample
+ self.conv4 = norm(nn.Conv2d(num_feat * 8, num_feat * 4, 3, 1, 1, bias=False))
+ self.conv5 = norm(nn.Conv2d(num_feat * 4, num_feat * 2, 3, 1, 1, bias=False))
+ self.conv6 = norm(nn.Conv2d(num_feat * 2, num_feat, 3, 1, 1, bias=False))
+
+ # extra
+ self.conv7 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
+ self.conv8 = norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1, bias=False))
+
+ self.conv9 = nn.Conv2d(num_feat, 1, 3, 1, 1)
+
+ def forward(self, x):
+ x0 = F.leaky_relu(self.conv0(x), negative_slope=0.2, inplace=True)
+ x1 = F.leaky_relu(self.conv1(x0), negative_slope=0.2, inplace=True)
+ x2 = F.leaky_relu(self.conv2(x1), negative_slope=0.2, inplace=True)
+ x3 = F.leaky_relu(self.conv3(x2), negative_slope=0.2, inplace=True)
+
+ # upsample
+ x3 = F.interpolate(x3, scale_factor=2, mode='bilinear', align_corners=False)
+ x4 = F.leaky_relu(self.conv4(x3), negative_slope=0.2, inplace=True)
+
+ if self.skip_connection:
+ x4 = x4 + x2
+ x4 = F.interpolate(x4, scale_factor=2, mode='bilinear', align_corners=False)
+ x5 = F.leaky_relu(self.conv5(x4), negative_slope=0.2, inplace=True)
+
+ if self.skip_connection:
+ x5 = x5 + x1
+ x5 = F.interpolate(x5, scale_factor=2, mode='bilinear', align_corners=False)
+ x6 = F.leaky_relu(self.conv6(x5), negative_slope=0.2, inplace=True)
+
+ if self.skip_connection:
+ x6 = x6 + x0
+
+ # extra
+ out = F.leaky_relu(self.conv7(x6), negative_slope=0.2, inplace=True)
+ out = F.leaky_relu(self.conv8(out), negative_slope=0.2, inplace=True)
+ out = self.conv9(out)
+
+ return out
\ No newline at end of file
diff --git a/basicsr/archs/ecbsr_arch.py b/basicsr/archs/ecbsr_arch.py
index a05c0a027..9ecb1fdb9 100644
--- a/basicsr/archs/ecbsr_arch.py
+++ b/basicsr/archs/ecbsr_arch.py
@@ -227,6 +227,8 @@ class ECBSR(nn.Module):
def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_type, scale):
super(ECBSR, self).__init__()
+ self.num_in_ch = num_in_ch
+ self.scale = scale
backbone = []
backbone += [ECB(num_in_ch, num_channel, depth_multiplier=2.0, act_type=act_type, with_idt=with_idt)]
@@ -240,6 +242,10 @@ def __init__(self, num_in_ch, num_out_ch, num_block, num_channel, with_idt, act_
self.upsampler = nn.PixelShuffle(scale)
def forward(self, x):
- y = self.backbone(x) + x # will repeat the input in the channel dimension (repeat scale * scale times)
+ if self.num_in_ch > 1:
+ shortcut = torch.repeat_interleave(x, self.scale * self.scale, dim=1)
+ else:
+ shortcut = x # will repeat the input in the channel dimension (repeat scale * scale times)
+ y = self.backbone(x) + shortcut
y = self.upsampler(y)
return y
diff --git a/basicsr/archs/focalir_arch.py b/basicsr/archs/focalir_arch.py
new file mode 100644
index 000000000..c3892dd5b
--- /dev/null
+++ b/basicsr/archs/focalir_arch.py
@@ -0,0 +1,1470 @@
+######
+# FocalIR
+# This code is referenced by Focal Transformer and SwinIR
+# This model is supported by BasicSR
+######
+# --------------------------------------------------------
+# Focal Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Jianwei Yang (jianwyan@microsoft.com)
+# Based on Swin Transformer written by Zhe Liu
+# --------------------------------------------------------
+# -----------------------------------------------------------------------------------
+# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
+# Originally Written by Ze Liu, Modified by Jingyun Liang.
+# -----------------------------------------------------------------------------------
+
+
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+
+from basicsr.archs.arch_util import to_2tuple, trunc_normal_
+from basicsr.utils.registry import ARCH_REGISTRY
+from thop import profile as hp
+
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
+ """
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class Mlp(nn.Module):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_partition_noreshape(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (B, num_windows_h, num_windows_w, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous()
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+def get_roll_masks(H, W, window_size, shift_size):
+ #####################################
+ # move to top-left
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, H - window_size),
+ slice(H - window_size, H - shift_size),
+ slice(H - shift_size, H))
+ w_slices = (slice(0, W - window_size),
+ slice(W - window_size, W - shift_size),
+ slice(W - shift_size, W))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, window_size * window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask_tl = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ ####################################
+ # move to top right
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, H - window_size),
+ slice(H - window_size, H - shift_size),
+ slice(H - shift_size, H))
+ w_slices = (slice(0, shift_size),
+ slice(shift_size, window_size),
+ slice(window_size, W))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, window_size * window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask_tr = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ ####################################
+ # move to bottom left
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, shift_size),
+ slice(shift_size, window_size),
+ slice(window_size, H))
+ w_slices = (slice(0, W - window_size),
+ slice(W - window_size, W - shift_size),
+ slice(W - shift_size, W))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, window_size * window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask_bl = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ ####################################
+ # move to bottom right
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, shift_size),
+ slice(shift_size, window_size),
+ slice(window_size, H))
+ w_slices = (slice(0, shift_size),
+ slice(shift_size, window_size),
+ slice(window_size, W))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, window_size * window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask_br = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ # append all
+ attn_mask_all = torch.cat((attn_mask_tl, attn_mask_tr, attn_mask_bl, attn_mask_br), -1)
+ return attn_mask_all
+
+
+def get_relative_position_index(q_windows, k_windows):
+ """
+ Args:
+ q_windows: tuple (query_window_height, query_window_width)
+ k_windows: tuple (key_window_height, key_window_width)
+
+ Returns:
+ relative_position_index: query_window_height*query_window_width, key_window_height*key_window_width
+ """
+ # get pair-wise relative position index for each token inside the window
+ coords_h_q = torch.arange(q_windows[0])
+ coords_w_q = torch.arange(q_windows[1])
+ coords_q = torch.stack(torch.meshgrid([coords_h_q, coords_w_q])) # 2, Wh_q, Ww_q
+
+ coords_h_k = torch.arange(k_windows[0])
+ coords_w_k = torch.arange(k_windows[1])
+ coords_k = torch.stack(torch.meshgrid([coords_h_k, coords_w_k])) # 2, Wh, Ww
+
+ coords_flatten_q = torch.flatten(coords_q, 1) # 2, Wh_q*Ww_q
+ coords_flatten_k = torch.flatten(coords_k, 1) # 2, Wh_k*Ww_k
+
+ relative_coords = coords_flatten_q[:, :, None] - coords_flatten_k[:, None, :] # 2, Wh_q*Ww_q, Wh_k*Ww_k
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh_q*Ww_q, Wh_k*Ww_k, 2
+ relative_coords[:, :, 0] += k_windows[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += k_windows[1] - 1
+ relative_coords[:, :, 0] *= (q_windows[1] + k_windows[1]) - 1
+ relative_position_index = relative_coords.sum(-1) # Wh_q*Ww_q, Wh_k*Ww_k
+ return relative_position_index
+
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+
+ Args:
+ dim (int): Number of input channels.
+ expand_size (int): The expand size at focal level 1.
+ window_size (tuple[int]): The height and width of the window.
+ focal_window (int): Focal region size.
+ focal_level (int): Focal attention level.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ pool_method (str): window pooling method. Default: none
+ """
+
+ def __init__(self, dim, expand_size, window_size, focal_window, focal_level, num_heads,
+ qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0., pool_method="none"):
+
+ super().__init__()
+ self.dim = dim
+ self.expand_size = expand_size
+ self.window_size = window_size # Wh, Ww
+ self.pool_method = pool_method
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+ self.focal_level = focal_level
+ self.focal_window = focal_window
+
+ # define a parameter table of relative position bias for each window
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ if self.expand_size > 0 and focal_level > 0:
+ # define a parameter table of position bias between window and its fine-grained surroundings
+ self.window_size_of_key = self.window_size[0] * self.window_size[1] if self.expand_size == 0 else \
+ (4 * self.window_size[0] * self.window_size[1] - 4 * (self.window_size[0] - self.expand_size) * (
+ self.window_size[0] - self.expand_size))
+ self.relative_position_bias_table_to_neighbors = nn.Parameter(
+ torch.zeros(1, num_heads, self.window_size[0] * self.window_size[1],
+ self.window_size_of_key)) # Wh*Ww, nH, nSurrounding
+ trunc_normal_(self.relative_position_bias_table_to_neighbors, std=.02)
+
+ # get mask for rolled k and rolled v
+ mask_tl = torch.ones(self.window_size[0], self.window_size[1]);
+ mask_tl[:-self.expand_size, :-self.expand_size] = 0
+ mask_tr = torch.ones(self.window_size[0], self.window_size[1]);
+ mask_tr[:-self.expand_size, self.expand_size:] = 0
+ mask_bl = torch.ones(self.window_size[0], self.window_size[1]);
+ mask_bl[self.expand_size:, :-self.expand_size] = 0
+ mask_br = torch.ones(self.window_size[0], self.window_size[1]);
+ mask_br[self.expand_size:, self.expand_size:] = 0
+ mask_rolled = torch.stack((mask_tl, mask_tr, mask_bl, mask_br), 0).flatten(0)
+ self.register_buffer("valid_ind_rolled", mask_rolled.nonzero().view(-1))
+
+ if pool_method != "none" and focal_level > 1:
+ self.relative_position_bias_table_to_windows = nn.ParameterList()
+ self.unfolds = nn.ModuleList()
+
+ # build relative position bias between local patch and pooled windows
+ for k in range(focal_level - 1):
+ stride = 2 ** k
+ kernel_size = 2 * (self.focal_window // 2) + 2 ** k + (2 ** k - 1)
+ # define unfolding operations
+ self.unfolds += [nn.Unfold(
+ kernel_size=(kernel_size, kernel_size),
+ stride=stride, padding=kernel_size // 2)
+ ]
+
+ # define relative position bias table
+ relative_position_bias_table_to_windows = nn.Parameter(
+ torch.zeros(
+ self.num_heads,
+ (self.window_size[0] + self.focal_window + 2 ** k - 2) * (
+ self.window_size[1] + self.focal_window + 2 ** k - 2),
+ )
+ )
+ trunc_normal_(relative_position_bias_table_to_windows, std=.02)
+ self.relative_position_bias_table_to_windows.append(relative_position_bias_table_to_windows)
+
+ # define relative position bias index
+ relative_position_index_k = get_relative_position_index(self.window_size,
+ to_2tuple(self.focal_window + 2 ** k - 1))
+ self.register_buffer("relative_position_index_{}".format(k), relative_position_index_k)
+
+ # define unfolding index for focal_level > 0
+ if k > 0:
+ mask = torch.zeros(kernel_size, kernel_size);
+ mask[(2 ** k) - 1:, (2 ** k) - 1:] = 1
+ self.register_buffer("valid_ind_unfold_{}".format(k), mask.flatten(0).nonzero().view(-1))
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x_all, mask_all=None):
+ """
+ Args:
+ x_all (list[Tensors]): input features at different granularity
+ mask_all (list[Tensors/None]): masks for input features at different granularity
+ """
+ x = x_all[0] #
+
+ B, nH, nW, C = x.shape
+ qkv = self.qkv(x).reshape(B, nH, nW, 3, C).permute(3, 0, 1, 2, 4).contiguous()
+ q, k, v = qkv[0], qkv[1], qkv[2] # B, nH, nW, C
+
+ # partition q map
+ (q_windows, k_windows, v_windows) = map(
+ lambda t: window_partition(t, self.window_size[0]).view(
+ -1, self.window_size[0] * self.window_size[0], self.num_heads, C // self.num_heads
+ ).transpose(1, 2),
+ (q, k, v)
+ )
+
+ if self.expand_size > 0 and self.focal_level > 0:
+ (k_tl, v_tl) = map(
+ lambda t: torch.roll(t, shifts=(-self.expand_size, -self.expand_size), dims=(1, 2)), (k, v)
+ )
+ (k_tr, v_tr) = map(
+ lambda t: torch.roll(t, shifts=(-self.expand_size, self.expand_size), dims=(1, 2)), (k, v)
+ )
+ (k_bl, v_bl) = map(
+ lambda t: torch.roll(t, shifts=(self.expand_size, -self.expand_size), dims=(1, 2)), (k, v)
+ )
+ (k_br, v_br) = map(
+ lambda t: torch.roll(t, shifts=(self.expand_size, self.expand_size), dims=(1, 2)), (k, v)
+ )
+
+ (k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows) = map(
+ lambda t: window_partition(t, self.window_size[0]).view(-1, self.window_size[0] * self.window_size[0],
+ self.num_heads, C // self.num_heads),
+ (k_tl, k_tr, k_bl, k_br)
+ )
+ (v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows) = map(
+ lambda t: window_partition(t, self.window_size[0]).view(-1, self.window_size[0] * self.window_size[0],
+ self.num_heads, C // self.num_heads),
+ (v_tl, v_tr, v_bl, v_br)
+ )
+ k_rolled = torch.cat((k_tl_windows, k_tr_windows, k_bl_windows, k_br_windows), 1).transpose(1, 2)
+ v_rolled = torch.cat((v_tl_windows, v_tr_windows, v_bl_windows, v_br_windows), 1).transpose(1, 2)
+
+ # mask out tokens in current window
+ k_rolled = k_rolled[:, :, self.valid_ind_rolled]
+ v_rolled = v_rolled[:, :, self.valid_ind_rolled]
+ k_rolled = torch.cat((k_windows, k_rolled), 2)
+ v_rolled = torch.cat((v_windows, v_rolled), 2)
+ else:
+ k_rolled = k_windows;
+ v_rolled = v_windows;
+
+ if self.pool_method != "none" and self.focal_level > 1:
+ k_pooled = []
+ v_pooled = []
+ for k in range(self.focal_level - 1):
+ stride = 2 ** k
+ x_window_pooled = x_all[k + 1] # B, nWh, nWw, C
+ nWh, nWw = x_window_pooled.shape[1:3]
+
+ # generate mask for pooled windows
+ mask = x_window_pooled.new(nWh, nWw).fill_(1)
+ unfolded_mask = self.unfolds[k](mask.unsqueeze(0).unsqueeze(1)).view(
+ 1, 1, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3,
+ 1).contiguous(). \
+ view(nWh * nWw // stride // stride, -1, 1)
+
+ if k > 0:
+ valid_ind_unfold_k = getattr(self, "valid_ind_unfold_{}".format(k))
+ unfolded_mask = unfolded_mask[:, valid_ind_unfold_k]
+
+ x_window_masks = unfolded_mask.flatten(1).unsqueeze(0)
+ x_window_masks = x_window_masks.masked_fill(x_window_masks == 0, float(-100.0)).masked_fill(
+ x_window_masks > 0, float(0.0))
+ mask_all[k + 1] = x_window_masks
+
+ # generate k and v for pooled windows
+ qkv_pooled = self.qkv(x_window_pooled).reshape(B, nWh, nWw, 3, C).permute(3, 0, 4, 1, 2).contiguous()
+ k_pooled_k, v_pooled_k = qkv_pooled[1], qkv_pooled[2] # B, C, nWh, nWw
+
+ (k_pooled_k, v_pooled_k) = map(
+ lambda t: self.unfolds[k](t).view(
+ B, C, self.unfolds[k].kernel_size[0], self.unfolds[k].kernel_size[1], -1).permute(0, 4, 2, 3,
+ 1).contiguous(). \
+ view(-1, self.unfolds[k].kernel_size[0] * self.unfolds[k].kernel_size[1], self.num_heads,
+ C // self.num_heads).transpose(1, 2),
+ (k_pooled_k, v_pooled_k) # (B x (nH*nW)) x nHeads x (unfold_wsize x unfold_wsize) x head_dim
+ )
+
+ if k > 0:
+ (k_pooled_k, v_pooled_k) = map(
+ lambda t: t[:, :, valid_ind_unfold_k], (k_pooled_k, v_pooled_k)
+ )
+
+ k_pooled += [k_pooled_k]
+ v_pooled += [v_pooled_k]
+ k_all = torch.cat([k_rolled] + k_pooled, 2)
+ v_all = torch.cat([v_rolled] + v_pooled, 2)
+ else:
+ k_all = k_rolled
+ v_all = v_rolled
+
+ N = k_all.shape[-2]
+ q_windows = q_windows * self.scale
+ attn = (q_windows @ k_all.transpose(-2,
+ -1)) # B*nW, nHead, window_size*window_size, focal_window_size*focal_window_size
+
+ window_area = self.window_size[0] * self.window_size[1]
+ window_area_rolled = k_rolled.shape[2]
+
+ # add relative position bias for tokens inside window
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn[:, :, :window_area, :window_area] = attn[:, :, :window_area,
+ :window_area] + relative_position_bias.unsqueeze(0)
+
+ # add relative position bias for patches inside a window
+ if self.expand_size > 0 and self.focal_level > 0:
+ attn[:, :, :window_area, window_area:window_area_rolled] = attn[:, :, :window_area,
+ window_area:window_area_rolled] + self.relative_position_bias_table_to_neighbors
+
+ if self.pool_method != "none" and self.focal_level > 1:
+ # add relative position bias for different windows in an image
+ offset = window_area_rolled
+ for k in range(self.focal_level - 1):
+ # add relative position bias
+ relative_position_index_k = getattr(self, 'relative_position_index_{}'.format(k))
+ relative_position_bias_to_windows = self.relative_position_bias_table_to_windows[k][:,
+ relative_position_index_k.view(-1)].view(
+ -1, self.window_size[0] * self.window_size[1], (self.focal_window + 2 ** k - 1) ** 2,
+ ) # nH, NWh*NWw,focal_region*focal_region
+ attn[:, :, :window_area, offset:(offset + (self.focal_window + 2 ** k - 1) ** 2)] = \
+ attn[:, :, :window_area, offset:(offset + (
+ self.focal_window + 2 ** k - 1) ** 2)] + relative_position_bias_to_windows.unsqueeze(0)
+ # add attentional mask
+ if mask_all[k + 1] is not None:
+ attn[:, :, :window_area, offset:(offset + (self.focal_window + 2 ** k - 1) ** 2)] = \
+ attn[:, :, :window_area, offset:(offset + (self.focal_window + 2 ** k - 1) ** 2)] + \
+ mask_all[k + 1][:, :, None, None, :].repeat(attn.shape[0] // mask_all[k + 1].shape[1], 1, 1, 1,
+ 1).view(-1, 1, 1, mask_all[k + 1].shape[-1])
+
+ offset += (self.focal_window + 2 ** k - 1) ** 2
+
+ if mask_all[0] is not None:
+ nW = mask_all[0].shape[0]
+ attn = attn.view(attn.shape[0] // nW, nW, self.num_heads, window_area, N)
+ attn[:, :, :, :, :window_area] = attn[:, :, :, :, :window_area] + mask_all[0][None, :, None, :, :]
+ attn = attn.view(-1, self.num_heads, window_area, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v_all).transpose(1, 2).reshape(attn.shape[0], window_area, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
+
+ def flops(self, N, window_size, unfold_size):
+ # calculate flops for 1 window with token length of N
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += N * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
+ if self.pool_method != "none" and self.focal_level > 1:
+ flops += self.num_heads * N * (self.dim // self.num_heads) * (unfold_size * unfold_size)
+ if self.expand_size > 0 and self.focal_level > 0:
+ flops += self.num_heads * N * (self.dim // self.num_heads) * (
+ (window_size + 2 * self.expand_size) ** 2 - window_size ** 2)
+
+ # x = (attn @ v)
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
+ if self.pool_method != "none" and self.focal_level > 1:
+ flops += self.num_heads * N * (self.dim // self.num_heads) * (unfold_size * unfold_size)
+ if self.expand_size > 0 and self.focal_level > 0:
+ flops += self.num_heads * N * (self.dim // self.num_heads) * (
+ (window_size + 2 * self.expand_size) ** 2 - window_size ** 2)
+
+ # x = self.proj(x)
+ flops += N * self.dim * self.dim
+ return flops
+
+
+class FocalTransformerBlock(nn.Module):
+ r""" Focal Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ expand_size (int): expand size at first focal level (finest level).
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ pool_method (str): window pooling method. Default: none, options: [none|fc|conv]
+ focal_level (int): number of focal levels. Default: 1.
+ focal_window (int): region size of focal attention. Default: 1
+ use_layerscale (bool): whether use layer scale for training stability. Default: False
+ layerscale_value (float): scaling value for layer scale. Default: 1e-4
+ """
+
+ def __init__(self, dim, input_resolution, num_heads, window_size=7, expand_size=0, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm, pool_method="none",
+ focal_level=1, focal_window=1, use_layerscale=False, layerscale_value=1e-4):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.expand_size = expand_size
+ self.mlp_ratio = mlp_ratio
+ self.pool_method = pool_method
+ self.focal_level = focal_level
+ self.focal_window = focal_window
+ self.use_layerscale = use_layerscale
+
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.expand_size = 0
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.window_size_glo = self.window_size
+
+ self.pool_layers = nn.ModuleList()
+ if self.pool_method != "none":
+ for k in range(self.focal_level - 1):
+ window_size_glo = math.floor(self.window_size_glo / (2 ** k))
+ if self.pool_method == "fc":
+ self.pool_layers.append(nn.Linear(window_size_glo * window_size_glo, 1))
+ self.pool_layers[-1].weight.data.fill_(1. / (window_size_glo * window_size_glo))
+ self.pool_layers[-1].bias.data.fill_(0)
+ elif self.pool_method == "conv":
+ self.pool_layers.append(
+ nn.Conv2d(dim, dim, kernel_size=window_size_glo, stride=window_size_glo, groups=dim))
+
+ self.norm1 = norm_layer(dim)
+
+ self.attn = WindowAttention(
+ dim, expand_size=self.expand_size, window_size=to_2tuple(self.window_size),
+ focal_window=focal_window, focal_level=focal_level, num_heads=num_heads,
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop, pool_method=pool_method)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if self.shift_size > 0:
+ # calculate attention mask for SW-MSA
+ H, W = self.input_resolution
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ else:
+ attn_mask = None
+ self.register_buffer("attn_mask", attn_mask)
+
+ if self.use_layerscale:
+ self.gamma_1 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
+ self.gamma_2 = nn.Parameter(layerscale_value * torch.ones((dim)), requires_grad=True)
+
+ def forward(self, x, x_size):
+ H, W = x_size
+ B, _, C = x.shape
+ # assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # pad feature maps to multiples of window size
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ if pad_r > 0 or pad_b > 0:
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+
+ B, H, W, C = x.shape
+
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ x_windows_all = [shifted_x]
+ x_window_masks_all = [self.attn_mask]
+
+ if self.focal_level > 1 and self.pool_method != "none":
+ # if we add coarser granularity and the pool method is not none
+ for k in range(self.focal_level - 1):
+ window_size_glo = math.floor(self.window_size_glo / (2 ** k))
+ pooled_h = math.ceil(H / self.window_size) * (2 ** k)
+ pooled_w = math.ceil(W / self.window_size) * (2 ** k)
+ H_pool = pooled_h * window_size_glo
+ W_pool = pooled_w * window_size_glo
+
+ x_level_k = shifted_x
+ # trim or pad shifted_x depending on the required size
+ if H > H_pool:
+ trim_t = (H - H_pool) // 2
+ trim_b = H - H_pool - trim_t
+ x_level_k = x_level_k[:, trim_t:-trim_b]
+ elif H < H_pool:
+ pad_t = (H_pool - H) // 2
+ pad_b = H_pool - H - pad_t
+ x_level_k = F.pad(x_level_k, (0, 0, 0, 0, pad_t, pad_b))
+
+ if W > W_pool:
+ trim_l = (W - W_pool) // 2
+ trim_r = W - W_pool - trim_l
+ x_level_k = x_level_k[:, :, trim_l:-trim_r]
+ elif W < W_pool:
+ pad_l = (W_pool - W) // 2
+ pad_r = W_pool - W - pad_l
+ x_level_k = F.pad(x_level_k, (0, 0, pad_l, pad_r))
+
+ x_windows_noreshape = window_partition_noreshape(x_level_k.contiguous(),
+ window_size_glo) # B, nw, nw, window_size, window_size, C
+ nWh, nWw = x_windows_noreshape.shape[1:3]
+ if self.pool_method == "mean":
+ x_windows_pooled = x_windows_noreshape.mean([3, 4]) # B, nWh, nWw, C
+ elif self.pool_method == "max":
+ x_windows_pooled = x_windows_noreshape.max(-2)[0].max(-2)[0].view(B, nWh, nWw,
+ C) # B, nWh, nWw, C
+ elif self.pool_method == "fc":
+ x_windows_noreshape = x_windows_noreshape.view(B, nWh, nWw, window_size_glo * window_size_glo,
+ C).transpose(3, 4) # B, nWh, nWw, C, wsize**2
+ x_windows_pooled = self.pool_layers[k](x_windows_noreshape).flatten(
+ -2) # B, nWh, nWw, C
+ elif self.pool_method == "conv":
+ x_windows_noreshape = x_windows_noreshape.view(-1, window_size_glo, window_size_glo, C).permute(0,
+ 3,
+ 1,
+ 2).contiguous() # B * nw * nw, C, wsize, wsize
+ x_windows_pooled = self.pool_layers[k](x_windows_noreshape).view(B, nWh, nWw,
+ C) # B, nWh, nWw, C
+
+ x_windows_all += [x_windows_pooled]
+ x_window_masks_all += [None]
+
+ attn_windows = self.attn(x_windows_all, mask_all=x_window_masks_all) # nW*B, window_size*window_size, C
+
+ attn_windows = attn_windows[:, :self.window_size ** 2]
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ # 不知道这行干啥的,先改了再说
+ #x = x[:, :self.input_resolution[0], :self.input_resolution[1]].contiguous().view(B, -1, C)
+ x = x[:, :x_size[0], :x_size[1]].contiguous().view(B, -1, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x if (not self.use_layerscale) else (self.gamma_1 * x))
+ x = x + self.drop_path(
+ self.mlp(self.norm2(x)) if (not self.use_layerscale) else (self.gamma_2 * self.mlp(self.norm2(x))))
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " \
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+
+ def flops(self):
+ flops = 0
+ H, W = self.input_resolution
+ # norm1
+ flops += self.dim * H * W
+
+ # W-MSA/SW-MSA
+ nW = H * W / self.window_size / self.window_size
+ flops += nW * self.attn.flops(self.window_size * self.window_size, self.window_size, self.focal_window)
+
+ if self.pool_method != "none" and self.focal_level > 1:
+ for k in range(self.focal_level - 1):
+ window_size_glo = math.floor(self.window_size_glo / (2 ** k))
+ nW_glo = nW * (2 ** k)
+ # (sub)-window pooling
+ flops += nW_glo * self.dim * window_size_glo * window_size_glo
+ # qkv for global levels
+ # NOTE: in our implementation, we pass the pooled window embedding to qkv embedding layer,
+ # but theoritically, we only need to compute k and v.
+ flops += nW_glo * self.dim * 3 * self.dim
+
+ # mlp
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * H * W
+ return flops
+
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+
+ Args:
+ img_size (tuple[int]): Resolution of input feature.
+ in_chans (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, img_size, in_chans=3, norm_layer=nn.LayerNorm, **kwargs):
+ super().__init__()
+ self.input_resolution = img_size
+ self.dim = in_chans
+ self.reduction = nn.Linear(4 * in_chans, 2 * in_chans, bias=False)
+ self.norm = norm_layer(4 * in_chans)
+
+ def forward(self, x):
+ """
+ x: B, C, H, W
+ """
+ B, C, H, W = x.shape
+
+ x = x.permute(0, 2, 3, 1).contiguous()
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = H * W * self.dim
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+ return flops
+
+
+class BasicLayer(nn.Module):
+ """ A basic Focal Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ expand_size (int): expand size for focal level 1.
+ expand_layer (str): expand layer. Default: all
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.0.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ pool_method (str): Window pooling method. Default: none.
+ focal_level (int): Number of focal levels. Default: 1.
+ focal_window (int): region size at each focal level. Default: 1.
+ use_conv_embed (bool): whether use overlapped convolutional patch embedding layer. Default: False
+ use_shift (bool): Whether use window shift as in Swin Transformer. Default: False
+ use_pre_norm (bool): Whether use pre-norm before patch embedding projection for stability. Default: False
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ use_layerscale (bool): Whether use layer scale for stability. Default: False.
+ layerscale_value (float): Layerscale value. Default: 1e-4.
+ """
+
+ def __init__(self, dim, input_resolution, depth, num_heads, window_size, expand_size, expand_layer="all",
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0.,
+ drop_path=0., norm_layer=nn.LayerNorm, pool_method="none",
+ focal_level=1, focal_window=1, use_shift=False,
+ downsample=None, use_checkpoint=False, use_layerscale=False, layerscale_value=1e-4):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ if expand_layer == "even":
+ expand_factor = 0
+ elif expand_layer == "odd":
+ expand_factor = 1
+ elif expand_layer == "all":
+ expand_factor = -1
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ FocalTransformerBlock(dim=dim, input_resolution=input_resolution,
+ num_heads=num_heads, window_size=window_size,
+ shift_size=(0 if (i % 2 == 0) else window_size // 2) if use_shift else 0,
+ expand_size=0 if (i % 2 == expand_factor) else expand_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ pool_method=pool_method,
+ focal_level=focal_level,
+ focal_window=focal_window,
+ use_layerscale=use_layerscale,
+ layerscale_value=layerscale_value)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(img_size=input_resolution, embed_dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, x_size):
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x, x_size)
+
+ if self.downsample is not None:
+ x = self.downsample(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
+
+
+class RFTB(nn.Module):
+ """Residual Focal Transformer Block (RFTB).
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ img_size: Input image size.
+ patch_size: Patch size.
+ resi_connection: The convolutional block before residual connection.
+ """
+
+ def __init__(self,
+ dim,
+ input_resolution,
+ depth,
+ num_heads,
+ window_size,
+ expand_size,
+ expand_layer="all",
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ norm_layer=nn.LayerNorm,
+ pool_method="none",
+ focal_level=1,
+ focal_window=1,
+ use_conv_embed=False,
+ use_shift=False,
+ use_pre_norm=False,
+ downsample=None,
+ use_checkpoint=False,
+ use_layerscale=False,
+ layerscale_value=1e-4,
+ img_size=224,
+ patch_size=4,
+ resi_connection='1conv'):
+ super(RFTB, self).__init__()
+
+ self.dim = dim
+ self.input_resolution = input_resolution
+
+ self.residual_group = BasicLayer(
+ dim=dim,
+ input_resolution=input_resolution,
+ depth=depth,
+ num_heads=num_heads,
+ window_size=window_size,
+ expand_size=expand_size,
+ expand_layer=expand_layer,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path,
+ norm_layer=norm_layer,
+ pool_method=pool_method,
+ focal_level=focal_level,
+ focal_window=focal_window,
+ use_shift=use_shift,
+ downsample=downsample,
+ use_checkpoint=use_checkpoint,
+ use_layerscale=use_layerscale,
+ layerscale_value=layerscale_value)
+
+ if resi_connection == '1conv':
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv = nn.Sequential(
+ nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, norm_layer=None)
+
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=dim, embed_dim=dim, norm_layer=None)
+
+ def forward(self, x, x_size):
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
+
+ def flops(self):
+ flops = 0
+ flops += self.residual_group.flops()
+ h, w = self.input_resolution
+ flops += h * w * self.dim * self.dim * 9
+ flops += self.patch_embed.flops()
+ flops += self.patch_unembed.flops()
+
+ return flops
+
+
+class PatchEmbed(nn.Module):
+ r""" Image to Patch Embedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ use_conv_embed (bool): Wherther use overlapped convolutional embedding layer. Default: False.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ use_pre_norm (bool): Whether use pre-normalization before projection. Default: False
+ is_stem (bool): Whether current patch embedding is stem. Default: False
+ """
+
+ def __init__(self, img_size=(224, 224), patch_size=4, in_chans=3, embed_dim=96,
+ norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ x = x.flatten(2).transpose(1, 2)
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+ def flops(self):
+ Ho, Wo = self.patches_resolution
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+
+class PatchUnEmbed(nn.Module):
+ r""" Image to Patch Unembedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ def forward(self, x, x_size):
+ x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c
+ return x
+
+ def flops(self):
+ flops = 0
+ return flops
+
+
+class Upsample(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
+ super(Upsample, self).__init__(*m)
+
+
+class UpsampleOneStep(nn.Sequential):
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
+ Used in lightweight SR to save parameters.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+
+ """
+
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
+ self.num_feat = num_feat
+ self.input_resolution = input_resolution
+ m = []
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
+ m.append(nn.PixelShuffle(scale))
+ super(UpsampleOneStep, self).__init__(*m)
+
+ def flops(self):
+ h, w = self.input_resolution
+ flops = h * w * self.num_feat * 3 * 9
+ return flops
+
+
+@ARCH_REGISTRY.register()
+class FocalIR(nn.Module):
+ r""" Focal Transformer: Focal Self-attention for Local-Global Interactions in Vision Transformer
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 224
+ patch_size (int | tuple(int)): Patch size. Default: 4
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Focal Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ use_shift (bool): Whether to use window shift proposed by Swin Transformer. We observe that using shift or not does not make difference to our Focal Transformer. Default: False
+ focal_stages (list): Which stages to perform focal attention. Default: [0, 1, 2, 3], means all stages
+ focal_levels (list): How many focal levels at all stages. Note that this excludes the finest-grain level. Default: [1, 1, 1, 1]
+ focal_windows (list): The focal window size at all stages. Default: [7, 5, 3, 1]
+ expand_stages (list): Which stages to expand the finest grain window. Default: [0, 1, 2, 3], means all stages
+ expand_sizes (list): The expand size for the finest grain level. Default: [3, 3, 3, 3]
+ expand_layer (str): Which layers we want to expand the window for the finest grain leve. This can save computational and memory cost without the loss of performance. Default: "all"
+ use_conv_embed (bool): Whether use convolutional embedding. We noted that using convolutional embedding usually improve the performance, but we do not use it by default. Default: False
+ use_layerscale (bool): Whether use layerscale proposed in CaiT. Default: False
+ layerscale_value (float): Value for layer scale. Default: 1e-4
+ use_pre_norm (bool): Whether use pre-norm in patch merging/embedding layer to control the feature magtigute. Default: False
+ """
+
+ def __init__(self,
+ img_size=224,
+ patch_size=4,
+ in_chans=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ ape=False,
+ patch_norm=True,
+ use_checkpoint=False,
+ use_shift=False,
+ focal_stages=[0, 1, 2, 3],
+ focal_levels=[1, 1, 1, 1],
+ focal_windows=[7, 5, 3, 1],
+ focal_pool="fc",
+ expand_stages=[0, 1, 2, 3],
+ expand_sizes=[3, 3, 3, 3],
+ expand_layer="all",
+ use_layerscale=False,
+ layerscale_value=1e-4,
+ upscale=2,
+ img_range=1.,
+ upsampler='pixelshuffle',
+ resi_connection='1conv',
+ **kwargs):
+ super(FocalIR, self).__init__()
+ num_in_ch = in_chans
+ num_out_ch = in_chans
+ num_feat = 64
+ self.img_range = img_range
+ if in_chans == 3:
+ rgb_mean = (0.4488, 0.4371, 0.4040)
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
+ else:
+ self.mean = torch.zeros(1, 1, 1, 1)
+ self.upscale = upscale
+ self.upsampler = upsampler
+ self.window_size = window_size
+
+ # ------------------------- 1, shallow feature extraction ------------------------- #
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
+
+ # ------------------------- 2, deep feature extraction ------------------------- #
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ # self.num_features = int(embed_dim * 2 ** (self.num_layers - 1))
+ self.num_features = embed_dim
+ self.mlp_ratio = mlp_ratio
+
+ # split image into patches using either non-overlapped embedding or overlapped embedding
+ self.patch_embed = PatchEmbed(
+ img_size=to_2tuple(img_size), patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.patches_resolution
+ self.patches_resolution = patches_resolution
+
+ # merge non-overlapping patches into image
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=embed_dim,
+ embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build Residual Focal Transformer blocks (RFTB)
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = RFTB(
+ dim=embed_dim,
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ pool_method=focal_pool if i_layer in focal_stages else "none",
+ downsample=None,
+ focal_level=focal_levels[i_layer],
+ focal_window=focal_windows[i_layer],
+ expand_size=expand_sizes[i_layer],
+ expand_layer=expand_layer,
+ use_shift=use_shift,
+ use_checkpoint=use_checkpoint,
+ use_layerscale=use_layerscale,
+ layerscale_value=layerscale_value,
+ img_size=img_size,
+ patch_size=patch_size,
+ resi_connection=resi_connection)
+ self.layers.append(layer)
+
+ self.norm = norm_layer(self.num_features)
+ # self.avgpool = nn.AdaptiveAvgPool1d(1)
+ # self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
+
+ # build the last conv layer in deep feature extraction
+ if resi_connection == '1conv':
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv_after_body = nn.Sequential(
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
+
+ # ------------------------- 3, high quality image reconstruction ------------------------- #
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ self.conv_before_upsample = nn.Sequential(
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR (to save parameters)
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
+ (patches_resolution[0], patches_resolution[1]))
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR (less artifacts)
+ assert self.upscale == 4, 'only support x4 now.'
+ self.conv_before_upsample = nn.Sequential(
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table', 'relative_position_bias_table_to_neighbors',
+ 'relative_position_bias_table_to_windows'}
+
+ def check_image_size(self, x):
+ _, _, h, w = x.size()
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
+ return x
+
+ def forward_features(self, x):
+ x_size = (x.shape[2], x.shape[3])
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+
+ for layer in self.layers:
+ x = layer(x, x_size)
+ x = self.norm(x) # B L C
+ x = self.patch_unembed(x, x_size)
+
+ return x
+
+ def forward(self, x):
+ H, W = x.shape[2:]
+ x = self.check_image_size(x)
+ self.mean = self.mean.type_as(x)
+ x = (x - self.mean) * self.img_range
+
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.conv_last(self.upsample(x))
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.upsample(x)
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ x_first = self.conv_first(x)
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
+ x = x + self.conv_last(res)
+
+ x = x / self.img_range + self.mean
+ return x[:, :, :H * self.upscale, :W * self.upscale]
+
+ def flops(self):
+ flops = 0
+ flops += self.patch_embed.flops()
+ for i, layer in enumerate(self.layers):
+ flops += layer.flops()
+ flops += self.num_features * self.patches_resolution[0] * self.patches_resolution[1] // (2 ** self.num_layers)
+ return flops
+
+
+def profile(model, inputs):
+ from torch.profiler import profile, record_function, ProfilerActivity
+ with profile(activities=[
+ ProfilerActivity.CPU, ProfilerActivity.CUDA], with_stack=True, record_shapes=True) as prof:
+ with record_function("model_inference"):
+ model(inputs)
+
+ print(prof.key_averages().table(sort_by="cuda_time_total", row_limit=15))
+
+
+if __name__ == '__main__':
+ img_hsize = 320
+ img_wsize = 180
+ x = torch.rand(1, 3, img_hsize, img_wsize).cuda()
+ model = FocalIR(img_size=(img_hsize, img_wsize), upscale=4, in_chans=3, embed_dim=60, depths=[6, 6, 6, 6], drop_path_rate=0.2,
+ focal_levels=[2, 2, 2, 2], expand_sizes=[3, 3, 3, 3], expand_layer="all",num_heads=[6, 6, 6, 6],
+ focal_windows=[7, 5, 3, 1], mlp_ratio=2, upsampler='pixelshuffle', window_size=4, resi_connection='1conv', use_shift=False).cuda()
+
+ model.eval()
+
+ #flops = model.flops()
+ #print(f"number of GFLOPs: {flops / 1e9}")
+
+ #n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
+ #print(f"number of params: {n_parameters}")
+
+ flops, params = hp(model, inputs=(x,))
+
+ print("FLOPs=", str(flops / 1e9) + '{}'.format("G"))
+ print("params=", str(params / 1e6) + '{}'.format("M"))
+
+ #profile(model, x)
diff --git a/basicsr/archs/rdswinir_arch.py b/basicsr/archs/rdswinir_arch.py
new file mode 100644
index 000000000..58eed35e8
--- /dev/null
+++ b/basicsr/archs/rdswinir_arch.py
@@ -0,0 +1,1016 @@
+# Modified from https://github.com/JingyunLiang/SwinIR
+# SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
+# Originally Written by Ze Liu, Modified by Jingyun Liang.
+
+import math
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as checkpoint
+
+from basicsr.utils.registry import ARCH_REGISTRY
+from basicsr.archs.arch_util import to_2tuple, trunc_normal_
+
+
+def drop_path(x, drop_prob: float = 0., training: bool = False):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
+ """
+ if drop_prob == 0. or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
+ random_tensor.floor_() # binarize
+ output = x.div(keep_prob) * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
+
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
+ """
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
+
+
+class Mlp(nn.Module):
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (b, h, w, c)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*b, window_size, window_size, c)
+ """
+ b, h, w, c = x.shape
+ x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
+ return windows
+
+
+def window_reverse(windows, window_size, h, w):
+ """
+ Args:
+ windows: (num_windows*b, window_size, window_size, c)
+ window_size (int): Window size
+ h (int): Height of image
+ w (int): Width of image
+
+ Returns:
+ x: (b, h, w, c)
+ """
+ b = int(windows.shape[0] / (h * w / window_size / window_size))
+ x = windows.view(b, h // window_size, w // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer('relative_position_index', relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*b, n, c)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ b_, n, c = x.shape
+ qkv = self.qkv(x).reshape(b_, n, 3, self.num_heads, c // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nw = mask.shape[0]
+ attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, n, n)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}'
+
+ def flops(self, n):
+ # calculate flops for 1 window with token length of n
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += n * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * n * (self.dim // self.num_heads) * n
+ # x = (attn @ v)
+ flops += self.num_heads * n * n * (self.dim // self.num_heads)
+ # x = self.proj(x)
+ flops += n * self.dim * self.dim
+ return flops
+
+
+class SwinTransformerBlock(nn.Module):
+ r""" Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self,
+ dim,
+ input_resolution,
+ num_heads,
+ window_size=7,
+ shift_size=0,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ if min(self.input_resolution) <= self.window_size:
+ # if window size is larger than input resolution, we don't partition windows
+ self.shift_size = 0
+ self.window_size = min(self.input_resolution)
+ assert 0 <= self.shift_size < self.window_size, 'shift_size must in 0-window_size'
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim,
+ window_size=to_2tuple(self.window_size),
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ attn_drop=attn_drop,
+ proj_drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ if self.shift_size > 0:
+ attn_mask = self.calculate_mask(self.input_resolution)
+ else:
+ attn_mask = None
+
+ self.register_buffer('attn_mask', attn_mask)
+
+ def calculate_mask(self, x_size):
+ # calculate attention mask for SW-MSA
+ h, w = x_size
+ img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
+ h_slices = (slice(0, -self.window_size), slice(-self.window_size,
+ -self.shift_size), slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size), slice(-self.window_size,
+ -self.shift_size), slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nw, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ return attn_mask
+
+ def forward(self, x, x_size):
+ h, w = x_size
+ b, _, c = x.shape
+ # assert seq_len == h * w, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(b, h, w, c)
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ else:
+ shifted_x = x
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nw*b, window_size, window_size, c
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, c) # nw*b, window_size*window_size, c
+
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
+ if self.input_resolution == x_size:
+ attn_windows = self.attn(x_windows, mask=self.attn_mask) # nw*b, window_size*window_size, c
+ else:
+ attn_windows = self.attn(x_windows, mask=self.calculate_mask(x_size).to(x.device))
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
+ shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+ x = x.view(b, h * w, c)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+ def extra_repr(self) -> str:
+ return (f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, '
+ f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}')
+
+ def flops(self):
+ flops = 0
+ h, w = self.input_resolution
+ # norm1
+ flops += self.dim * h * w
+ # W-MSA/SW-MSA
+ nw = h * w / self.window_size / self.window_size
+ flops += nw * self.attn.flops(self.window_size * self.window_size)
+ # mlp
+ flops += 2 * h * w * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * h * w
+ return flops
+
+
+class PatchMerging(nn.Module):
+ r""" Patch Merging Layer.
+
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.input_resolution = input_resolution
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x):
+ """
+ x: b, h*w, c
+ """
+ h, w = self.input_resolution
+ b, seq_len, c = x.shape
+ assert seq_len == h * w, 'input feature has wrong size'
+ assert h % 2 == 0 and w % 2 == 0, f'x size ({h}*{w}) are not even.'
+
+ x = x.view(b, h, w, c)
+
+ x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c
+ x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c
+ x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c
+ x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c
+ x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c
+ x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f'input_resolution={self.input_resolution}, dim={self.dim}'
+
+ def flops(self):
+ h, w = self.input_resolution
+ flops = h * w * self.dim
+ flops += (h // 2) * (w // 2) * 4 * self.dim * 2 * self.dim
+ return flops
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self,
+ dim,
+ input_resolution,
+ depth,
+ num_heads,
+ window_size,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False,
+ img_size=224,
+ patch_size=4):
+
+ super().__init__()
+ self.dim = dim
+ self.input_resolution = input_resolution
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(
+ dim=dim,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer) for i in range(depth)
+ ])
+ # define the learnable parameters
+ self.fuse_weight = []
+ for i in range(depth):
+ w = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
+ w.data.fill_(1)
+ w = w.to("cuda")
+ self.fuse_weight.append(w)
+
+ # Convolutional extractor 给稠密连接设计的特征融合器,可惜一样没什么屌用
+ # self.conv_body = [nn.Conv2d(dim * (i + 1), dim, 3, 1, 1) for i in range(depth + 1)]
+ # for c in self.conv_body:
+ # c.to("cuda")
+
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(input_resolution, dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, x_size):
+ # temp = [x]
+ # batch_size = x.shape[0]
+ count = 0
+ for blk in self.blocks:
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x)
+ else:
+ x = blk(x, x_size)
+ '''
+ ######稠密连接,可惜没什么屌用...######
+ if len(temp) > 1:
+ x = torch.cat((temp[:]))
+ # print(x.shape)
+ # 融合
+ self.patch_unembed(x, x_size)
+ x = x.view(batch_size, self.dim * count, x_size[0], x_size[1])
+ x = self.conv_body[count - 1](x)
+ # 变为patch embeding
+ # x = self.patch_embed(x, x_size)
+ x = x.flatten(2).transpose(1, 2)
+ else:
+ x = temp[0]
+ x = blk(x, x_size)
+ temp.append(x)
+ count += 1
+ # 对temp进行融合
+ x = torch.cat((temp[:]))
+ # print(x.shape)
+ # 融合
+ self.patch_unembed(x, x_size)
+ x = x.view(batch_size, self.dim * count, x_size[0], x_size[1])
+ x = self.conv_body[count - 1](x)
+ # 变为patch embeding
+ # x = self.patch_embed(x, x_size)
+ x = x.flatten(2).transpose(1, 2)
+ '''
+ if self.downsample is not None:
+ x = self.downsample(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return f'dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}'
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
+
+
+class RSTB(nn.Module):
+ """Residual Swin Transformer Block (RSTB).
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ img_size: Input image size.
+ patch_size: Patch size.
+ resi_connection: The convolutional block before residual connection.
+ """
+
+ def __init__(self,
+ dim,
+ input_resolution,
+ depth,
+ num_heads,
+ window_size,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False,
+ img_size=224,
+ patch_size=4,
+ resi_connection='1conv'):
+ super(RSTB, self).__init__()
+
+ self.dim = dim
+ self.input_resolution = input_resolution
+
+ self.residual_group = BasicLayer(
+ dim=dim,
+ input_resolution=input_resolution,
+ depth=depth,
+ num_heads=num_heads,
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path,
+ norm_layer=norm_layer,
+ downsample=downsample,
+ use_checkpoint=use_checkpoint,
+ img_size=img_size,
+ patch_size=patch_size)
+
+ if resi_connection == '1conv':
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv = nn.Sequential(
+ nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim, 3, 1, 1))
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
+
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
+
+ def forward(self, x, x_size):
+ return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
+
+ def flops(self):
+ flops = 0
+ flops += self.residual_group.flops()
+ h, w = self.input_resolution
+ flops += h * w * self.dim * self.dim * 9
+ flops += self.patch_embed.flops()
+ flops += self.patch_unembed.flops()
+
+ return flops
+
+
+class PatchEmbed(nn.Module):
+ r""" Image to Patch Embedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+ def flops(self):
+ flops = 0
+ h, w = self.img_size
+ if self.norm is not None:
+ flops += h * w * self.embed_dim
+ return flops
+
+
+class PatchUnEmbed(nn.Module):
+ r""" Image to Patch Unembedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ def forward(self, x, x_size):
+ x = x.transpose(1, 2).view(x.shape[0], self.embed_dim, x_size[0], x_size[1]) # b Ph*Pw c
+ return x
+
+ def flops(self):
+ flops = 0
+ return flops
+
+
+class Upsample(nn.Sequential):
+ """Upsample module.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+ """
+
+ def __init__(self, scale, num_feat):
+ m = []
+ if (scale & (scale - 1)) == 0: # scale = 2^n
+ for _ in range(int(math.log(scale, 2))):
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(2))
+ elif scale == 3:
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
+ m.append(nn.PixelShuffle(3))
+ else:
+ raise ValueError(f'scale {scale} is not supported. ' 'Supported scales: 2^n and 3.')
+ super(Upsample, self).__init__(*m)
+
+
+class UpsampleOneStep(nn.Sequential):
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
+ Used in lightweight SR to save parameters.
+
+ Args:
+ scale (int): Scale factor. Supported scales: 2^n and 3.
+ num_feat (int): Channel number of intermediate features.
+
+ """
+
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
+ self.num_feat = num_feat
+ self.input_resolution = input_resolution
+ m = []
+ m.append(nn.Conv2d(num_feat, (scale ** 2) * num_out_ch, 3, 1, 1))
+ m.append(nn.PixelShuffle(scale))
+ super(UpsampleOneStep, self).__init__(*m)
+
+ def flops(self):
+ h, w = self.input_resolution
+ flops = h * w * self.num_feat * 3 * 9
+ return flops
+
+
+@ARCH_REGISTRY.register()
+class RDSwinIR(nn.Module):
+ r""" SwinIR
+ A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 64
+ patch_size (int | tuple(int)): Patch size. Default: 1
+ in_chans (int): Number of input image channels. Default: 3
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
+ img_range: Image range. 1. or 255.
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
+ """
+
+ def __init__(self,
+ img_size=64,
+ patch_size=1,
+ in_chans=3,
+ embed_dim=96,
+ depths=(6, 6, 6, 6),
+ num_heads=(6, 6, 6, 6),
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ ape=False,
+ patch_norm=True,
+ use_checkpoint=False,
+ upscale=2,
+ img_range=1.,
+ upsampler='',
+ resi_connection='1conv',
+ **kwargs):
+ super(RDSwinIR, self).__init__()
+ num_in_ch = in_chans
+ num_out_ch = in_chans
+ num_feat = 64
+ self.img_range = img_range
+ if in_chans == 3:
+ rgb_mean = (0.4488, 0.4371, 0.4040)
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
+ else:
+ self.mean = torch.zeros(1, 1, 1, 1)
+ self.upscale = upscale
+ self.upsampler = upsampler
+
+ # ------------------------- 1, shallow feature extraction ------------------------- #
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
+
+ # ------------------------- 2, deep feature extraction ------------------------- #
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.num_features = embed_dim
+ self.mlp_ratio = mlp_ratio
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=embed_dim,
+ embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.patches_resolution
+ self.patches_resolution = patches_resolution
+
+ # merge non-overlapping patches into image
+ self.patch_unembed = PatchUnEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=embed_dim,
+ embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build Residual Swin Transformer blocks (RSTB)
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = RSTB(
+ dim=embed_dim,
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
+ norm_layer=norm_layer,
+ downsample=None,
+ use_checkpoint=use_checkpoint,
+ img_size=img_size,
+ patch_size=patch_size,
+ resi_connection=resi_connection)
+ self.layers.append(layer)
+ self.norm = norm_layer(self.num_features)
+ # define the learnable parameters
+ self.fuse_wight = []
+ for i in range(self.num_layers):
+ w = torch.nn.Parameter(torch.FloatTensor(1), requires_grad=True)
+ w.data.fill_(1)
+ w = w.to("cuda")
+ self.fuse_wight.append(w)
+
+ # build the last conv layer in deep feature extraction
+ if resi_connection == '1conv':
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
+ elif resi_connection == '3conv':
+ # to save parameters and memory
+ self.conv_after_body = nn.Sequential(
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
+
+ # ------------------------- 3, high quality image reconstruction ------------------------- #
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ self.conv_before_upsample = nn.Sequential(
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
+ self.upsample = Upsample(upscale, num_feat)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR (to save parameters)
+ self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
+ (patches_resolution[0], patches_resolution[1]))
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR (less artifacts)
+ assert self.upscale == 4, 'only support x4 now.'
+ self.conv_before_upsample = nn.Sequential(
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {'absolute_pos_embed'}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {'relative_position_bias_table'}
+
+ def forward_features(self, x):
+ x_size = (x.shape[2], x.shape[3])
+ x = self.patch_embed(x)
+ if self.ape:
+ x = x + self.absolute_pos_embed
+ x = self.pos_drop(x)
+ count = 0
+ for layer in self.layers:
+ x = layer(x, x_size) + self.fuse_wight[count] * x
+ count += 1
+
+ x = self.norm(x) # b seq_len c
+ x = self.patch_unembed(x, x_size)
+
+ return x
+
+ def forward(self, x):
+ self.mean = self.mean.type_as(x)
+ x = (x - self.mean) * self.img_range
+
+ if self.upsampler == 'pixelshuffle':
+ # for classical SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.conv_last(self.upsample(x))
+ elif self.upsampler == 'pixelshuffledirect':
+ # for lightweight SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.upsample(x)
+ elif self.upsampler == 'nearest+conv':
+ # for real-world SR
+ x = self.conv_first(x)
+ x = self.conv_after_body(self.forward_features(x)) + x
+ x = self.conv_before_upsample(x)
+ x = self.lrelu(self.conv_up1(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
+ else:
+ # for image denoising and JPEG compression artifact reduction
+ x_first = self.conv_first(x)
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
+ x = x + self.conv_last(res)
+
+ x = x / self.img_range + self.mean
+
+ return x
+
+ def flops(self):
+ flops = 0
+ h, w = self.patches_resolution
+ flops += h * w * 3 * self.embed_dim * 9
+ flops += self.patch_embed.flops()
+ for layer in self.layers:
+ flops += layer.flops()
+ flops += h * w * 3 * self.embed_dim * self.embed_dim
+ flops += self.upsample.flops()
+ return flops
+
+
+if __name__ == '__main__':
+ upscale = 4
+ window_size = 8
+ height = (1024 // upscale // window_size + 1) * window_size
+ width = (720 // upscale // window_size + 1) * window_size
+ model = RDSwinIR(
+ upscale=4,
+ img_size=(height, width),
+ window_size=window_size,
+ img_range=1.,
+ depths=[6, 6, 6, 6, 6, 6],
+ embed_dim=180,
+ num_heads=[6, 6, 6, 6, 6, 6],
+ mlp_ratio=2,
+ upsampler='pixelshuffledirect')
+ print(model)
+ print(height, width, model.flops() / 1e9)
+
+ x = torch.randn((1, 3, height, width))
+ x = model(x)
+ print(x.shape)
diff --git a/basicsr/archs/swinir_arch.py b/basicsr/archs/swinir_arch.py
index f3e9e2c54..c688a600b 100644
--- a/basicsr/archs/swinir_arch.py
+++ b/basicsr/archs/swinir_arch.py
@@ -6,9 +6,10 @@
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
+from thop import profile as hp
from basicsr.utils.registry import ARCH_REGISTRY
-from .arch_util import to_2tuple, trunc_normal_
+from basicsr.archs.arch_util import to_2tuple, trunc_normal_
def drop_path(x, drop_prob: float = 0., training: bool = False):
@@ -935,11 +936,13 @@ def flops(self):
if __name__ == '__main__':
upscale = 4
- window_size = 8
- height = (1024 // upscale // window_size + 1) * window_size
- width = (720 // upscale // window_size + 1) * window_size
+ window_size = 4
+ # height = (1024 // upscale // window_size + 1) * window_size
+ # width = (720 // upscale // window_size + 1) * window_size
+ height = 320
+ width = 180
model = SwinIR(
- upscale=2,
+ upscale=4,
img_size=(height, width),
window_size=window_size,
img_range=1.,
@@ -948,9 +951,13 @@ def flops(self):
num_heads=[6, 6, 6, 6],
mlp_ratio=2,
upsampler='pixelshuffledirect')
- print(model)
- print(height, width, model.flops() / 1e9)
-
+ # print(model)
+ # print(height, width, model.flops() / 1e9)
+ model.eval()
x = torch.randn((1, 3, height, width))
- x = model(x)
- print(x.shape)
+ # x = model(x)
+ # print(x.shape)
+ flops, params = hp(model, inputs=(x,))
+
+ print("FLOPs=", str(flops / 1e9) + '{}'.format("G"))
+ print("params=", str(params / 1e6) + '{}'.format("M"))
diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py
index ee15d5a83..bfd99fad9 100644
--- a/basicsr/data/paired_image_dataset.py
+++ b/basicsr/data/paired_image_dataset.py
@@ -88,7 +88,7 @@ def __getitem__(self, index):
# flip, rotation
img_gt, img_lq = augment([img_gt, img_lq], self.opt['use_flip'], self.opt['use_rot'])
- if self.opt['color'] == 'y':
+ if 'color' in self.opt and self.opt['color'] == 'y':
img_gt = rgb2ycbcr(img_gt, y_only=True)[..., None]
img_lq = rgb2ycbcr(img_lq, y_only=True)[..., None]
diff --git a/basicsr/metrics/__init__.py b/basicsr/metrics/__init__.py
index 4fb044a93..1f1e355fe 100644
--- a/basicsr/metrics/__init__.py
+++ b/basicsr/metrics/__init__.py
@@ -3,8 +3,9 @@
from basicsr.utils.registry import METRIC_REGISTRY
from .niqe import calculate_niqe
from .psnr_ssim import calculate_psnr, calculate_ssim
+from .lpips import calculate_lpips
-__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe']
+__all__ = ['calculate_psnr', 'calculate_ssim', 'calculate_niqe', 'calculate_lpips']
def calculate_metric(data, opt):
diff --git a/basicsr/metrics/lpips.py b/basicsr/metrics/lpips.py
new file mode 100644
index 000000000..8d5a73014
--- /dev/null
+++ b/basicsr/metrics/lpips.py
@@ -0,0 +1,65 @@
+from torchvision.transforms.functional import normalize
+from basicsr.utils import img2tensor
+import lpips
+import numpy as np
+
+from basicsr.metrics.metric_util import reorder_image, to_y_channel
+from basicsr.utils.registry import METRIC_REGISTRY
+import torch
+
+@METRIC_REGISTRY.register()
+def calculate_lpips(img, img2, crop_border, input_order='HWC', test_y_channel=False, **kwargs):
+
+ """Calculate LPIPS.
+ Ref: https://github.com/xinntao/BasicSR/pull/367
+ Args:
+ img (ndarray): Images with range [0, 255].
+ img2 (ndarray): Images with range [0, 255].
+ crop_border (int): Cropped pixels in each edge of an image. These
+ pixels are not involved in the PSNR calculation.
+ input_order (str): Whether the input order is 'HWC' or 'CHW'.
+ Default: 'HWC'.
+ test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
+
+ Returns:
+ float: LPIPS result.
+ """
+ assert img.shape == img2.shape, (f'Image shapes are differnet: {img.shape}, {img2.shape}.')
+ if input_order not in ['HWC', 'CHW']:
+ raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are ' '"HWC" and "CHW"')
+ img = reorder_image(img, input_order=input_order)
+ img2 = reorder_image(img2, input_order=input_order)
+ img = img.astype(np.float64)
+ img2 = img2.astype(np.float64)
+
+ if crop_border != 0:
+ img = img[crop_border:-crop_border, crop_border:-crop_border, ...]
+ img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
+
+ if test_y_channel:
+ img = to_y_channel(img)
+ img2 = to_y_channel(img2)
+
+ # start calculating LPIPS metrics
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ loss_fn_vgg = lpips.LPIPS(net='vgg', verbose=False).to(DEVICE) # RGB, normalized to [-1,1]
+
+ mean = [0.5, 0.5, 0.5]
+ std = [0.5, 0.5, 0.5]
+
+ img_gt = img2 / 255.
+ img_restored = img / 255.
+
+ img_gt, img_restored = img2tensor([img_gt, img_restored], bgr2rgb=True, float32=True)
+ # norm to [-1, 1]
+ normalize(img_gt, mean, std, inplace=True)
+ normalize(img_restored, mean, std, inplace=True)
+
+ # calculate lpips
+ img_gt = img_gt.to(DEVICE)
+ img_restored = img_restored.to(DEVICE)
+ loss_fn_vgg.eval()
+ lpips_val = loss_fn_vgg(img_restored.unsqueeze(0), img_gt.unsqueeze(0))
+
+ return lpips_val.detach().cpu().numpy().mean()
+
diff --git a/basicsr/models/focalir_model.py b/basicsr/models/focalir_model.py
new file mode 100644
index 000000000..b0d268f2d
--- /dev/null
+++ b/basicsr/models/focalir_model.py
@@ -0,0 +1,33 @@
+import torch
+from torch.nn import functional as F
+
+from basicsr.utils.registry import MODEL_REGISTRY
+from .sr_model import SRModel
+
+
+@MODEL_REGISTRY.register()
+class FocalIRModel(SRModel):
+
+ def test(self):
+ # pad to multiplication of window_size
+ window_size = self.opt['network_g']['window_size']
+ scale = self.opt.get('scale', 1)
+ mod_pad_h, mod_pad_w = 0, 0
+ _, _, h, w = self.lq.size()
+ if h % window_size != 0:
+ mod_pad_h = window_size - h % window_size
+ if w % window_size != 0:
+ mod_pad_w = window_size - w % window_size
+ img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
+ if hasattr(self, 'net_g_ema'):
+ self.net_g_ema.eval()
+ with torch.no_grad():
+ self.output = self.net_g_ema(img)
+ else:
+ self.net_g.eval()
+ with torch.no_grad():
+ self.output = self.net_g(img)
+ self.net_g.train()
+
+ _, _, h, w = self.output.size()
+ self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]
diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py
index fdbb88678..54c80bd6d 100644
--- a/basicsr/models/sr_model.py
+++ b/basicsr/models/sr_model.py
@@ -138,10 +138,11 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
with_metrics = self.opt['val'].get('metrics') is not None
use_pbar = self.opt['val'].get('pbar', False)
- if with_metrics and not hasattr(self, 'metric_results'): # only execute in the first run
- self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
- # initialize the best metric results for each dataset_name (supporting multiple validation datasets)
- self._initialize_best_metric_results(dataset_name)
+ if with_metrics:
+ if not hasattr(self, 'metric_results'): # only execute in the first run
+ self.metric_results = {metric: 0 for metric in self.opt['val']['metrics'].keys()}
+ # initialize the best metric results for each dataset_name (supporting multiple validation datasets)
+ self._initialize_best_metric_results(dataset_name)
# zero self.metric_results
if with_metrics:
self.metric_results = {metric: 0 for metric in self.metric_results}
diff --git a/basicsr/models/swinirgan_model.py b/basicsr/models/swinirgan_model.py
new file mode 100644
index 000000000..0b2c0de34
--- /dev/null
+++ b/basicsr/models/swinirgan_model.py
@@ -0,0 +1,107 @@
+import torch
+from collections import OrderedDict
+from torch.nn import functional as F
+from basicsr.utils.registry import MODEL_REGISTRY
+from .srgan_model import SRGANModel
+
+
+@MODEL_REGISTRY.register()
+class SwinIRGANModel(SRGANModel):
+ """SwinIRGAN model for single image super-resolution."""
+
+ def test(self):
+ # pad to multiplication of window_size
+ window_size = self.opt['network_g']['window_size']
+ scale = self.opt.get('scale', 1)
+ mod_pad_h, mod_pad_w = 0, 0
+ _, _, h, w = self.lq.size()
+ if h % window_size != 0:
+ mod_pad_h = window_size - h % window_size
+ if w % window_size != 0:
+ mod_pad_w = window_size - w % window_size
+ img = F.pad(self.lq, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
+ if hasattr(self, 'net_g_ema'):
+ self.net_g_ema.eval()
+ with torch.no_grad():
+ self.output = self.net_g_ema(img)
+ else:
+ self.net_g.eval()
+ with torch.no_grad():
+ self.output = self.net_g(img)
+ self.net_g.train()
+
+ _, _, h, w = self.output.size()
+ self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]
+
+ def optimize_parameters(self, current_iter):
+ # optimize net_g
+ for p in self.net_d.parameters():
+ p.requires_grad = False
+
+ self.optimizer_g.zero_grad()
+ self.output = self.net_g(self.lq)
+
+ l_g_total = 0
+ loss_dict = OrderedDict()
+ if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+ # pixel loss
+ if self.cri_pix:
+ l_g_pix = self.cri_pix(self.output, self.gt)
+ l_g_total += l_g_pix
+ loss_dict['l_g_pix'] = l_g_pix
+ # perceptual loss
+ if self.cri_perceptual:
+ l_g_percep, l_g_style = self.cri_perceptual(self.output, self.gt)
+ if l_g_percep is not None:
+ l_g_total += l_g_percep
+ loss_dict['l_g_percep'] = l_g_percep
+ if l_g_style is not None:
+ l_g_total += l_g_style
+ loss_dict['l_g_style'] = l_g_style
+ # gan loss (relativistic gan)
+ real_d_pred = self.net_d(self.gt).detach()
+ fake_g_pred = self.net_d(self.output)
+ l_g_real = self.cri_gan(real_d_pred - torch.mean(fake_g_pred), False, is_disc=False)
+ l_g_fake = self.cri_gan(fake_g_pred - torch.mean(real_d_pred), True, is_disc=False)
+ l_g_gan = (l_g_real + l_g_fake) / 2
+
+ l_g_total += l_g_gan
+ loss_dict['l_g_gan'] = l_g_gan
+
+ l_g_total.backward()
+ self.optimizer_g.step()
+
+ # optimize net_d
+ for p in self.net_d.parameters():
+ p.requires_grad = True
+
+ self.optimizer_d.zero_grad()
+ # gan loss (relativistic gan)
+
+ # In order to avoid the error in distributed training:
+ # "Error detected in CudnnBatchNormBackward: RuntimeError: one of
+ # the variables needed for gradient computation has been modified by
+ # an inplace operation",
+ # we separate the backwards for real and fake, and also detach the
+ # tensor for calculating mean.
+
+ # real
+ fake_d_pred = self.net_d(self.output).detach()
+ real_d_pred = self.net_d(self.gt)
+ l_d_real = self.cri_gan(real_d_pred - torch.mean(fake_d_pred), True, is_disc=True) * 0.5
+ l_d_real.backward()
+ # fake
+ fake_d_pred = self.net_d(self.output.detach())
+ l_d_fake = self.cri_gan(fake_d_pred - torch.mean(real_d_pred.detach()), False, is_disc=True) * 0.5
+ l_d_fake.backward()
+ self.optimizer_d.step()
+
+ loss_dict['l_d_real'] = l_d_real
+ loss_dict['l_d_fake'] = l_d_fake
+ loss_dict['out_d_real'] = torch.mean(real_d_pred.detach())
+ loss_dict['out_d_fake'] = torch.mean(fake_d_pred.detach())
+
+ self.log_dict = self.reduce_loss_dict(loss_dict)
+
+ if self.ema_decay > 0:
+ self.model_ema(decay=self.ema_decay)
diff --git a/basicsr/models/video_base_model.py b/basicsr/models/video_base_model.py
index 2240df45c..9f7993a15 100644
--- a/basicsr/models/video_base_model.py
+++ b/basicsr/models/video_base_model.py
@@ -24,14 +24,15 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
# 'folder1': tensor (num_frame x len(metrics)),
# 'folder2': tensor (num_frame x len(metrics))
# }
- if with_metrics and not hasattr(self, 'metric_results'): # only execute in the first run
- self.metric_results = {}
- num_frame_each_folder = Counter(dataset.data_info['folder'])
- for folder, num_frame in num_frame_each_folder.items():
- self.metric_results[folder] = torch.zeros(
- num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
- # initialize the best metric results
- self._initialize_best_metric_results(dataset_name)
+ if with_metrics:
+ if not hasattr(self, 'metric_results'): # only execute in the first run
+ self.metric_results = {}
+ num_frame_each_folder = Counter(dataset.data_info['folder'])
+ for folder, num_frame in num_frame_each_folder.items():
+ self.metric_results[folder] = torch.zeros(
+ num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
+ # initialize the best metric results
+ self._initialize_best_metric_results(dataset_name)
# zero self.metric_results
rank, world_size = get_dist_info()
if with_metrics:
diff --git a/basicsr/models/video_recurrent_model.py b/basicsr/models/video_recurrent_model.py
index 49fa0e4e2..796ee57d5 100644
--- a/basicsr/models/video_recurrent_model.py
+++ b/basicsr/models/video_recurrent_model.py
@@ -72,14 +72,15 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
# 'folder1': tensor (num_frame x len(metrics)),
# 'folder2': tensor (num_frame x len(metrics))
# }
- if with_metrics and not hasattr(self, 'metric_results'): # only execute in the first run
- self.metric_results = {}
- num_frame_each_folder = Counter(dataset.data_info['folder'])
- for folder, num_frame in num_frame_each_folder.items():
- self.metric_results[folder] = torch.zeros(
- num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
- # initialize the best metric results
- self._initialize_best_metric_results(dataset_name)
+ if with_metrics:
+ if not hasattr(self, 'metric_results'): # only execute in the first run
+ self.metric_results = {}
+ num_frame_each_folder = Counter(dataset.data_info['folder'])
+ for folder, num_frame in num_frame_each_folder.items():
+ self.metric_results[folder] = torch.zeros(
+ num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
+ # initialize the best metric results
+ self._initialize_best_metric_results(dataset_name)
# zero self.metric_results
rank, world_size = get_dist_info()
if with_metrics:
diff --git a/options/test/FocalIR/test_FocalIR_x4.yml b/options/test/FocalIR/test_FocalIR_x4.yml
new file mode 100644
index 000000000..7ca663cc0
--- /dev/null
+++ b/options/test/FocalIR/test_FocalIR_x4.yml
@@ -0,0 +1,99 @@
+name: FocalIR_SRx4_DIV2K
+model_type: FocalIRModel
+scale: 4
+num_gpu: 1 # set num_gpu: 0 for cpu mode
+manual_seed: 0
+
+datasets:
+ test_1: # the 1st test dataset
+ name: Set5
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/Set5/HR
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+ color: n
+ test_2: # the 2nd test dataset
+ name: Set14
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/Set14/HR
+ dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+ color: n
+ test_3:
+ name: B100
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/B100/HR
+ dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+ color: n
+ test_4:
+ name: Urban100
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/Urban100/HR
+ dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+ color: n
+ test_5:
+ name: Manga109
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/Manga109/HR
+ dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+ color: n
+
+# network structures
+network_g:
+ type: FocalIR
+ upscale: 4
+ in_chans: 3
+ img_size: 48
+ window_size: 4
+ img_range: 1.
+ depths: [6, 6, 6, 6]
+ embed_dim: 60
+ num_heads: [6, 6, 6, 6]
+ drop_path_rate: 0.2
+ focal_levels: [2, 2, 2, 2]
+ expand_sizes: [3, 3, 3, 3]
+ expand_layer: "all"
+ focal_windows: [7, 5, 3, 1]
+ mlp_ratio: 2
+ upsampler: 'pixelshuffle'
+ resi_connection: '1conv'
+ use_shift: False
+
+# path
+path:
+ pretrain_network_g: experiments/train_FocalIR_SRx4_s48g96_DIV2K/models/net_g_5000.pth
+ #pretrain_network_g: experiments/train_SwinIR_SRx4_DIV2K/models/net_g_latest.pth
+ strict_load_g: true
+
+# validation settings
+val:
+ save_img: true
+ suffix: ~ # add suffix to saved images, if None, use exp name
+
+ metrics:
+ psnr: # metric name, can be arbitrary
+ type: calculate_psnr
+ crop_border: 4
+ test_y_channel: true
+ ssim:
+ type: calculate_ssim
+ crop_border: 4
+ test_y_channel: true
+ lpips:
+ type: calculate_lpips
+ crop_border: 4
+ test_y_channel: false
+ better: lower
diff --git a/options/test/SwinIR/test_SwinIR_x4.yml b/options/test/SwinIR/test_SwinIR_x4.yml
new file mode 100644
index 000000000..fe8422983
--- /dev/null
+++ b/options/test/SwinIR/test_SwinIR_x4.yml
@@ -0,0 +1,92 @@
+name: SwinIR_SRx4_DIV2K
+model_type: SwinIRModel
+scale: 4
+num_gpu: 1 # set num_gpu: 0 for cpu mode
+manual_seed: 0
+
+datasets:
+ test_1: # the 1st test dataset
+ name: Set5
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/Set5/HR
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+ color: n
+ test_2: # the 2nd test dataset
+ name: Set14
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/Set14/HR
+ dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+ color: n
+ test_3:
+ name: B100
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/B100/HR
+ dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+ color: n
+ test_4:
+ name: Urban100
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/Urban100/HR
+ dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+ color: n
+ test_5:
+ name: Manga109
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/Manga109/HR
+ dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+ color: n
+
+# network structures
+network_g:
+ type: SwinIR
+ upscale: 4
+ in_chans: 3
+ img_size: 48
+ window_size: 8
+ img_range: 1.
+ depths: [6, 6, 6, 6, 6, 6]
+ embed_dim: 180
+ num_heads: [6, 6, 6, 6, 6, 6]
+ mlp_ratio: 2
+ upsampler: 'pixelshuffle'
+ resi_connection: '1conv'
+
+# path
+path:
+ pretrain_network_g: experiments/pretrained_models/SwinIR/001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.pth
+ #pretrain_network_g: experiments/train_SwinIR_SRx4_DIV2K/models/net_g_latest.pth
+ strict_load_g: true
+
+# validation settings
+val:
+ save_img: true
+ suffix: ~ # add suffix to saved images, if None, use exp name
+
+ metrics:
+ psnr: # metric name, can be arbitrary
+ type: calculate_psnr
+ crop_border: 4
+ test_y_channel: true
+ ssim:
+ type: calculate_ssim
+ crop_border: 4
+ test_y_channel: true
+ lpips:
+ type: calculate_lpips
+ crop_border: 4
+ test_y_channel: false
diff --git a/options/test/SwinIR/test_SwinIR_x4_myself.yml b/options/test/SwinIR/test_SwinIR_x4_myself.yml
new file mode 100644
index 000000000..476398bf7
--- /dev/null
+++ b/options/test/SwinIR/test_SwinIR_x4_myself.yml
@@ -0,0 +1,94 @@
+name: SwinIR_SRx4_DIV2K
+model_type: SwinIRModel
+scale: 4
+num_gpu: 1 # set num_gpu: 0 for cpu mode
+manual_seed: 0
+
+datasets:
+ test_1: # the 1st test dataset
+ name: Set5
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/Set5/HR
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+ color: n
+ test_2: # the 2nd test dataset
+ name: Set14
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/Set14/HR
+ dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+ color: n
+ test_3:
+ name: B100
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/B100/HR
+ dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+ color: n
+ test_4:
+ name: Urban100
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/Urban100/HR
+ dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+ color: n
+ test_5:
+ name: Manga109
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/Manga109/HR
+ dataroot_lq: datasets/benchmark/Manga109/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+ color: n
+
+# network structures
+network_g:
+ type: SwinIR
+ upscale: 4
+ in_chans: 3
+ #img_size: 32
+ img_size: 48
+ window_size: 8
+ img_range: 1.
+ depths: [6, 6, 6, 6, 6, 6]
+ embed_dim: 180
+ num_heads: [6, 6, 6, 6, 6, 6]
+ mlp_ratio: 2
+ upsampler: 'pixelshuffle'
+ resi_connection: '1conv'
+
+# path
+path:
+ #pretrain_network_g: experiments/pretrained_models/SwinIR/001_classicalSR_DIV2K_s48w8_SwinIR-M_x4.pth
+ pretrain_network_g: experiments/train_SwinIR_SRx4_DIV2K/models/net_g_latest.pth
+ strict_load_g: true
+
+# validation settings
+val:
+ save_img: true
+ suffix: ~ # add suffix to saved images, if None, use exp name
+
+ metrics:
+ psnr: # metric name, can be arbitrary
+ type: calculate_psnr
+ crop_border: 4
+ test_y_channel: true
+ ssim:
+ type: calculate_ssim
+ crop_border: 4
+ test_y_channel: true
+ lpips:
+ type: calculate_lpips
+ crop_border: 4
+ test_y_channel: true
+ better: lower
diff --git a/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu_RGB.yml b/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu_RGB.yml
new file mode 100644
index 000000000..d9aee8c12
--- /dev/null
+++ b/options/train/ECBSR/train_ECBSR_x4_m4c16_prelu_RGB.yml
@@ -0,0 +1,139 @@
+# general settings
+name: 100_train_ECBSR_x4_m4c16_prelu_RGB
+model_type: SRModel
+scale: 4
+num_gpu: 1 # set num_gpu: 0 for cpu mode
+manual_seed: 0
+
+# dataset and data loader settings
+datasets:
+ train:
+ name: DIV2K
+ type: PairedImageDataset
+ # It is strongly recommended to use lmdb for faster IO speed, especially for small networks
+ dataroot_gt: datasets/DF2K/DIV2K_train_HR_sub.lmdb
+ dataroot_lq: datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub.lmdb
+ meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt
+ filename_tmpl: '{}'
+ io_backend:
+ type: lmdb
+
+ gt_size: 256
+ use_flip: true
+ use_rot: true
+
+ # data loader
+ use_shuffle: true
+ num_worker_per_gpu: 12
+ batch_size_per_gpu: 32
+ dataset_enlarge_ratio: 10
+ prefetch_mode: ~
+
+ # we use multiple validation datasets. The SR benchmark datasets can be download from: https://cv.snu.ac.kr/research/EDSR/benchmark.tar
+ val:
+ name: Set5
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/Set5/HR
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+
+ val_2:
+ name: Set14
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/Set14/HR
+ dataroot_lq: datasets/benchmark/Set14/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+
+ val_3:
+ name: B100
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/B100/HR
+ dataroot_lq: datasets/benchmark/B100/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+
+ val_4:
+ name: Urban100
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/Urban100/HR
+ dataroot_lq: datasets/benchmark/Urban100/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+
+# network structures
+network_g:
+ type: ECBSR
+ num_in_ch: 3
+ num_out_ch: 3
+ num_block: 4
+ num_channel: 16
+ with_idt: False
+ act_type: prelu
+ scale: 4
+
+# path
+path:
+ pretrain_network_g: ~
+ strict_load_g: true
+ resume_state: ~
+
+# training settings
+train:
+ ema_decay: 0
+ optim_g:
+ type: Adam
+ lr: !!float 5e-4
+ weight_decay: 0
+ betas: [0.9, 0.99]
+
+ scheduler:
+ type: MultiStepLR
+ milestones: [1600000]
+ gamma: 1
+
+ total_iter: 1600000
+ warmup_iter: -1 # no warm up
+
+ # losses
+ pixel_opt:
+ type: L1Loss
+ loss_weight: 1.0
+ reduction: mean
+
+# validation settings
+val:
+ val_freq: !!float 1600 # the same as the original setting. # TODO: Can be larger
+ save_img: false
+ pbar: False
+
+ metrics:
+ psnr:
+ type: calculate_psnr
+ crop_border: 4
+ test_y_channel: true
+ better: higher # the higher, the better. Default: higher
+ ssim:
+ type: calculate_ssim
+ crop_border: 4
+ test_y_channel: true
+ better: higher # the higher, the better. Default: higher
+
+# logging settings
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 1600
+ use_tb_logger: true
+ wandb:
+ project: ~
+ resume_id: ~
+
+# dist training settings
+dist_params:
+ backend: nccl
+ port: 29500
diff --git a/options/train/ESRGAN/train_ESRGAN_x4.yml b/options/train/ESRGAN/train_ESRGAN_x4.yml
index 5310e7a14..f13ec6739 100644
--- a/options/train/ESRGAN/train_ESRGAN_x4.yml
+++ b/options/train/ESRGAN/train_ESRGAN_x4.yml
@@ -20,6 +20,7 @@ datasets:
type: disk
# (for lmdb)
# type: lmdb
+ color: n
gt_size: 128
use_flip: true
@@ -29,16 +30,18 @@ datasets:
use_shuffle: true
num_worker_per_gpu: 6
batch_size_per_gpu: 16
- dataset_enlarge_ratio: 100
+ dataset_enlarge_ratio: 1
prefetch_mode: ~
val:
- name: Set14
+ name: Set5
type: PairedImageDataset
- dataroot_gt: datasets/Set14/GTmod12
- dataroot_lq: datasets/Set14/LRbicx4
+ dataroot_gt: datasets/benchmark/Set5/HR
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
+ filename_tmpl: '{}x4'
io_backend:
type: disk
+ color: n
# network structures
network_g:
@@ -55,7 +58,7 @@ network_d:
# path
path:
- pretrain_network_g: experiments/051_RRDBNet_PSNR_x4_f64b23_DIV2K_1000k_B16G1_wandb/models/net_g_1000000.pth
+ pretrain_network_g: ~
strict_load_g: true
resume_state: ~
diff --git a/options/train/FocalIR/train_focalIR_SRx4_scratch.yml b/options/train/FocalIR/train_focalIR_SRx4_scratch.yml
new file mode 100644
index 000000000..5d0c1944f
--- /dev/null
+++ b/options/train/FocalIR/train_focalIR_SRx4_scratch.yml
@@ -0,0 +1,128 @@
+# general settings
+name: train_test_FocalIR_SRx4_s48g96_DIV2K
+model_type: FocalIRModel
+scale: 4
+num_gpu: 1
+manual_seed: 0
+
+# dataset and data loader settings
+datasets:
+ train:
+ name: DIV2K
+ type: PairedImageDataset
+ #dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub
+ #dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub
+ dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb
+ dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb
+ #meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt
+ filename_tmpl: '{}'
+ io_backend:
+ type: lmdb
+
+ gt_size: 96
+ use_flip: true
+ use_rot: true
+ color: n
+
+ # data loader
+ use_shuffle: true
+ num_worker_per_gpu: 6
+ batch_size_per_gpu: 16
+ dataset_enlarge_ratio: 1
+ #prefetch_mode: ~
+ prefetch_mode: cuda
+ pin_memory: true
+
+ val:
+ name: Set5
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/Set5/HR
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+ color: n
+
+# network structures
+network_g:
+ type: FocalIR
+ upscale: 4
+ in_chans: 3
+ img_size: 48
+ window_size: 4
+ img_range: 1.
+ depths: [6, 6, 6, 6]
+ embed_dim: 60
+ num_heads: [6, 6, 6, 6]
+ drop_path_rate: 0.2
+ focal_levels: [2, 2, 2, 2]
+ expand_sizes: [3, 3, 3, 3]
+ expand_layer: "all"
+ focal_windows: [7, 5, 3, 1]
+ mlp_ratio: 2
+ upsampler: 'pixelshuffle'
+ resi_connection: '1conv'
+ use_shift: False
+
+# path
+path:
+ pretrain_network_g: ~
+ strict_load_g: true
+ resume_state: ~
+
+# training settings
+train:
+ ema_decay: 0.999
+ optim_g:
+ type: Adam
+ lr: !!float 2e-4
+ weight_decay: 0
+ betas: [0.9, 0.99]
+
+ scheduler:
+ type: MultiStepLR
+ milestones: [250000, 400000, 450000, 475000]
+ gamma: 0.5
+
+ total_iter: 500000
+ warmup_iter: -1 # no warm up
+
+ # losses
+ pixel_opt:
+ type: L1Loss
+ loss_weight: 1.0
+ reduction: mean
+
+# validation settings
+val:
+ val_freq: !!float 5e3
+ save_img: false
+
+ metrics:
+ psnr: # metric name, can be arbitrary
+ type: calculate_psnr
+ crop_border: 4
+ test_y_channel: true
+ ssim:
+ type: calculate_ssim
+ crop_border: 4
+ test_y_channel: true
+ lpips:
+ type: calculate_lpips
+ crop_border: 4
+ test_y_channel: false
+ better: lower
+
+# logging settings
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 5e3
+ use_tb_logger: true
+ wandb:
+ project: ~
+ resume_id: ~
+
+# dist training settings
+#dist_params:
+# backend: nccl
+# port: 29500
diff --git a/options/train/FocalIR/train_focalIR_SRx4_scratch_plus.yml b/options/train/FocalIR/train_focalIR_SRx4_scratch_plus.yml
new file mode 100644
index 000000000..c5fd21dc1
--- /dev/null
+++ b/options/train/FocalIR/train_focalIR_SRx4_scratch_plus.yml
@@ -0,0 +1,129 @@
+# general settings
+name: train_FocalIR_SRx4_s48g96_DIV2K
+model_type: FocalIRModel
+scale: 4
+num_gpu: 1
+manual_seed: 0
+
+# dataset and data loader settings
+datasets:
+ train:
+ name: DIV2K
+ type: PairedImageDataset
+ #dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub
+ #dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub
+ dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb
+ dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb
+ #meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt
+ filename_tmpl: '{}'
+ io_backend:
+ type: lmdb
+
+ gt_size: 96
+ use_flip: true
+ use_rot: true
+ color: n
+
+ # data loader
+ use_shuffle: true
+ num_worker_per_gpu: 6
+ batch_size_per_gpu: 16
+ dataset_enlarge_ratio: 1
+ prefetch_mode: ~
+ #prefetch_mode: cuda
+ #pin_memory: true
+
+ val:
+ name: Set5
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/Set5/HR
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+ color: n
+
+# network structures
+network_g:
+ type: FocalIR
+ upscale: 4
+ in_chans: 3
+ img_size: 48
+ window_size: 4
+ img_range: 1.
+ depths: [6, 6, 6, 6, 6, 6]
+ embed_dim: 180
+ num_heads: [6, 6, 6, 6, 6, 6]
+ drop_path_rate: 0.2
+ focal_levels: [2, 2, 2, 2, 2, 2]
+ expand_sizes: [3, 3, 3, 3, 3, 3]
+ expand_layer: "all"
+ focal_windows: [11, 9, 7, 5, 3, 1]
+ mlp_ratio: 2
+ upsampler: 'pixelshuffle'
+ resi_connection: '1conv'
+ use_shift: False
+
+# path
+path:
+ pretrain_network_g: ~
+ strict_load_g: true
+ resume_state: ~
+
+# training settings
+train:
+ ema_decay: 0.999
+ optim_g:
+ type: Adam
+ lr: !!float 2e-4
+ weight_decay: 0
+ betas: [0.9, 0.99]
+
+ scheduler:
+ type: MultiStepLR
+ milestones: [250000, 400000, 450000, 475000]
+ gamma: 0.5
+
+ total_iter: 500000
+ warmup_iter: -1 # no warm up
+
+ # losses
+ pixel_opt:
+ type: L1Loss
+ loss_weight: 1.0
+ reduction: mean
+
+# validation settings
+val:
+ #val_freq: !!float 5e3
+ val_freq: 1
+ save_img: false
+
+ metrics:
+ psnr: # metric name, can be arbitrary
+ type: calculate_psnr
+ crop_border: 4
+ test_y_channel: true
+ ssim:
+ type: calculate_ssim
+ crop_border: 4
+ test_y_channel: true
+ lpips:
+ type: calculate_lpips
+ crop_border: 4
+ test_y_channel: false
+ better: lower
+
+# logging settings
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 5e3
+ use_tb_logger: true
+ wandb:
+ project: ~
+ resume_id: ~
+
+# dist training settings
+#dist_params:
+# backend: nccl
+# port: 29500
diff --git a/options/train/RDSwinIR/train_RDSwinIR_SRx4_scratch.yml b/options/train/RDSwinIR/train_RDSwinIR_SRx4_scratch.yml
new file mode 100644
index 000000000..296cab105
--- /dev/null
+++ b/options/train/RDSwinIR/train_RDSwinIR_SRx4_scratch.yml
@@ -0,0 +1,121 @@
+# general settings
+name: train_RDSwinIR_learnable_DIV2K
+model_type: SwinIRModel
+scale: 4
+num_gpu: 1
+manual_seed: 0
+
+# dataset and data loader settings
+datasets:
+ train:
+ name: DIV2K
+ type: PairedImageDataset
+ #dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub
+ #dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub
+ dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb
+ dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb
+ #meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt
+ filename_tmpl: '{}'
+ io_backend:
+ type: lmdb
+
+ gt_size: 64
+ use_flip: true
+ use_rot: true
+ color: n
+
+ # data loader
+ use_shuffle: true
+ num_worker_per_gpu: 6
+ batch_size_per_gpu: 16
+ dataset_enlarge_ratio: 1
+ prefetch_mode: ~
+
+ val:
+ name: Set5
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/Set5/HR
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+ color: n
+
+# network structures
+network_g:
+ type: RDSwinIR
+ upscale: 4
+ in_chans: 3
+ img_size: 32
+ window_size: 8
+ img_range: 1.
+ depths: [6, 6, 6, 6, 6, 6]
+ embed_dim: 180
+ num_heads: [6, 6, 6, 6, 6, 6]
+ mlp_ratio: 2
+ #patch_norm: False
+ upsampler: 'pixelshuffle'
+ resi_connection: '1conv'
+
+# path
+path:
+ pretrain_network_g: ~
+ strict_load_g: true
+ resume_state: ~
+
+# training settings
+train:
+ ema_decay: 0.999
+ optim_g:
+ type: Adam
+ lr: !!float 2e-4
+ weight_decay: 0
+ betas: [0.9, 0.99]
+
+ scheduler:
+ type: MultiStepLR
+ milestones: [250000, 400000, 450000, 475000]
+ gamma: 0.5
+
+ total_iter: 500000
+ warmup_iter: -1 # no warm up
+
+ # losses
+ pixel_opt:
+ type: L1Loss
+ loss_weight: 1.0
+ reduction: mean
+
+# validation settings
+val:
+ val_freq: !!float 5e3
+ save_img: false
+
+ metrics:
+ psnr: # metric name, can be arbitrary
+ type: calculate_psnr
+ crop_border: 4
+ test_y_channel: true
+ ssim:
+ type: calculate_ssim
+ crop_border: 4
+ test_y_channel: true
+ lpips:
+ type: calculate_lpips
+ crop_border: 4
+ test_y_channel: false
+ better: lower
+
+# logging settings
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 5e3
+ use_tb_logger: true
+ wandb:
+ project: ~
+ resume_id: ~
+
+# dist training settings
+#dist_params:
+# backend: nccl
+# port: 29500
diff --git a/options/train/SwinIR/train_SwinIR_SRx4_scratch.yml b/options/train/SwinIR/train_SwinIR_SRx4_scratch.yml
index ed1edd006..b7571a213 100644
--- a/options/train/SwinIR/train_SwinIR_SRx4_scratch.yml
+++ b/options/train/SwinIR/train_SwinIR_SRx4_scratch.yml
@@ -1,8 +1,8 @@
# general settings
-name: train_SwinIR_SRx4_scratch_P48W8_DIV2K_500k_B4G8
+name: train_SwinIR_SRx4_s48g96_DIV2K
model_type: SwinIRModel
scale: 4
-num_gpu: auto
+num_gpu: 1
manual_seed: 0
# dataset and data loader settings
@@ -10,31 +10,36 @@ datasets:
train:
name: DIV2K
type: PairedImageDataset
- dataroot_gt: datasets/DF2K/DIV2K_train_HR_sub
- dataroot_lq: datasets/DF2K/DIV2K_train_LR_bicubic_X4_sub
- meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt
+ #dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub
+ #dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic/X4_sub
+ dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb
+ dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb
+ #meta_info_file: basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt
filename_tmpl: '{}'
io_backend:
- type: disk
+ type: lmdb
- gt_size: 192
+ gt_size: 96
use_flip: true
use_rot: true
+ color: n
# data loader
use_shuffle: true
num_worker_per_gpu: 6
- batch_size_per_gpu: 4
+ batch_size_per_gpu: 16
dataset_enlarge_ratio: 1
prefetch_mode: ~
val:
name: Set5
type: PairedImageDataset
- dataroot_gt: datasets/Set5/GTmod12
- dataroot_lq: datasets/Set5/LRbicx4
+ dataroot_gt: datasets/benchmark/Set5/HR
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
+ filename_tmpl: '{}x4'
io_backend:
type: disk
+ color: n
# network structures
network_g:
@@ -89,7 +94,16 @@ val:
psnr: # metric name, can be arbitrary
type: calculate_psnr
crop_border: 4
+ test_y_channel: true
+ ssim:
+ type: calculate_ssim
+ crop_border: 4
+ test_y_channel: true
+ lpips:
+ type: calculate_lpips
+ crop_border: 4
test_y_channel: false
+ better: lower
# logging settings
logger:
@@ -101,6 +115,6 @@ logger:
resume_id: ~
# dist training settings
-dist_params:
- backend: nccl
- port: 29500
+#dist_params:
+# backend: nccl
+# port: 29500
diff --git a/options/train/SwinIRGAN/train_SwinIRGAN_x4.yml b/options/train/SwinIRGAN/train_SwinIRGAN_x4.yml
new file mode 100644
index 000000000..916c875a0
--- /dev/null
+++ b/options/train/SwinIRGAN/train_SwinIRGAN_x4.yml
@@ -0,0 +1,151 @@
+# general settings
+name: SwinIRGANModel_x4_DIV2K
+model_type: SwinIRGANModel
+scale: 4
+num_gpu: 1 # set num_gpu: 0 for cpu mode
+manual_seed: 0
+
+# dataset and data loader settings
+datasets:
+ train:
+ name: DIV2K
+ type: PairedImageDataset
+ dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb
+ dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb
+ # (for lmdb)
+ # dataroot_gt: datasets/DIV2K/DIV2K_train_HR_sub.lmdb
+ # dataroot_lq: datasets/DIV2K/DIV2K_train_LR_bicubic_X4_sub.lmdb
+ filename_tmpl: '{}'
+ io_backend:
+ type: lmdb
+ # (for lmdb)
+ # type: lmdb
+
+ gt_size: 128
+ use_flip: true
+ use_rot: true
+
+ # data loader
+ use_shuffle: true
+ num_worker_per_gpu: 6
+ batch_size_per_gpu: 16
+ dataset_enlarge_ratio: 100
+ prefetch_mode: ~
+
+ val:
+ name: Set5
+ type: PairedImageDataset
+ dataroot_gt: datasets/benchmark/Set5/HR
+ dataroot_lq: datasets/benchmark/Set5/LR_bicubic/X4
+ filename_tmpl: '{}x4'
+ io_backend:
+ type: disk
+ color: n
+
+# network structures
+network_g:
+ type: SwinIR
+ upscale: 4
+ in_chans: 3
+ img_size: 64
+ window_size: 8
+ img_range: 1.
+ depths: [6, 6, 6, 6, 6, 6]
+ embed_dim: 180
+ num_heads: [6, 6, 6, 6, 6, 6]
+ mlp_ratio: 4
+ upsampler: 'pixelshuffle'
+ resi_connection: '1conv'
+
+
+network_d:
+ type: UNetDiscriminatorSN
+ num_in_ch: 3
+ num_feat: 64
+
+# path
+path:
+ pretrain_network_g: ~
+ strict_load_g: true
+ resume_state: ~
+
+# training settings
+train:
+ ema_decay: 0.999
+ optim_g:
+ type: Adam
+ lr: !!float 1e-4
+ weight_decay: 0
+ betas: [0.9, 0.99]
+ optim_d:
+ type: Adam
+ lr: !!float 1e-4
+ weight_decay: 0
+ betas: [0.9, 0.99]
+
+ scheduler:
+ type: MultiStepLR
+ milestones: [50000, 100000, 200000, 300000]
+ gamma: 0.5
+
+ total_iter: 400000
+ warmup_iter: -1 # no warm up
+
+ # losses
+ pixel_opt:
+ type: L1Loss
+ loss_weight: !!float 1e-2
+ reduction: mean
+ perceptual_opt:
+ type: PerceptualLoss
+ layer_weights:
+ 'conv5_4': 1 # before relu
+ vgg_type: vgg19
+ use_input_norm: true
+ range_norm: false
+ perceptual_weight: 1.0
+ style_weight: 0
+ criterion: l1
+ gan_opt:
+ type: GANLoss
+ gan_type: vanilla
+ real_label_val: 1.0
+ fake_label_val: 0.0
+ loss_weight: !!float 5e-3
+
+ net_d_iters: 1
+ net_d_init_iters: 0
+
+# validation settings
+val:
+ val_freq: !!float 5e3
+ save_img: true
+
+ metrics:
+ psnr: # metric name, can be arbitrary
+ type: calculate_psnr
+ crop_border: 4
+ test_y_channel: true
+ ssim: # metric name, can be arbitrary
+ type: calculate_ssim
+ crop_border: 4
+ test_y_channel: true
+ lpips: # metric name, can be arbitrary
+ type: calculate_lpips
+ crop_border: 4
+ test_y_channel: true
+ better: lower
+
+# logging settings
+logger:
+ print_freq: 100
+ save_checkpoint_freq: !!float 5e3
+ use_tb_logger: true
+ wandb:
+ project: ~
+ resume_id: ~
+
+# dist training settings
+dist_params:
+ backend: nccl
+ port: 29500
diff --git a/setup.py b/setup.py
index 998ba8591..9e31e945e 100644
--- a/setup.py
+++ b/setup.py
@@ -45,12 +45,13 @@ def _minimal_ext_cmd(cmd):
def get_hash():
if os.path.exists('.git'):
sha = get_git_hash()[:7]
- elif os.path.exists(version_file):
- try:
- from basicsr.version import __version__
- sha = __version__.split('+')[-1]
- except ImportError:
- raise ImportError('Unable to get git version')
+ # currently ignore this
+ # elif os.path.exists(version_file):
+ # try:
+ # from basicsr.version import __version__
+ # sha = __version__.split('+')[-1]
+ # except ImportError:
+ # raise ImportError('Unable to get git version')
else:
sha = 'unknown'