Skip to content

Commit fda2fd6

Browse files
committed
refactor: improve readability
1 parent 68ea0ab commit fda2fd6

File tree

2 files changed

+21
-24
lines changed

2 files changed

+21
-24
lines changed

compressai/latent_codecs/checkerboard.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -208,7 +208,7 @@ def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]:
208208

209209
def _forward_twopass_step(
210210
self, y: Tensor, side_params: Tensor, params: Tensor, y_ctx: Tensor, step: str
211-
) -> Dict[str, Any]:
211+
) -> Tensor:
212212
# NOTE: The _i variables contain only the current step's pixels.
213213
assert step in ("anchor", "non_anchor")
214214

compressai/latent_codecs/rasterscan.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(
9898
self.gaussian_conditional = gaussian_conditional or GaussianConditional(None)
9999
self.entropy_parameters = entropy_parameters or nn.Identity()
100100
self.context_prediction = context_prediction or MaskedConv2d()
101-
self.kernel_size = _reduce_seq(self.context_prediction.kernel_size)
101+
self.kernel_size = _to_single(self.context_prediction.kernel_size)
102102
self.padding = (self.kernel_size - 1) // 2
103103

104104
def forward(self, y: Tensor, params: Tensor) -> Dict[str, Any]:
@@ -113,8 +113,11 @@ def forward(self, y: Tensor, params: Tensor) -> Dict[str, Any]:
113113

114114
def compress(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]:
115115
n, _, y_height, y_width = y.shape
116-
ds = [
117-
self._compress_single(
116+
ds = []
117+
for i in range(n):
118+
encoder = BufferedRansEncoder()
119+
y_hat = raster_scan_compress_single_stream(
120+
encoder=encoder,
118121
y=y[i : i + 1, :, :, :],
119122
params=ctx_params[i : i + 1, :, :, :],
120123
gaussian_conditional=self.gaussian_conditional,
@@ -126,16 +129,10 @@ def compress(self, y: Tensor, ctx_params: Tensor) -> Dict[str, Any]:
126129
kernel_size=self.kernel_size,
127130
merge=self.merge,
128131
)
129-
for i in range(n)
130-
]
132+
y_strings = encoder.flush()
133+
ds.append({"strings": [y_strings], "y_hat": y_hat.squeeze(0)})
131134
return {**default_collate(ds), "shape": y.shape[2:4]}
132135

133-
def _compress_single(self, **kwargs):
134-
encoder = BufferedRansEncoder()
135-
y_hat = raster_scan_compress_single_stream(encoder=encoder, **kwargs)
136-
y_strings = encoder.flush()
137-
return {"strings": [y_strings], "y_hat": y_hat.squeeze(0)}
138-
139136
def decompress(
140137
self,
141138
strings: List[List[bytes]],
@@ -145,9 +142,12 @@ def decompress(
145142
) -> Dict[str, Any]:
146143
(y_strings,) = strings
147144
y_height, y_width = shape
148-
ds = [
149-
self._decompress_single(
150-
y_string=y_strings[i],
145+
ds = []
146+
for i in range(len(y_strings)):
147+
decoder = RansDecoder()
148+
decoder.set_stream(y_strings[i])
149+
y_hat = raster_scan_decompress_single_stream(
150+
decoder=decoder,
151151
params=ctx_params[i : i + 1, :, :, :],
152152
gaussian_conditional=self.gaussian_conditional,
153153
entropy_parameters=self.entropy_parameters,
@@ -159,16 +159,9 @@ def decompress(
159159
device=ctx_params.device,
160160
merge=self.merge,
161161
)
162-
for i in range(len(y_strings))
163-
]
162+
ds.append({"y_hat": y_hat.squeeze(0)})
164163
return default_collate(ds)
165164

166-
def _decompress_single(self, y_string, **kwargs):
167-
decoder = RansDecoder()
168-
decoder.set_stream(y_string)
169-
y_hat = raster_scan_decompress_single_stream(decoder=decoder, **kwargs)
170-
return {"y_hat": y_hat.squeeze(0)}
171-
172165
@staticmethod
173166
def merge(*args):
174167
return torch.cat(args, dim=1)
@@ -312,12 +305,16 @@ def _pad_2d(x: Tensor, padding: int) -> Tensor:
312305
return F.pad(x, (padding, padding, padding, padding))
313306

314307

315-
def _reduce_seq(xs):
308+
def _to_single(xs):
316309
assert all(x == xs[0] for x in xs)
317310
return xs[0]
318311

319312

320313
def default_collate(batch: List[Dict[K, V]]) -> Dict[K, List[V]]:
314+
"""Combines a list of dictionaries into a single dictionary.
315+
316+
Workaround to ``torch.utils.data.default_collate`` bug in PyTorch 2.0.0.
317+
"""
321318
if not isinstance(batch, list) or any(not isinstance(d, dict) for d in batch):
322319
raise NotImplementedError
323320

0 commit comments

Comments
 (0)