Skip to content

Commit 5f9ce2a

Browse files
authored
Merge pull request #3 from Justin-Tan/ans_compression
Adds rANS compression support.
2 parents 56cb383 + 39d29a1 commit 5f9ce2a

File tree

17 files changed

+2633
-382
lines changed

17 files changed

+2633
-382
lines changed

README.md

Lines changed: 61 additions & 72 deletions
Large diffs are not rendered by default.

assets/EXAMPLES.md

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
Original | Reconstruction
2+
:-------------------------:|:-------------------------:
3+
![guess](assets/hific/CLIC2020_5_RECON_0.160bpp.png) | ![guess](assets/originals/CLIC2020_5.png)
4+
5+
<details>
6+
7+
<summary>Image 1</summary>
8+
9+
```python
10+
Original: B (11.6 bpp) | HIFIC: A (0.160 bpp). Ratio: 72.5.
11+
```
12+
13+
</details>
14+
15+
A | B
16+
:-------------------------:|:-------------------------:
17+
![guess](assets/originals/CLIC2020_20.png) | ![guess](assets/hific/CLIC2020_20_RECON_0.330bpp.png)
18+
19+
<details>
20+
21+
<summary>Image 2</summary>
22+
23+
```python
24+
Original: A (14.6 bpp) | HIFIC: B (0.330 bpp). Ratio: 44.2
25+
```
26+
27+
</details>
28+
29+
A | B
30+
:-------------------------:|:-------------------------:
31+
![guess](assets/originals/CLIC2020_18.png) | ![guess](assets/hific/CLIC2020_18_RECON_0.209bpp.png)
32+
33+
<details>
34+
35+
<summary>Image 3</summary>
36+
37+
```python
38+
Original: A (12.3 bpp) | HIFIC: B (0.209 bpp). Ratio: 58.9
39+
```
40+
41+
</details>
42+
43+
A | B
44+
:-------------------------:|:-------------------------:
45+
![guess](assets/hific/CLIC2020_19_RECON_0.565bpp.png) | ![guess](assets/originals/CLIC2020_19.png)
46+
47+
<details>
48+
49+
<summary>Image 4</summary>
50+
51+
```python
52+
Original: B (19.9 bpp) | HIFIC: A (0.565 bpp). Ratio: 35.2
53+
```
54+
55+
</details>
56+
57+
| Tables | Are | Cool |
58+
|:------------- |:-------------:| -----:|
59+
| col 3 is | right-aligned | $1600 |
60+
| col 2 is | centered | $12 |
61+
| col 1 is | left-aligned | $42 |
62+
| zebra stripes | are neat | $1 |

