Skip to content

Commit 7163a04

Browse files
committed
refactor: inline method
1 parent daca23c commit 7163a04

File tree

1 file changed

+28
-28
lines changed

1 file changed

+28
-28
lines changed

compressai/latent_codecs/checkerboard.py

Lines changed: 28 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -188,14 +188,36 @@ def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]:
188188

189189
params = y.new_zeros((B, C * 2, H, W))
190190

191-
y_hat_anchors = self._forward_twopass_step(
192-
y, side_params, params, self._y_ctx_zero(y), "anchor"
193-
)
191+
y_hat_ = []
194192

195-
y_hat_non_anchors = self._forward_twopass_step(
196-
y, side_params, params, self.context_prediction(y_hat_anchors), "non_anchor"
197-
)
193+
# NOTE: The _i variables contain only the current step's pixels.
194+
# i=0: step=anchor
195+
# i=1: step=non_anchor
196+
197+
for step in ("anchor", "non_anchor"):
198+
if step == "anchor":
199+
y_ctx = self._y_ctx_zero(y)
200+
else: # step == "non_anchor"
201+
y_hat_anchors = y_hat_[0]
202+
y_ctx = self.context_prediction(y_hat_anchors)
203+
204+
params_i = self.entropy_parameters(self.merge(y_ctx, side_params))
205+
206+
# Save params for current step. This is later used for entropy estimation.
207+
self._copy(params, params_i, step)
208+
209+
# Keep only elements needed for current step.
210+
# It's not necessary to mask the rest out just yet, but it doesn't hurt.
211+
params_i = self._keep_only(params_i, step)
212+
y_i = self._keep_only(y, step)
213+
214+
# Determine y_hat for current step, and mask out the other pixels.
215+
_, means_i = self.latent_codec["y"]._chunk(params_i)
216+
y_hat_i = self._keep_only(quantize_ste(y_i - means_i) + means_i, step)
217+
218+
y_hat_.append(y_hat_i)
198219

220+
[y_hat_anchors, y_hat_non_anchors] = y_hat_
199221
y_hat = y_hat_anchors + y_hat_non_anchors
200222
y_out = self.latent_codec["y"](y, params)
201223

@@ -206,28 +228,6 @@ def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]:
206228
"y_hat": y_hat,
207229
}
208230

209-
def _forward_twopass_step(
210-
self, y: Tensor, side_params: Tensor, params: Tensor, y_ctx: Tensor, step: str
211-
) -> Tensor:
212-
# NOTE: The _i variables contain only the current step's pixels.
213-
assert step in ("anchor", "non_anchor")
214-
215-
params_i = self.entropy_parameters(self.merge(y_ctx, side_params))
216-
217-
# Save params for current step. This is later used for entropy estimation.
218-
self._copy(params, params_i, step)
219-
220-
# Keep only elements needed for current step.
221-
# It's not necessary to mask the rest out just yet, but it doesn't hurt.
222-
params_i = self._keep_only(params_i, step)
223-
y_i = self._keep_only(y, step)
224-
225-
# Determine y_hat for current step, and mask out the other pixels.
226-
_, means_i = self.latent_codec["y"]._chunk(params_i)
227-
y_hat_i = self._keep_only(quantize_ste(y_i - means_i) + means_i, step)
228-
229-
return y_hat_i
230-
231231
def _forward_twopass_faster(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]:
232232
"""Runs the entropy parameters network in two passes.
233233

0 commit comments

Comments
 (0)