Skip to content

Commit e37edc0

Browse files
committed
add AE example
1 parent b09336e commit e37edc0

File tree

8 files changed

+316
-4
lines changed

8 files changed

+316
-4
lines changed

examples/as_loss/as_loss.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
1+
import os
2+
import sys
13
import paddle
2-
from paddle.optimizer import Adam
3-
from PIL import Image
44
import numpy as np
5-
import sys, os
6-
import paddle.nn.functional as F
5+
6+
from PIL import Image
7+
from paddle.optimizer import Adam
78
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))), '..'))
89
from paddle_msssim import ssim, ms_ssim, SSIM, MS_SSIM
910

11+
1012
loss_type = 'msssim'
1113
assert loss_type in ['ssim', 'msssim']
1214

examples/auto_encoder/README.md

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Train an autoencoder with SSIM & MS-SSIM
2+
3+
## Prepare dataset
4+
* Download CLIC datase from http://clic.compression.cc/2021/tasks/index.html.
5+
6+
* Unzip them into datasets.
7+
8+
* The structure of the directory:
9+
10+
```yaml
11+
- datasets
12+
- CLIC
13+
- train
14+
- *.png
15+
- ...
16+
- valid
17+
- *.png
18+
- ...
19+
```
20+
21+
## Train
22+
* SSIM loss:
23+
24+
```bash
25+
$ python train.py --loss_type ssim
26+
```
27+
28+
* MS-SSIM loss:
29+
30+
```bash
31+
$ python train.py --loss_type ms_ssim
32+
```
33+
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .image_dataset import ImageDataset
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import os
2+
from PIL import Image
3+
from paddle.io import Dataset
4+
5+
6+
class ImageDataset(Dataset):
7+
def __init__(self, root, transform=None):
8+
self.root = root
9+
self.transform = transform
10+
self.images = list(os.listdir(root))
11+
self.images.sort()
12+
13+
def __getitem__(self, idx):
14+
img = Image.open(os.path.join(self.root, self.images[idx]))
15+
if self.transform is not None:
16+
img = self.transform(img)
17+
return img,
18+
19+
def __len__(self):
20+
return len(self.images)
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .autoencoder import AutoEncoder
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import paddle.nn as nn
2+
import paddle.nn.functional as F
3+
4+
from .gdn import GDN
5+
6+
7+
# https://arxiv.org/pdf/1611.01704.pdf
8+
# A simplfied version without quantization
9+
class AutoEncoder(nn.Layer):
10+
def __init__(self, C=128, M=128, in_chan=3, out_chan=3):
11+
super(AutoEncoder, self).__init__()
12+
self.encoder = Encoder(C=C, M=M, in_chan=in_chan)
13+
self.decoder = Decoder(C=C, M=M, out_chan=out_chan)
14+
15+
def forward(self, x, **kargs):
16+
code = self.encoder(x)
17+
out = self.decoder(code)
18+
return out
19+
20+
21+
class Encoder(nn.Layer):
22+
""" Encoder
23+
"""
24+
25+
def __init__(self, C=32, M=128, in_chan=3):
26+
super(Encoder, self).__init__()
27+
self.enc = nn.Sequential(
28+
nn.Conv2D(in_channels=in_chan, out_channels=M,
29+
kernel_size=5, stride=2, padding=2, bias_attr=False),
30+
GDN(M),
31+
32+
nn.Conv2D(in_channels=M, out_channels=M, kernel_size=5,
33+
stride=2, padding=2, bias_attr=False),
34+
GDN(M),
35+
36+
nn.Conv2D(in_channels=M, out_channels=M, kernel_size=5,
37+
stride=2, padding=2, bias_attr=False),
38+
GDN(M),
39+
40+
nn.Conv2D(in_channels=M, out_channels=C, kernel_size=5,
41+
stride=2, padding=2, bias_attr=False)
42+
)
43+
44+
def forward(self, x):
45+
return self.enc(x)
46+
47+
48+
class Decoder(nn.Layer):
49+
""" Decoder
50+
"""
51+
52+
def __init__(self, C=32, M=128, out_chan=3):
53+
super(Decoder, self).__init__()
54+
self.dec = nn.Sequential(
55+
nn.Conv2DTranspose(in_channels=C, out_channels=M, kernel_size=5,
56+
stride=2, padding=2, output_padding=1, bias_attr=False),
57+
GDN(M, inverse=True),
58+
59+
nn.Conv2DTranspose(in_channels=M, out_channels=M, kernel_size=5,
60+
stride=2, padding=2, output_padding=1, bias_attr=False),
61+
GDN(M, inverse=True),
62+
63+
nn.Conv2DTranspose(in_channels=M, out_channels=M, kernel_size=5,
64+
stride=2, padding=2, output_padding=1, bias_attr=False),
65+
GDN(M, inverse=True),
66+
67+
nn.Conv2DTranspose(in_channels=M, out_channels=out_chan, kernel_size=5,
68+
stride=2, padding=2, output_padding=1, bias_attr=False),
69+
)
70+
71+
def forward(self, q):
72+
return F.sigmoid(self.dec(q))
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import paddle
2+
import paddle.nn as nn
3+
import paddle.nn.functional as F
4+
5+
6+
class GDN(nn.Layer):
7+
def __init__(self,
8+
num_features,
9+
inverse=False,
10+
gamma_init=.1,
11+
beta_bound=1e-6,
12+
gamma_bound=0.0,
13+
reparam_offset=2**-18,
14+
):
15+
super(GDN, self).__init__()
16+
self._inverse = inverse
17+
self.num_features = num_features
18+
self.reparam_offset = reparam_offset
19+
self.pedestal = self.reparam_offset**2
20+
21+
beta_init = paddle.sqrt(paddle.ones((num_features, ), dtype=paddle.float32) + self.pedestal)
22+
gama_init = paddle.sqrt(paddle.full((num_features, num_features), fill_value=gamma_init, dtype=paddle.float32)
23+
* paddle.eye(num_features, dtype=paddle.float32) + self.pedestal)
24+
25+
self.beta = self.create_parameter(
26+
shape=beta_init.shape, default_initializer=nn.initializer.Assign(beta_init))
27+
self.gamma = self.create_parameter(
28+
shape=gama_init.shape, default_initializer=nn.initializer.Assign(gama_init))
29+
30+
self.beta_bound = (beta_bound + self.pedestal) ** 0.5
31+
self.gamma_bound = (gamma_bound + self.pedestal) ** 0.5
32+
33+
def _reparam(self, var, bound):
34+
var = paddle.clip(var, min=bound)
35+
return (var**2) - self.pedestal
36+
37+
def forward(self, x):
38+
gamma = self._reparam(self.gamma, self.gamma_bound).reshape((self.num_features, self.num_features, 1, 1)) # expand to (C, C, 1, 1)
39+
beta = self._reparam(self.beta, self.beta_bound)
40+
norm_pool = F.conv2d(x ** 2, gamma, bias=beta, stride=1, padding=0)
41+
norm_pool = paddle.sqrt(norm_pool)
42+
43+
if self._inverse:
44+
norm_pool = x * norm_pool
45+
else:
46+
norm_pool = x / norm_pool
47+
return norm_pool

