Skip to content

Commit 7077e9e

Browse files
committed
minor text fixes
1 parent 9542165 commit 7077e9e

File tree

7 files changed

+38
-95
lines changed

7 files changed

+38
-95
lines changed

README.md

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ JPG, 0.264 bpp / 90.1 kB
1818
```
1919
![guess](assets/comparison/camp_jpg_compress.png)
2020

21-
The image shown is an out-of-sample instance from the CLIC-2020 dataset. The HIFIC image is obtained by reconstruction via the learned model. The JPG image is obtained by the command `mogrify -format jpg -quality 42 camp_original.png`. All images are losslessly compressed to PNG format for viewing. Images stored under `assets/comparison`. Note that the learned model was not adapted in any way for evaluation of this image.
21+
The image shown is an out-of-sample instance from the CLIC-2020 dataset. The HIFIC image is obtained by reconstruction via the learned model. The JPG image is obtained by the command `mogrify -format jpg -quality 42 camp_original.png`. All images are losslessly compressed to PNG format for viewing. Images and other examples are stored under `assets/comparison`. Note that the learned model was not adapted in any way for evaluation of this image.
2222

2323
## Details
2424
This repository defines a model for learnable image compression capable of compressing images of arbitrary size and resolution. There are three main components to this model, as described in the original paper:
@@ -61,7 +61,7 @@ python3 train.py --model_type compression --regime low --n_steps 1e6
6161
```
6262
python3 train.py --model_type compression_gan --regime low --n_steps 1e6 --warmstart --ckpt path/to/base/checkpoint
6363
```
64-
* Training after the warmstart for 2e5 steps using a batch size of 16 was sufficient to get reasonable results at sub-0.2 `bpp` on average.
64+
* Training after the warmstart for 2e5 steps using a batch size of 16 was sufficient to get reasonable results at sub-0.2 `bpp` per image, on average using the default config.
6565
* If you get out-of-memory errors, try:
6666
* Reducing the number of residual blocks in the generator (default 7, the original paper used 9).
6767
* Decreasing the batch size (default 16).
@@ -73,23 +73,24 @@ tensorboard --logdir experiments/my_experiment/tensorboard
7373
```
7474

