@@ -263,7 +263,7 @@ def _forward_twopass_faster(self, y: Tensor, side_params: Tensor) -> Dict[str, A
263263 @torch .no_grad ()
264264 def _y_ctx_zero (self , y : Tensor ) -> Tensor :
265265 """Create a zero tensor with correct shape for y_ctx."""
266- return self ._mask (self .context_prediction (y ).detach (), "all" )
266+ return self ._mask_all (self .context_prediction (y ).detach ())
267267
268268 def compress (self , y : Tensor , side_params : Tensor ) -> Dict [str , Any ]:
269269 n , c , h , w = y .shape
@@ -275,7 +275,7 @@ def compress(self, y: Tensor, side_params: Tensor) -> Dict[str, Any]:
275275 for i in range (2 ):
276276 y_ctx_i = self .unembed (self .context_prediction (self .embed (y_hat_ )))[i ]
277277 if i == 0 :
278- y_ctx_i = self ._mask (y_ctx_i , "all" )
278+ y_ctx_i = self ._mask_all (y_ctx_i )
279279 params_i = self .entropy_parameters (self .merge (y_ctx_i , side_params_ [i ]))
280280 y_out = self .latent_codec ["y" ].compress (y_ [i ], params_i )
281281 y_hat_ [i ] = y_out ["y_hat" ]
@@ -309,7 +309,7 @@ def decompress(
309309 for i in range (2 ):
310310 y_ctx_i = self .unembed (self .context_prediction (self .embed (y_hat_ )))[i ]
311311 if i == 0 :
312- y_ctx_i = self ._mask (y_ctx_i , "all" )
312+ y_ctx_i = self ._mask_all (y_ctx_i )
313313 params_i = self .entropy_parameters (self .merge (y_ctx_i , side_params_ [i ]))
314314 y_out = self .latent_codec ["y" ].decompress (
315315 [y_strings_ [i ]], y_i_shape , params_i
@@ -380,25 +380,21 @@ def _copy(self, dest: Tensor, src: Tensor, step: str) -> None:
380380 dest [..., 0 ::2 , 1 ::2 ] = src [..., 0 ::2 , 1 ::2 ]
381381 dest [..., 1 ::2 , 0 ::2 ] = src [..., 1 ::2 , 0 ::2 ]
382382
383- def _keep_only (self , y : Tensor , step : str , inplace : bool = False ) -> Tensor :
383+ def _keep_only (self , y : Tensor , step : str ) -> Tensor :
384384 """Keep only pixels in the current step, and zero out the rest."""
385- return self ._mask (
386- y ,
387- parity = self .non_anchor_parity if step == "anchor" else self .anchor_parity ,
388- inplace = inplace ,
389- )
390-
391- def _mask (self , y : Tensor , parity : str , inplace : bool = False ) -> Tensor :
392- if not inplace :
393- y = y .clone ()
385+ y = y .clone ()
386+ parity = self .anchor_parity if step == "anchor" else self .non_anchor_parity
394387 if parity == "even" :
395- y [..., 0 ::2 , 0 ::2 ] = 0
396- y [..., 1 ::2 , 1 ::2 ] = 0
397- elif parity == "odd" :
398388 y [..., 0 ::2 , 1 ::2 ] = 0
399389 y [..., 1 ::2 , 0 ::2 ] = 0
400- elif parity == "all" :
401- y [:] = 0
390+ elif parity == "odd" :
391+ y [..., 0 ::2 , 0 ::2 ] = 0
392+ y [..., 1 ::2 , 1 ::2 ] = 0
393+ return y
394+
395+ def _mask_all (self , y : Tensor ) -> Tensor :
396+ y = y .clone ()
397+ y [:] = 0
402398 return y
403399
404400 def merge (self , * args ):
0 commit comments