Skip to content

Commit 05bf9d9

Browse files
committed
refactor: inline _mask
1 parent 7163a04 commit 05bf9d9

File tree

1 file changed

+14
-18
lines changed

1 file changed

+14
-18
lines changed

compressai/latent_codecs/checkerboard.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,7 @@ def _forward_twopass_faster(self, y: Tensor, side_params: Tensor) -> Dict[str, A
263263
@torch.no_grad()
264264
def _y_ctx_zero(self, y: Tensor) -> Tensor:
265265
"""Create a zero tensor with correct shape for y_ctx."""
266-
return self._mask(self.context_prediction(y).detach(), "all")
266+
return self._mask_all(self.context_prediction(y).detach())
267267

268268
def compress(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]:
269269
n, c, h, w = y.shape
@@ -275,7 +275,7 @@ def compress(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]:
275275
for i in range(2):
276276
y_ctx_i = self.unembed(self.context_prediction(self.embed(y_hat_)))[i]
277277
if i == 0:
278-
y_ctx_i = self._mask(y_ctx_i, "all")
278+
y_ctx_i = self._mask_all(y_ctx_i)
279279
params_i = self.entropy_parameters(self.merge(y_ctx_i, side_params_[i]))
280280
y_out = self.latent_codec["y"].compress(y_[i], params_i)
281281
y_hat_[i] = y_out["y_hat"]
@@ -309,7 +309,7 @@ def decompress(
309309
for i in range(2):
310310
y_ctx_i = self.unembed(self.context_prediction(self.embed(y_hat_)))[i]
311311
if i == 0:
312-
y_ctx_i = self._mask(y_ctx_i, "all")
312+
y_ctx_i = self._mask_all(y_ctx_i)
313313
params_i = self.entropy_parameters(self.merge(y_ctx_i, side_params_[i]))
314314
y_out = self.latent_codec["y"].decompress(
315315
[y_strings_[i]], y_i_shape, params_i
@@ -380,25 +380,21 @@ def _copy(self, dest: Tensor, src: Tensor, step: str) -> None:
380380
dest[..., 0::2, 1::2] = src[..., 0::2, 1::2]
381381
dest[..., 1::2, 0::2] = src[..., 1::2, 0::2]
382382

383-
def _keep_only(self, y: Tensor, step: str, inplace: bool = False) -> Tensor:
383+
def _keep_only(self, y: Tensor, step: str) -> Tensor:
384384
"""Keep only pixels in the current step, and zero out the rest."""
385-
return self._mask(
386-
y,
387-
parity=self.non_anchor_parity if step == "anchor" else self.anchor_parity,
388-
inplace=inplace,
389-
)
390-
391-
def _mask(self, y: Tensor, parity: str, inplace: bool = False) -> Tensor:
392-
if not inplace:
393-
y = y.clone()
385+
y = y.clone()
386+
parity = self.anchor_parity if step == "anchor" else self.non_anchor_parity
394387
if parity == "even":
395-
y[..., 0::2, 0::2] = 0
396-
y[..., 1::2, 1::2] = 0
397-
elif parity == "odd":
398388
y[..., 0::2, 1::2] = 0
399389
y[..., 1::2, 0::2] = 0
400-
elif parity == "all":
401-
y[:] = 0
390+
elif parity == "odd":
391+
y[..., 0::2, 0::2] = 0
392+
y[..., 1::2, 1::2] = 0
393+
return y
394+
395+
def _mask_all(self, y: Tensor) -> Tensor:
396+
y = y.clone()
397+
y[:] = 0
402398
return y
403399

404400
def merge(self, *args):

0 commit comments

Comments
 (0)