@@ -211,7 +211,6 @@ def __init__(
211211 dd = dict (device = device , dtype = dtype )
212212 super ().__init__ ()
213213 self .k = kernel_size
214- self .unfold = nn .Unfold (kernel_size = (kernel_size , kernel_size ), stride = (kernel_size , kernel_size ))
215214 self .transform = Dct2d (kernel_size , kernel_type , orthonormal , ** dd )
216215 self .permutation = _zigzag_permutation (kernel_size , kernel_size )
217216
@@ -222,19 +221,21 @@ def __init__(
222221
223222 self .register_buffer ('mean' , torch .tensor (_DCT_MEAN , device = device ), persistent = False )
224223 self .register_buffer ('var' , torch .tensor (_DCT_VAR , device = device ), persistent = False )
225- self .register_buffer ('imagenet_mean' , torch .tensor ([0.485 , 0.456 , 0.406 ], device = device ), persistent = False )
226- self .register_buffer ('imagenet_std' , torch .tensor ([0.229 , 0.224 , 0.225 ], device = device ), persistent = False )
224+ # Shape (3, 1, 1) for BCHW broadcasting
225+ self .register_buffer ('imagenet_mean' , torch .tensor ([0.485 , 0.456 , 0.406 ], device = device ).view (3 , 1 , 1 ), persistent = False )
226+ self .register_buffer ('imagenet_std' , torch .tensor ([0.229 , 0.224 , 0.225 ], device = device ).view (3 , 1 , 1 ), persistent = False )
227227
228228 def _denormalize (self , x : torch .Tensor ) -> torch .Tensor :
229229 """Convert from ImageNet normalized to [0, 255] range."""
230230 return x .mul (self .imagenet_std ).add_ (self .imagenet_mean ) * 255
231231
232232 def _rgb_to_ycbcr (self , x : torch .Tensor ) -> torch .Tensor :
233- """Convert RGB to YCbCr color space."""
234- y = (x [:, :, :, 0 ] * 0.299 ) + (x [:, :, :, 1 ] * 0.587 ) + (x [:, :, :, 2 ] * 0.114 )
235- cb = 0.564 * (x [:, :, :, 2 ] - y ) + 128
236- cr = 0.713 * (x [:, :, :, 0 ] - y ) + 128
237- return torch .stack ([y , cb , cr ], dim = - 1 )
233+ """Convert RGB to YCbCr color space (BCHW input/output)."""
234+ r , g , b = x [:, 0 ], x [:, 1 ], x [:, 2 ]
235+ y = r * 0.299 + g * 0.587 + b * 0.114
236+ cb = 0.564 * (b - y ) + 128
237+ cr = 0.713 * (r - y ) + 128
238+ return torch .stack ([y , cb , cr ], dim = 1 )
238239
239240 def _frequency_normalize (self , x : torch .Tensor ) -> torch .Tensor :
240241 """Normalize DCT coefficients using precomputed statistics."""
@@ -243,12 +244,11 @@ def _frequency_normalize(self, x: torch.Tensor) -> torch.Tensor:
243244
244245 def forward (self , x : torch .Tensor ) -> torch .Tensor :
245246 b , c , h , w = x .shape
246- x = x .permute (0 , 2 , 3 , 1 )
247247 x = self ._denormalize (x )
248248 x = self ._rgb_to_ycbcr (x )
249- x = x . permute ( 0 , 3 , 1 , 2 )
250- x = self . unfold ( x ). transpose ( - 1 , - 2 )
251- x = x .reshape ( b , h // self . k , w // self . k , c , self . k , self . k )
249+ # Extract non-overlapping k x k patches
250+ x = x . reshape ( b , c , h // self . k , self . k , w // self . k , self . k ) # (B, C, H//k, k, W//k, k )
251+ x = x .permute ( 0 , 2 , 4 , 1 , 3 , 5 ) # (B, H//k, W// k, C, k, k)
252252 x = self .transform (x )
253253 x = x .reshape (- 1 , c , self .k * self .k )
254254 x = x [:, :, self .permutation ]
@@ -275,14 +275,14 @@ def __init__(
275275 dd = dict (device = device , dtype = dtype )
276276 super ().__init__ ()
277277 self .k = kernel_size
278- self .unfold = nn .Unfold (kernel_size = (kernel_size , kernel_size ), stride = (kernel_size , kernel_size ))
279278 self .transform = Dct2d (kernel_size , kernel_type , orthonormal , ** dd )
280279 self .permutation = _zigzag_permutation (kernel_size , kernel_size )
281280
282281 def forward (self , x : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
283282 b , c , h , w = x .shape
284- x = self .unfold (x ).transpose (- 1 , - 2 )
285- x = x .reshape (b , h // self .k , w // self .k , c , self .k , self .k )
283+ # Extract non-overlapping k x k patches
284+ x = x .reshape (b , c , h // self .k , self .k , w // self .k , self .k ) # (B, C, H//k, k, W//k, k)
285+ x = x .permute (0 , 2 , 4 , 1 , 3 , 5 ) # (B, H//k, W//k, C, k, k)
286286 x = self .transform (x )
287287 x = x .reshape (- 1 , c , self .k * self .k )
288288 x = x [:, :, self .permutation ]
0 commit comments