assets/USAGE_GUIDE.md

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Usage Guide
2+
3+
## Details
4+
5+
This repository defines a model for learnable image compression capable of compressing images of arbitrary size and resolution based on the paper ["High-Fidelity Generative Image Compression" (HIFIC) by Mentzer et. al.](https://hific.github.io/). There are three main components to this model, as described in the original paper:
6+
7+
1. An autoencoding architecture defining a nonlinear transform to latent space. This is used in place of the linear transforms used by traditional image codecs.
8+
2. A hierarchical (two-level in this case) entropy model over the quantized latent representation enabling lossless compression through standard entropy coding.
9+
3. A generator-discriminator component that encourages the decoder/generator component to yield realistic reconstructions.
10+
11+
The model is then trained end-to-end by optimization of a modified rate-distortion Lagrangian. Loosely, the model can be thought of as 'amortizing' the storage requirements for an generic input image through training a learnable compression/decompression scheme. The method is further described in the original paper [[0](https://arxiv.org/abs/2006.09965)]. The model is capable of yielding perceptually similar reconstructions to the input that tend to be more visually pleasing than standard image codecs which operate at comparable or higher bitrates.
12+
13+
This repository also includes a partial port of the [Tensorflow Compression library](https://github.com/tensorflow/compression) for general tools for neural image compression.
14+
15+
## Training
16+
17+
* Download a large (> 100,000) dataset of diverse color images. We found that using 1-2 training divisions of the [OpenImages](https://storage.googleapis.com/openimages/web/index.html) dataset was able to produce satisfactory results on arbitrary images. [Fabian Mentzer's L3C Repo](https://github.com/fab-jul/L3C-PyTorch/) provides utility functions for downloading and preprocessing OpenImages (the trained models did not use this exact split). Add the dataset path under the `DatasetPaths` class in `default_config.py`. Check default config/command line arguments:
18+
19+
```bash
20+
vim default_config.py
21+
python3 train.py -h
22+
```
23+
24+
* For best results, as described in the paper, train an initial base model using the rate-distortion loss only, together with the hyperprior model, e.g. to target low bitrates:
25+
26+
```bash
27+
# Train initial autoencoding model
28+
python3 train.py --model_type compression --regime low --n_steps 1e6
29+
```
30+
31+
* Then use the checkpoint of the trained base model to 'warmstart' the GAN architecture. Training the generator and discriminator from scratch was found to result in unstable training, but YMMV.
32+
33+
```bash
34+
# Train using full generator-discriminator loss
35+
python3 train.py --model_type compression_gan --regime low --n_steps 1e6 --warmstart --ckpt path/to/base/checkpoint
36+
```
37+
38+
* Training after the warmstart for 2e5 steps using a batch size of 16 was sufficient to get reasonable results at sub-0.2 `bpp` per validation image, on average, using the default config in the `low` regime. You can change regimes to `med` or `high` to tradeoff perceptual quality for increased bitrate.
39+
40+
* Perceptual distortion metrics and `bpp` tend to decrease with a pareto-like distribution over training, so model quality can probably be significantly improved by training for an extremely large number of steps.
41+
42+
* If you get out-of-memory errors, try, in decreasing order of priority:
43+
* Decreasing the batch size (default 16).
44+
* Decreasing the number of channels of the latent representation (`latent_channels`, default 220). You may be able to reduce this quite aggressively as the network is highly over-parameterized - many values of the latent representation are near-deterministic.
45+
* Reducing the number of residual blocks in the generator (`n_residual_blocks`, default 7, the original paper used 9).
46+
* Training on smaller crops (`crop_size`, default `256 x 256`).
47+
48+
* Logs for each experiment, including image reconstructions, are automatically created and periodically saved under `experiments/` with the appropriate name/timestamp. Metrics can be visualized via `tensorboard`:
49+
50+
```bash
51+
tensorboard --logdir experiments/my_experiment/tensorboard --port 2401
52+
```
53+
54+
Some sample logs for a couple of models can be found below:
55+
56+
* [Low bitrate regime (warmstart)](https://tensorboard.dev/experiment/xJV4hjbxRFy3TzrdYl7MXA/).
57+
* [Low bitrate regime (full GAN loss)](https://tensorboard.dev/experiment/ETa0JIeOS0ONNZuNkIdrQw/).
58+
* [High bitrate regime (full GAN loss)](https://tensorboard.dev/experiment/hAf1NYrqSVieKoDOcNpoGw/).
59+
60+
## Compression
61+
62+
* `compress.py` will compress generic images under some specified entropy model. This performs a forward pass through the model to obtain the compressed representation, optionally coding the representation using a vectorized rANS entropy coder. As the model architecture is fully convolutional, compression will work with images of arbitrary size/resolution (subject to memory constraints).
63+
64+
* For message transmission, separate entropy models over the latents and hyperlatents must be instantiated and shared between sender and receiver.
65+
* The sender computes the bottleneck tensor and calls the `compress()` method in `src/model.py` to obtain the compressed representation for transmission.
66+
* The receiver calls the `decompress()` method in `src/model.py` to obtain the quantized bottleneck tensor, which is then passed through the generator to obtain the reconstruction.
67+
68+
* The compression scheme in hierarchial in the sense that two 'levels' of information representing the latent and hyperlatent variables must be compressed and stored in the message, together with the shape of the encoded data.
69+
70+
```bash
71+
# Check arguments
72+
python3 compress.py -h
73+
74+
python3 compress.py -i path/to/image/dir -ckpt path/to/trained/model --reconstruct
75+
```
76+
77+
* Optionally, reconstructions from the compressed format can be generated by passing the `--reconstruct` flag. Decoding without executing the rANS coder takes around 2-3 seconds for ~megapixel images on GPU, but this can definitely be optimized. As the hyperprior entropy model involves a series of matrix multiplications, decoding is significantly faster on GPU.
78+
79+
* Executing the rANS coder is quite slow currently and represents a performance bottleneck. Passing the `--vectorize` flag is much faster, but incurs a constant-bit overhead. The batch size needs to be quite large to make this overhead negligible, suitable for e.g. video frames but not so good for general images. Working on a fix.
80+
81+
## Pretrained Models
82+
83+
* Pretrained models using the OpenImages dataset can be found below. The examples at the end of this readme were produced using the `HIFIC-med` model. Each model was trained for around `2e5` warmup steps and `2e5` steps with the full generative loss. Note the original paper trained for `1e6` steps in each mode, so you can probably get better performance by training from scratch yourself.
84+
85+
* To use a pretrained model, download the selected model (~2 GB) and point the `-ckpt` argument in the command above to the corresponding path. If you want to finetune this model, e.g. on some domain-specific dataset, use the following options for each respective model (you will probably need to adapt the learning rate and rate-penalty schedule yourself):
86+
87+
| Target bitrate (bpp) | Weights | Training Instructions |
88+
| ----------- | -------------------------------- | ---------------------- |
89+
| 0.14 | [`HIFIC-low`](https://drive.google.com/open?id=1hfFTkZbs_VOBmXQ-M4bYEPejrD76lAY9) | <pre lang=bash>`python3 train.py --model_type compression_gan --regime low --warmstart -ckpt path/to/trained/model -nrb 9 -norm`</pre> |
90+
| 0.30 | [`HIFIC-med`](https://drive.google.com/open?id=1QNoX0AGKTBkthMJGPfQI0dT0_tnysYUb) | <pre lang=bash>`python3 train.py --model_type compression_gan --regime med --warmstart -ckpt path/to/trained/model --likelihood_type logistic`</pre> |
91+
| 0.45 | [`HIFIC-high`](https://drive.google.com/open?id=1BFYpvhVIA_Ek2QsHBbKnaBE8wn1GhFyA) | <pre lang=bash>`python3 train.py --model_type compression_gan --regime high --warmstart -ckpt path/to/trained/model -nrb 9 -norm`</pre> |
92+
93+
## Extensibility
94+
95+
* Network architectures can be modified by changing the respective files under `src/network`.
96+
* The entropy model for both latents and hyperlatents can be changed by modifying `src/network/hyperprior`. For reference, there is an implementation of a discrete-logistic latent mixture model instead of the default latent mean-scale Gaussian model.
97+
* The exact compression algorithm used can be replaced with any entropy coder that makes use of indexed probability tables.
98+
99+
## Notes
100+
101+
* The reported `bpp` is the theoretical bitrate required to losslessly store the quantized latent representation of an image. Comparing this (not the size of the reconstruction) against the original size of the image will give you an idea of the reduction in memory footprint.
102+
* The total size of the model (using the original architecture) is around 737 MB. Forward pass time should scale sublinearly provided everything fits in memory. A complete forward pass using a batch of 10 `256 x 256` images takes around 45s on a 2.8 GHz Intel Core i7.
103+
* You may get an OOM error when compressing images which are too large (`>~ 4000 x 4000` on a typical consumer GPU). It's possible to get around this by splitting the input into distinct crops whose forward pass will fit in memory. We're working on a fix to automatically support this.
104+
* Compression of >~ megapixel images takes around 8 GB of RAM.
105+
106+
## Contributing
107+
108+
Feel free to submit any questions/corrections/suggestions/bugs as issues. Pull requests are welcome. Thanks to Grace for helping refactor my code.
109+
110+
### References
111+
112+
The following additional papers were useful to understand implementation details.
113+
114+
0. Fabian Mentzer, George Toderici, Michael Tschannen, Eirikur Agustsson. High-Fidelity Generative Image Compression. [arXiv:2006.09965 (2020)](https://arxiv.org/abs/2006.09965).
115+
1. Johannes Ballé, David Minnen, Saurabh Singh, Sung Jin Hwang, Nick Johnston. Variational image compression with a scale hyperprior. [arXiv:1802.01436 (2018)](https://arxiv.org/abs/1802.01436).
116+
2. David Minnen, Johannes Ballé, George Toderici. Joint Autoregressive and Hierarchical Priors for Learned Image Compression. [arXiv 1809.02736 (2018)](https://arxiv.org/abs/1809.02736).
117+
3. Johannes Ballé, Valero Laparra, Eero P. Simoncelli. End-to-end optimization of nonlinear transform codes for perceptual quality. [arXiv 1607.05006 (2016)](https://arxiv.org/abs/1607.05006).
118+
4. Fabian Mentzer, Eirikur Agustsson, Michael Tschannen, Radu Timofte, Luc Van Gool. Practical Full Resolution Learned Lossless Image Compression. [arXiv 1811.12817 (2018)](https://arxiv.org/abs/1811.12817).
119+
120+
## TODO (priority descending)
121+
122+
* Include `torchac` support for entropy coding.
123+
* Implement universal code for overflow values.
124+
* Investigate bit overhead in vectorized rANS implementation.
125+
* Rewrite rANS implementation for speed.

compress.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def make_deterministic(seed=42):
3232
def compress_batch(args):
3333

3434
# Reproducibility
35-
make_deterministic()
35+
# make_deterministic()
3636
perceptual_loss_fn = ps.PerceptualLoss(model='net-lin', net='alex', use_gpu=torch.cuda.is_available())
3737

3838
# Load model
@@ -41,20 +41,26 @@ def compress_batch(args):
4141
loaded_args, model, _ = utils.load_model(args.ckpt_path, logger, device, model_mode=ModelModes.EVALUATION,
4242
current_args_d=None, prediction=True, strict=False)
4343

44+
# Override current arguments with recorded
4445
dictify = lambda x: dict((n, getattr(x, n)) for n in dir(x) if not (n.startswith('__') or 'logger' in n))
4546
loaded_args_d, args_d = dictify(loaded_args), dictify(args)
4647
loaded_args_d.update(args_d)
4748
args = utils.Struct(**loaded_args_d)
4849
logger.info(loaded_args_d)
4950

51+
# Build probability tables
52+
model.Hyperprior.hyperprior_entropy_model.build_tables()
53+
54+
5055
eval_loader = datasets.get_dataloaders('evaluation', root=args.image_dir, batch_size=args.batch_size,
5156
logger=logger, shuffle=False, normalize=args.normalize_input_image)
5257

5358
n, N = 0, len(eval_loader.dataset)
5459
input_filenames_total = list()
5560
output_filenames_total = list()
5661
bpp_total, q_bpp_total, LPIPS_total = torch.Tensor(N), torch.Tensor(N), torch.Tensor(N)
57-
62+
utils.makedirs(args.output_dir)
63+
5864
start_time = time.time()
5965

6066
with torch.no_grad():
@@ -63,7 +69,14 @@ def compress_batch(args):
6369
data = data.to(device, dtype=torch.float)
6470
B = data.size(0)
6571

66-
reconstruction, q_bpp, n_bpp = model(data, writeout=False)
72+
if args.reconstruct is True:
73+
# Reconstruction without compression
74+
reconstruction, q_bpp = model(data, writeout=False)
75+
else:
76+
# Perform entropy coding
77+
compressed_output = model.compress(data)
78+
reconstruction = model.decompress(compressed_output)
79+
q_bpp = compressed_output.total_bpp
6780

6881
if args.normalize_input_image is True:
6982
# [-1., 1.] -> [0., 1.]
@@ -77,13 +90,14 @@ def compress_batch(args):
7790
if B > 1:
7891
q_bpp_per_im = float(q_bpp.cpu().numpy()[subidx])
7992
else:
80-
q_bpp_per_im = float(q_bpp.item())
93+
q_bpp_per_im = float(q_bpp.item()) if type(q_bpp) == torch.Tensor else float(q_bpp)
94+
8195
fname = os.path.join(args.output_dir, "{}_RECON_{:.3f}bpp.png".format(filenames[subidx], q_bpp_per_im))
8296
torchvision.utils.save_image(reconstruction[subidx], fname, normalize=True)
8397
output_filenames_total.append(fname)
8498

8599
bpp_total[n:n + B] = bpp.data
86-
q_bpp_total[n:n + B] = q_bpp.data
100+
q_bpp_total[n:n + B] = q_bpp.data if type(q_bpp) == torch.Tensor else q_bpp
87101
LPIPS_total[n:n + B] = perceptual_loss.data
88102
n += B
89103

@@ -116,6 +130,7 @@ def main(**kwargs):
116130
help="Path to directory to store output images")
117131
parser.add_argument('-bs', '--batch_size', type=int, default=1,
118132
help="Loader batch size. Set to 1 if images in directory are different sizes.")
133+
parser.add_argument("-rc", "--reconstruct", help="Reconstruct input image without compression.", action="store_true")
119134
args = parser.parse_args()
120135

121136
input_images = glob.glob(os.path.join(args.image_dir, '*.jpg'))
@@ -125,7 +140,7 @@ def main(**kwargs):
125140

126141
print('Input images')
127142
pprint(input_images)
128-
# Launch training
143+
129144
compress_batch(args)
130145

131146
if __name__ == '__main__':

default_config.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ class ModelTypes(object):
1414
class ModelModes(object):
1515
TRAINING = 'training'
1616
VALIDATION = 'validation'
17-
EVALUATION = 'evaluation'
17+
EVALUATION = 'evaluation' # actual entropy coding
1818

1919
class Datasets(object):
2020
OPENIMAGES = 'openimages'
@@ -29,12 +29,6 @@ class DatasetPaths(object):
2929
class directories(object):
3030
experiments = 'experiments'
3131

32-
class checkpoints(object):
33-
low_rate1 = 'experiments/norm_low_rate_openimages_compression_2020_08_19_16_13/checkpoints/norm_low_rate_openimages_compression_2020_08_19_16_13_epoch2_idx168720_2020_08_21_04:00.pt'
34-
low_rate_nrb9 = 'experiments/low_rate9_norm_openimages_compression_2020_08_19_16_59/checkpoints/low_rate9_norm_openimages_compression_2020_08_19_16_59_epoch4_idx237436_2020_08_22_00:21.pt'
35-
# python3 train.py -n low_rate_gan_v1_norm -mt compression_gan -bs 8 -norm --regime low -steps 1e6 --warmstart -ckpt
36-
# experiments/norm_low_rate_openimages_compression_2020_08_19_16_13/checkpoints/norm_low_rate_openimages_compression_2020_08_19_16_13_epoch2_idx168720_2020_08_21_04:00.pt
37-
3832
class args(object):
3933
"""
4034
Shared config

0 commit comments

Comments
 (0)