Skip to content

Commit a950ee7

Browse files
author
Justin Tan (unimelb)
committed
refactor everything
2 parents 0e9d010 + 68694d6 commit a950ee7

File tree

5 files changed

+17
-23
lines changed

5 files changed

+17
-23
lines changed

src/compression/ans.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ def push(x, starts, freqs, precisions):
6060
precision: Determines normalization factor of probability distribution.
6161
"""
6262
head, tail = x
63-
6463
assert head.shape == starts.shape == freqs.shape, (
6564
f"Inconsistent encoder shapes! head: {head.shape} | "
6665
f"starts: {starts.shape} | freqs: {freqs.shape}")

src/compression/entropy_coding.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
OVERFLOW_WIDTH = 4
99
OVERFLOW_CODE = 1 << (1 << OVERFLOW_WIDTH)
10-
PATCH_SIZE = (1,2)
10+
PATCH_SIZE = (1,1)
1111

1212
import torch
1313
import numpy as np
@@ -57,11 +57,10 @@ def _enc_statfun(value):
5757
# (coding_shape) = (C,H,W) by default but canbe generalized
5858
# cdf_i: [(coding_shape), pmf_length + 2]
5959
# value: [(coding_shape)]
60-
lower = np.squeeze(np.take_along_axis(cdf_i,
61-
np.expand_dims(value, -1), axis=-1))
62-
upper = np.squeeze(np.take_along_axis(cdf_i,
63-
np.expand_dims(value + 1, -1), axis=-1))
64-
60+
lower = np.take_along_axis(cdf_i,
61+
np.expand_dims(value, -1), axis=-1)[..., 0]
62+
upper = np.take_along_axis(cdf_i,
63+
np.expand_dims(value + 1, -1), axis=-1)[..., 0]
6564
return lower, upper - lower
6665

6766
return _enc_statfun
@@ -280,14 +279,11 @@ def vec_ans_index_encoder(symbols, indices, cdf, cdf_length, cdf_offset, precisi
280279
if B == 1:
281280
# Vectorize on patches - there's probably a way to interlace patches with
282281
# batch elements for B > 1 ...
283-
print('og', values.sh
284282
if ((symbols_shape[2] % PATCH_SIZE[0] == 0) and (symbols_shape[3] % PATCH_SIZE[1] == 0)) is False:
285283
values = utils.pad_factor(torch.Tensor(values), symbols_shape[2:],
286284
factor=PATCH_SIZE).cpu().numpy().astype(np.int32)
287285
indices = utils.pad_factor(torch.Tensor(indices), symbols_shape[2:],
288286
factor=PATCH_SIZE).cpu().numpy().astype(np.int32)
289-
print(values.shape)
290-
print(symbols.shape)
291287

292288
assert (values.shape[2] % PATCH_SIZE[0] == 0) and (values.shape[3] % PATCH_SIZE[1] == 0)
293289
assert (indices.shape[2] % PATCH_SIZE[0] == 0) and (indices.shape[3] % PATCH_SIZE[1] == 0)

src/compression/prior_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -317,15 +317,15 @@ def forward(self, x, mean, scale, **kwargs):
317317

318318
import time
319319

320-
n_channels = 64
320+
n_channels = 256
321321
use_blocks = True
322322
vectorize = True
323323
prior_density = PriorDensity(n_channels)
324324
prior_entropy_model = PriorEntropyModel(distribution=prior_density)
325325

326326
loc, scale = 2.401, 3.43
327327
n_data = 1
328-
toy_shape = (n_data, n_channels, 149, 175)
328+
toy_shape = (n_data, n_channels, 34, 50)
329329
bottleneck, means = torch.randn(toy_shape), torch.randn(toy_shape)
330330
scales = torch.randn(toy_shape) * np.sqrt(scale) + loc
331331
scales = torch.clamp(scales, min=MIN_SCALE)

src/hyperprior.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ class Hyperprior(CodingModel):
143143

144144
def __init__(self, bottleneck_capacity=220, hyperlatent_filters=LARGE_HYPERLATENT_FILTERS,
145145
mode='large', likelihood_type='gaussian', scale_lower_bound=MIN_SCALE, entropy_code=False,
146-
vectorize_encoding=False, block_encode=True):
146+
vectorize_encoding=True, block_encode=True):
147147

148148
"""
149149
Introduces probabilistic model over latents of