7575
### Compression
76-
* To obtain a _theoretical_ measure of the bitrate under some trained model, run `compress.py`. This will report the bits-per-pixel attainable by the compressed representation (`bpp`), some other fun metrics, and perform a forward pass through the model to obtain the reconstructed image. This model will work with images of arbitrary sizes and resolution.
76+
* To obtain a _theoretical_ measure of the bitrate under some trained model, run `compress.py`. This will report the bits-per-pixel attainable by the compressed representation (`bpp`), some other fun metrics, and perform a forward pass through the model to obtain the reconstructed image (as a PNG). This model will work with images of arbitrary sizes and resolution (provided you don't run out of memory). This will work with JPG and PNG (without alpha channels).
7777
```
78-
python3 compress.py --img path/to/image/dir --ckpt path/to/trained/model
78+
python3 compress.py -i path/to/image/dir -ckpt path/to/trained/model
7979
```
80+
* A pretrained model using the OpenImages dataset can be found here: [Drive link]. This model was trained for 2e5 warmup steps and 2e5 steps with the full generative loss. To use this, download the model and point the `-ckpt` argument in the command above to the corresponding path.
81+
8082
* The reported `bpp` is the theoretical bitrate required to losslessly store the quantized latent representation of an image as determined by the learned probability model provided by the hyperprior using some entropy coding algorithm. Comparing this (not the size of the reconstruction) against the original size of the image will give you an idea of the reduction in memory footprint. This repository does not currently support actual compression to a bitstring ([TensorFlow Compression](https://github.com/tensorflow/compression) does this well). We're working on an ANS entropy coder to support this in the future.
8183

8284
### Notes
8385
* The "size" of the compressed image as reported in `bpp` does not account for the size of the model required to decode the compressed format.
84-
* The total size of the model (using the original architecture) is around 737 MB. Forward pass time should scale sublinearly provided everything fits in memory.
85-
* You may get an OOM error when compressing images which are too large. We're working on a fix.
86+
* The total size of the model (using the original architecture) is around 737 MB. Forward pass time should scale sublinearly provided everything fits in memory. A complete forward pass using a batch of 10 images takes around 45s on a 2.8 GHz Intel Core i7.
87+
* You may get an OOM error when compressing images which are too large (`>~ 4000 x 4000`). It's possible to get around this by applying the network to evenly sized crops of the input image whose forward pass will fit in memory. We're working on a fix to automatically support this.
8688

8789
### Contributing
8890
All content in this repository is licensed under the Apache-2.0 license. Feel free to submit any corrections or suggestions as issues.
8991

9092
### Acknowledgements
9193
* The code under `hific/perceptual_similarity/` implementing the perceptual distortion loss is modified from the [Perceptual Similarity repository](https://github.com/richzhang/PerceptualSimilarity).
92-
<!-- * The cat in the main image is my neighbour's. -->
9394

9495
### Authors
9596
* Grace Han

compress.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,8 @@ def compress_batch(args):
6767
input_filenames_total.extend(filenames)
6868

6969
for subidx in range(reconstruction.shape[0]):
70-
fname = os.path.join(args.output_dir, "{}_RECON.png".format(filenames[subidx]))
70+
bpp_per_im = float(bpp[subidx].cpu().numpy())
71+
fname = os.path.join(args.output_dir, "{}_RECON_{:.3f}bpp.png".format(filenames[subidx], bpp_per_im))
7172
torchvision.utils.save_image(reconstruction[subidx], fname, normalize=True)
7273
output_filenames_total.append(fname)
7374

@@ -97,7 +98,7 @@ def compress_batch(args):
9798

9899
def main(**kwargs):
99100

100-
description = "Compresses batch of images using specified learned model."
101+
description = "Compresses batch of images using learned model specified via -ckpt argument."
101102
parser = argparse.ArgumentParser(description=description,
102103
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
103104
parser.add_argument("-ckpt", "--ckpt_path", type=str, required=True, help="Path to model to be restored")
@@ -106,7 +107,7 @@ def main(**kwargs):
106107
parser.add_argument("-o", "--output_dir", type=str, default='data/reconstructions',
107108
help="Path to directory to store output images")
108109
parser.add_argument('-bs', '--batch_size', type=int, default=1,
109-
help="Dataloader batch size. Set to 1 for images of different sizes.")
110+
help="Loader batch size. Set to 1 if images in directory are different sizes.")
110111
args = parser.parse_args()
111112

112113
input_images = glob.glob(os.path.join(args.image_dir, '*.jpg'))

default_config.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ class ModelTypes(object):
1313

1414
class ModelModes(object):
1515
TRAINING = 'training'
16-
VALIDATION = 'validation' # Monitoring
16+
VALIDATION = 'validation'
1717
EVALUATION = 'evaluation'
1818

1919
class Datasets(object):
@@ -30,13 +30,13 @@ class directories(object):
3030
experiments = 'experiments'
3131

3232
class checkpoints(object):
33-
gan1 = 'experiments/gan_med_bitrate_openimages_compression_gan_2020_08_14_07_12/checkpoints/gan_med_bitrate_openimages_compression_gan_2020_08_14_07_12_epoch1_idx56776_2020_08_14_18:43.pt'
33+
gan1 = 'experiments/lossless.pt'
3434

3535
class args(object):
3636
"""
3737
Shared config
3838
"""
39-
name = 'hific_v0'
39+
name = 'hific_v0.1'
4040
silent = True
4141
n_epochs = 8
4242
n_steps = 1e6
@@ -52,8 +52,8 @@ class args(object):
5252
model_mode = ModelModes.TRAINING
5353

5454
# Architecture params - Table 3a) of [1]
55-
latent_channels = 220 #220
56-
n_residual_blocks = 7 #7 # Authors use 9 blocks, performance saturates at 5
55+
latent_channels = 220
56+
n_residual_blocks = 7 # Authors use 9 blocks, performance saturates at 5
5757
lambda_B = 2**(-4) # Loose rate
5858
k_M = 0.075 * 2**(-5) # Distortion
5959
k_P = 1. # Perceptual loss

src/helpers/maths.py

Lines changed: 2 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ def backward(ctx, grad_output):
1212
return grad_output.clone(), None
1313

1414

15-
class LowerBoundToward_0(torch.autograd.Function):
15+
class LowerBoundToward(torch.autograd.Function):
1616
"""
1717
Assumes output shape is identical to input shape.
1818
"""
@@ -24,26 +24,9 @@ def forward(ctx, tensor, lower_bound):
2424

2525
@staticmethod
2626
def backward(ctx, grad_output):
27-
# gate = torch.autograd.Variable(torch.logical_or(ctx.mask, grad_output.lt(0.)).type(grad_output.dtype))
28-
gate = torch.autograd.Variable(torch.logical_or(ctx.mask, grad_output.lt(0.)).type_as(grad_output.data))
27+
gate = torch.logical_or(ctx.mask, grad_output.lt(0.)).type(grad_output.dtype)
2928
return grad_output * gate, None
3029

31-
class LowerBoundToward(torch.autograd.Function):
32-
@staticmethod
33-
def forward(ctx, inputs, bound):
34-
b = torch.ones_like(inputs) * bound
35-
ctx.save_for_backward(inputs, b)
36-
return torch.max(inputs, b)
37-
38-
@staticmethod
39-
def backward(ctx, grad_output):
40-
inputs, b = ctx.saved_tensors
41-
pass_through_1 = inputs >= b
42-
pass_through_2 = grad_output < 0
43-
44-
pass_through = pass_through_1 | pass_through_2
45-
return pass_through.type(grad_output.dtype) * grad_output, None
46-
4730
def standardized_CDF_gaussian(value):
4831
# Gaussian
4932
# return 0.5 * (1. + torch.erf(value/ np.sqrt(2)))

src/model.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
Intermediates = namedtuple("Intermediates",
2323
["input_image", # [0, 1] (after scaling from [0, 255])
2424
"reconstruction", # [0, 1]
25-
"latents_quantized", # Latents post-quantization.
25+
"latents_quantized", # Latents post-quantization.
2626
"n_bpp", # Differential entropy estimate.
2727
"q_bpp"]) # Shannon entropy estimate.
2828

@@ -134,6 +134,7 @@ def compression_forward(self, x):
134134
total_nbpp = hyperinfo.total_nbpp
135135
total_qbpp = hyperinfo.total_qbpp
136136

137+
# Use quantized latents as input to G
137138
reconstruction = self.Generator(latents_quantized)
138139

139140
if self.args.normalize_input_image is True:
@@ -160,7 +161,6 @@ def discriminator_forward(self, intermediates, train_generator):
160161
D_in = torch.cat([x_real, x_gen], dim=0)
161162

162163
latents = intermediates.latents_quantized.detach()
163-
# latents = torch.cat([latents, latents], dim=0)
164164
latents = torch.repeat_interleave(latents, 2, dim=0)
165165

166166
D_out, D_out_logits = self.Discriminator(D_in, latents)
@@ -170,14 +170,11 @@ def discriminator_forward(self, intermediates, train_generator):
170170
D_real, D_gen = torch.chunk(D_out, 2, dim=0)
171171
D_real_logits, D_gen_logits = torch.chunk(D_out_logits, 2, dim=0)
172172

173-
# Tensorboard
174-
# real_response, gen_response = D_real.mean(), D_fake.mean()
175-
176173
return Disc_out(D_real, D_gen, D_real_logits, D_gen_logits)
177174

178175
def distortion_loss(self, x_gen, x_real):
179176
# loss in [0,255] space but normalized by 255 to not be too big
180-
# - Delegate to weighting
177+
# - Delegate scaling to weighting
181178
sq_err = self.squared_difference(x_gen*255., x_real*255.) # / 255.
182179
return torch.mean(sq_err)
183180

@@ -196,30 +193,18 @@ def compression_loss(self, intermediates, hyperinfo):
196193
x_real = (x_real + 1.) / 2.
197194
x_gen = (x_gen + 1.) / 2.
198195

199-
# print('X REAL MAX', x_real.max())
200-
# print('X REAL MIN', x_real.min())
201-
# print('X GEN MAX', x_gen.max())
202-
# print('X GEN MIN', x_gen.min())
203-
204196
distortion_loss = self.distortion_loss(x_gen, x_real)
205197
perceptual_loss = self.perceptual_loss_wrapper(x_gen, x_real, normalize=True)
206198

207199
weighted_distortion = self.args.k_M * distortion_loss
208200
weighted_perceptual = self.args.k_P * perceptual_loss
209201

210-
# print('Distortion loss size', weighted_distortion.size())
211-
# print('Perceptual loss size', weighted_perceptual.size())
212-
213202
weighted_rate, rate_penalty = losses.weighted_rate_loss(self.args, total_nbpp=intermediates.n_bpp,
214203
total_qbpp=intermediates.q_bpp, step_counter=self.step_counter)
215204

216-
# print('Weighted rate loss size', weighted_rate.size())
217205
weighted_R_D_loss = weighted_rate + weighted_distortion
218206
weighted_compression_loss = weighted_R_D_loss + weighted_perceptual
219207

220-
# print('Weighted R-D loss size', weighted_R_D_loss.size())
221-
# print('Weighted compression loss size', weighted_compression_loss.size())
222-
223208
# Bookkeeping
224209
if (self.step_counter % self.log_interval == 1):
225210
self.store_loss('rate_penalty', rate_penalty)

src/network/hyperprior.py

Lines changed: 11 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -45,17 +45,15 @@ def _quantize(self, x, mode='noise', means=None):
4545

4646
if mode == 'noise':
4747
quantization_noise = torch.nn.init.uniform_(torch.zeros_like(x), -0.5, 0.5)
48-
# quantization_noise = torch.rand(x.size()).to(x) - 0.5
4948
x = x + quantization_noise
50-
elif mode == 'quantize':
5149

50+
elif mode == 'quantize':
5251
if means is not None:
5352
x = x - means
5453
x = torch.floor(x + 0.5)
5554
x = x + means
5655
else:
5756
x = torch.floor(x + 0.5)
58-
# x = torch.round(x)
5957
else:
6058
raise NotImplementedError
6159

@@ -71,16 +69,8 @@ def _estimate_entropy(self, likelihood, spatial_shape):
7169
n_pixels = np.prod(spatial_shape)
7270

7371
log_likelihood = torch.log(likelihood + EPS)
74-
# print('LOG LIKELIHOOD', log_likelihood.mean().item())
7572
n_bits = torch.sum(log_likelihood) / (batch_size * quotient)
7673
bpp = n_bits / n_pixels
77-
# print('N_PIXELS', n_pixels)
78-
# print('BATCH SIZE', batch_size)
79-
# print('LH', likelihood)
80-
#print('LH MAX', likelihood.max())
81-
#print('LH MAX', likelihood.min())
82-
#print('NB', n_bits)
83-
#print('BPP', bpp)
8474

8575
return n_bits, bpp
8676

@@ -192,13 +182,13 @@ def likelihood(self, x):
192182

193183
# Numerical stability using some sigmoid identities
194184
# to avoid subtraction of two numbers close to 1
195-
# sign = -torch.sign(cdf_upper + cdf_lower)
196-
# sign = sign.detach()
197-
# likelihood_ = torch.abs(
198-
# torch.sigmoid(sign * cdf_upper) - torch.sigmoid(sign * cdf_lower))
185+
sign = -torch.sign(cdf_upper + cdf_lower)
186+
sign = sign.detach()
187+
likelihood_ = torch.abs(
188+
torch.sigmoid(sign * cdf_upper) - torch.sigmoid(sign * cdf_lower))
199189

200190
# Naive
201-
likelihood_ = torch.sigmoid(cdf_upper) - torch.sigmoid(cdf_lower)
191+
# likelihood_ = torch.sigmoid(cdf_upper) - torch.sigmoid(cdf_lower)
202192

203193
# Reshape to (N,C,H,W)
204194
likelihood_ = torch.reshape(likelihood_, shape)
@@ -268,13 +258,13 @@ def latent_likelihood(self, x, mean, scale):
268258

269259
# Assumes 1 - CDF(x) = CDF(-x)
270260
x = x - mean
271-
# x = torch.abs(x)
272-
# cdf_upper = self.standardized_CDF((0.5 - x) / scale)
273-
# cdf_lower = self.standardized_CDF(-(0.5 + x) / scale)
261+
x = torch.abs(x)
262+
cdf_upper = self.standardized_CDF((0.5 - x) / scale)
263+
cdf_lower = self.standardized_CDF(-(0.5 + x) / scale)
274264

275265
# Naive
276-
cdf_upper = self.standardized_CDF( (x + 0.5) / scale )
277-
cdf_lower = self.standardized_CDF( (x - 0.5) / scale )
266+
# cdf_upper = self.standardized_CDF( (x + 0.5) / scale )
267+
# cdf_lower = self.standardized_CDF( (x - 0.5) / scale )
278268

279269
likelihood_ = cdf_upper - cdf_lower
280270
likelihood_ = lower_bound_toward(likelihood_, self.min_likelihood)
@@ -298,9 +288,6 @@ def forward(self, latents, spatial_shape, **kwargs):
298288
quantized_hyperlatent_bits, quantized_hyperlatent_bpp = self._estimate_entropy(
299289
quantized_hyperlatent_likelihood, spatial_shape)
300290

301-
#print('QUANT HL', quantized_hyperlatents)
302-
#print('maxQUANT HL', quantized_hyperlatents.max())
303-
#print('minQUANT HL', quantized_hyperlatents.min())
304291
if self.training is True:
305292
hyperlatents_decoded = noisy_hyperlatents
306293
else:
@@ -343,11 +330,6 @@ def forward(self, latents, spatial_shape, **kwargs):
343330
side_bitstring=None, # TODO
344331
)
345332

346-
# print(quantized_latents)
347-
# print(quantized_hyperlatents)
348-
# print(noisy_latents)
349-
# print(noisy_hyperlatents)
350-
351333
return info
352334

353335
class HyperpriorAnalysis(nn.Module):

train.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from default_config import hific_args, mse_lpips_args, directories, ModelModes, ModelTypes
2828

2929
# go fast boi!!
30-
# Optimizes cuda kernels by benchmarking - no dynamic input sizes!
3130
torch.backends.cudnn.benchmark = True
3231

3332
def create_model(args, device, logger, storage, storage_test):
@@ -304,19 +303,9 @@ def train(args, model, train_loader, test_loader, device, logger, optimizers):
304303
else:
305304
model = create_model(args, device, logger, storage, storage_test)
306305
model = model.to(device)
307-
# amortization_parameters = itertools.chain.from_iterable(
308-
# [am.parameters() for am in model.amortization_models])
309-
310-
amort_names, amortization_parameters = list(), list()
311-
for n, p in model.named_parameters():
312-
if ('Encoder' in n) or ('Generator' in n):
313-
amort_names.append(n)
314-
amortization_parameters.append(p)
315-
logger.info(f'AM {n} - {p.shape}')
316-
if ('analysis' in n) or ('synthesis' in n):
317-
amort_names.append(n)
318-
amortization_parameters.append(p)
319-
logger.info(f'AM {n} - {p.shape}')
306+
amortization_parameters = itertools.chain.from_iterable(
307+
[am.parameters() for am in model.amortization_models])
308+
320309
hyperlatent_likelihood_parameters = model.Hyperprior.hyperlatent_likelihood.parameters()
321310

322311
amortization_opt = torch.optim.Adam(amortization_parameters,
@@ -332,6 +321,8 @@ def train(args, model, train_loader, test_loader, device, logger, optimizers):
332321

333322
n_gpus = torch.cuda.device_count()
334323
if n_gpus > 1 and args.multigpu is True:
324+
# Not supported at this time
325+
raise NotImplementedError('MultiGPU not supported yet.')
335326
logger.info('Using {} GPUs.'.format(n_gpus))
336327
model = nn.DataParallel(model)
337328

0 commit comments

Comments
 (0)