@@ -188,14 +188,36 @@ def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]:
188188
189189 params = y .new_zeros ((B , C * 2 , H , W ))
190190
191- y_hat_anchors = self ._forward_twopass_step (
192- y , side_params , params , self ._y_ctx_zero (y ), "anchor"
193- )
191+ y_hat_ = []
194192
195- y_hat_non_anchors = self ._forward_twopass_step (
196- y , side_params , params , self .context_prediction (y_hat_anchors ), "non_anchor"
197- )
193+ # NOTE: The _i variables contain only the current step's pixels.
194+ # i=0: step=anchor
195+ # i=1: step=non_anchor
196+
197+ for step in ("anchor" , "non_anchor" ):
198+ if step == "anchor" :
199+ y_ctx = self ._y_ctx_zero (y )
200+ else : # step == "non_anchor"
201+ y_hat_anchors = y_hat_ [0 ]
202+ y_ctx = self .context_prediction (y_hat_anchors )
203+
204+ params_i = self .entropy_parameters (self .merge (y_ctx , side_params ))
205+
206+ # Save params for current step. This is later used for entropy estimation.
207+ self ._copy (params , params_i , step )
208+
209+ # Keep only elements needed for current step.
210+ # It's not necessary to mask the rest out just yet, but it doesn't hurt.
211+ params_i = self ._keep_only (params_i , step )
212+ y_i = self ._keep_only (y , step )
213+
214+ # Determine y_hat for current step, and mask out the other pixels.
215+ _ , means_i = self .latent_codec ["y" ]._chunk (params_i )
216+ y_hat_i = self ._keep_only (quantize_ste (y_i - means_i ) + means_i , step )
217+
218+ y_hat_ .append (y_hat_i )
198219
220+ [y_hat_anchors , y_hat_non_anchors ] = y_hat_
199221 y_hat = y_hat_anchors + y_hat_non_anchors
200222 y_out = self .latent_codec ["y" ](y , params )
201223
@@ -206,28 +228,6 @@ def _forward_twopass(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]:
206228 "y_hat" : y_hat ,
207229 }
208230
209- def _forward_twopass_step (
210- self , y : Tensor , side_params : Tensor , params : Tensor , y_ctx : Tensor , step : str
211- ) -> Tensor :
212- # NOTE: The _i variables contain only the current step's pixels.
213- assert step in ("anchor" , "non_anchor" )
214-
215- params_i = self .entropy_parameters (self .merge (y_ctx , side_params ))
216-
217- # Save params for current step. This is later used for entropy estimation.
218- self ._copy (params , params_i , step )
219-
220- # Keep only elements needed for current step.
221- # It's not necessary to mask the rest out just yet, but it doesn't hurt.
222- params_i = self ._keep_only (params_i , step )
223- y_i = self ._keep_only (y , step )
224-
225- # Determine y_hat for current step, and mask out the other pixels.
226- _ , means_i = self .latent_codec ["y" ]._chunk (params_i )
227- y_hat_i = self ._keep_only (quantize_ste (y_i - means_i ) + means_i , step )
228-
229- return y_hat_i
230-
231231 def _forward_twopass_faster (self , y : Tensor , side_params : Tensor ) -> Dict [str , Any ]:
232232 """Runs the entropy parameters network in two passes.
233233
0 commit comments