src/model.py

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,6 @@ def compression_forward(self, x):
133133
if self.model_mode == ModelModes.EVALUATION and (self.training is False):
134134
n_encoder_downsamples = self.Encoder.n_downsampling_layers
135135
factor = 2 ** n_encoder_downsamples
136-
self.logger.info('Padding input image by {}'.format(factor))
137136
x = utils.pad_factor(x, x.size()[2:], factor)
138137

139138
# Encoder forward pass
@@ -142,7 +141,6 @@ def compression_forward(self, x):
142141
if self.model_mode == ModelModes.EVALUATION and (self.training is False):
143142
n_hyperencoder_downsamples = self.Hyperprior.analysis_net.n_downsampling_layers
144143
factor = 2 ** n_hyperencoder_downsamples
145-
self.logger.info('Padding latents by {}'.format(factor))
146144
y = utils.pad_factor(y, y.size()[2:], factor)
147145

148146
hyperinfo = self.Hyperprior(y, spatial_shape=x.size()[2:])
@@ -281,7 +279,6 @@ def compress(self, x):
281279
if self.model_mode == ModelModes.EVALUATION and (self.training is False):
282280
n_encoder_downsamples = self.Encoder.n_downsampling_layers
283281
factor = 2 ** n_encoder_downsamples
284-
self.logger.info('Padding input image to {}'.format(factor))
285282
x = utils.pad_factor(x, x.size()[2:], factor)
286283

287284
# Encoder forward pass
@@ -290,21 +287,23 @@ def compress(self, x):
290287
if self.model_mode == ModelModes.EVALUATION and (self.training is False):
291288
n_hyperencoder_downsamples = self.Hyperprior.analysis_net.n_downsampling_layers
292289
factor = 2 ** n_hyperencoder_downsamples
293-
self.logger.info('Padding latents to {}'.format(factor))
294290
y = utils.pad_factor(y, y.size()[2:], factor)
295291

296292
compression_output = self.Hyperprior.compress_forward(y, spatial_shape)
297293
attained_hbpp = 32 * len(compression_output.hyperlatents_encoded) / np.prod(spatial_shape)
298294
attained_lbpp = 32 * len(compression_output.latents_encoded) / np.prod(spatial_shape)
299295
attained_bpp = 32 * ((len(compression_output.hyperlatents_encoded) +
300296
len(compression_output.latents_encoded)) / np.prod(spatial_shape))
301-
print('BPP', compression_output.total_bpp)
302-
print('h BPP', compression_output.hyperlatent_bpp)
303-
print('l BPP', compression_output.latent_bpp)
304297

305-
print('Actual BPP', attained_bpp)
306-
print('h BPP', attained_hbpp)
307-
print('l BPP', attained_lbpp)
298+
self.logger.info('[ESTIMATED]')
299+
self.logger.info(f'BPP: {compression_output.total_bpp:.3f}')
300+
self.logger.info(f'HL BPP: {compression_output.hyperlatent_bpp:.3f}')
301+
self.logger.info(f'L BPP: {compression_output.latent_bpp:.3f}')
302+
303+
self.logger.info('[ATTAINED]')
304+
self.logger.info(f'BPP: {attained_bpp:.3f}')
305+
self.logger.info(f'HL BPP: {attained_hbpp:.3f}')
306+
self.logger.info(f'L BPP: {attained_lbpp:.3f}')
308307
return compression_output
309308

310309

0 commit comments

Comments
 (0)