examples/auto_encoder/train.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import os
2+
import sys
3+
import paddle
4+
import argparse
5+
6+
from PIL import Image
7+
from models import AutoEncoder
8+
from datas import ImageDataset
9+
from paddle.vision import transforms
10+
from paddle.optimizer import Adam
11+
from paddle.io import DataLoader
12+
13+
sys.path.append(os.path.join(os.path.dirname(os.path.dirname(os.path.realpath(__file__))), '..'))
14+
from paddle_msssim import ssim, ms_ssim, SSIM, MS_SSIM
15+
16+
17+
class MS_SSIM_Loss(MS_SSIM):
18+
def forward(self, img1, img2):
19+
return 100*(1 - super(MS_SSIM_Loss, self).forward(img1, img2))
20+
21+
22+
class SSIM_Loss(SSIM):
23+
def forward(self, img1, img2):
24+
return 100*(1 - super(SSIM_Loss, self).forward(img1, img2))
25+
26+
27+
def get_argparser():
28+
parser = argparse.ArgumentParser()
29+
parser.add_argument("--ckpt", default=None, type=str,
30+
help="path to trained model. Leave it None if you want to retrain your model")
31+
parser.add_argument("--loss_type", type=str,
32+
default='ssim', choices=['ssim', 'ms_ssim'])
33+
parser.add_argument("--batch_size", type=int, default=8)
34+
parser.add_argument("--log_interval", type=int, default=10)
35+
parser.add_argument("--total_epochs", type=int, default=50)
36+
return parser
37+
38+
39+
def main():
40+
opts = get_argparser().parse_args()
41+
42+
# dataset
43+
train_trainsform = transforms.Compose([
44+
transforms.RandomCrop(size=512, pad_if_needed=True),
45+
transforms.RandomHorizontalFlip(),
46+
transforms.RandomVerticalFlip(),
47+
transforms.ToTensor(),
48+
])
49+
50+
val_transform = transforms.Compose([
51+
transforms.CenterCrop(size=512),
52+
transforms.ToTensor()
53+
])
54+
55+
train_loader = DataLoader(
56+
ImageDataset(root='datasets/CLIC/train', transform=train_trainsform),
57+
batch_size=opts.batch_size, shuffle=True, num_workers=0, drop_last=True)
58+
59+
val_loader = DataLoader(
60+
ImageDataset(root='datasets/CLIC/valid', transform=val_transform),
61+
batch_size=opts.batch_size, shuffle=False, num_workers=0)
62+
63+
print("Train set: %d, Val set: %d" %
64+
(len(train_loader.dataset), len(val_loader.dataset)))
65+
model = AutoEncoder(C=128, M=128, in_chan=3, out_chan=3)
66+
67+
# optimizer
68+
optimizer = Adam(parameters=model.parameters(),
69+
learning_rate=1e-4,
70+
weight_decay=1e-5)
71+
72+
# checkpoint
73+
best_score = 0.0
74+
cur_epoch = 0
75+
if opts.ckpt is not None and os.path.isfile(opts.ckpt):
76+
model.set_dict(paddle.load(opts.ckpt))
77+
else:
78+
print("[!] Retrain")
79+
80+
if opts.loss_type == 'ssim':
81+
criterion = SSIM_Loss(data_range=1.0, size_average=True, channel=3)
82+
else:
83+
criterion = MS_SSIM_Loss(data_range=1.0, size_average=True, channel=3)
84+
85+
#========== Train Loop ==========#
86+
for cur_epoch in range(opts.total_epochs):
87+
# ===== Train =====
88+
model.train()
89+
for cur_step, (images, ) in enumerate(train_loader):
90+
optimizer.clear_grad()
91+
outputs = model(images)
92+
93+
loss = criterion(outputs, images)
94+
loss.backward()
95+
96+
optimizer.step()
97+
98+
if (cur_step) % opts.log_interval == 0:
99+
print("Epoch %d, Batch %d/%d, loss=%.6f" %
100+
(cur_epoch, cur_step, len(train_loader), loss.item()))
101+
102+
# ===== Save Latest Model =====
103+
paddle.save(model.state_dict(), 'latest_model.pdparams')
104+
105+
# ===== Validation =====
106+
print("Val...")
107+
best_score = 0.0
108+
cur_score = test(opts, model, val_loader)
109+
print("%s = %.6f" % (opts.loss_type, cur_score))
110+
# ===== Save Best Model =====
111+
if cur_score > best_score: # save best model
112+
best_score = cur_score
113+
paddle.save(model.state_dict(), 'best_model.pdparams')
114+
print("Best model saved as best_model.pt")
115+
116+
117+
def test(opts, model, val_loader):
118+
model.eval()
119+
cur_score = 0.0
120+
121+
metric = ssim if opts.loss_type == 'ssim' else ms_ssim
122+
123+
with paddle.no_grad():
124+
for i, (images, ) in enumerate(val_loader):
125+
outputs = model(images)
126+
# save the first reconstructed image
127+
if i == 20:
128+
Image.fromarray((outputs*255).squeeze(0).detach().numpy().astype(
129+
'uint8').transpose(1, 2, 0)).save('recons_%s.png' % (opts.loss_type))
130+
cur_score += metric(outputs, images, data_range=1.0)
131+
cur_score /= len(val_loader.dataset)
132+
return cur_score
133+
134+
135+
if __name__ == '__main__':
136+
main()

0 commit comments

Comments
 (0)