@@ -3522,5 +3522,106 @@ void BackwardFunction(Tensor<T> gradient)
35223522
35233523 return node ;
35243524 }
3525+
3526+ /// <summary>
3527+ /// Performs pixel shuffle (depth-to-space) operation for sub-pixel convolution.
3528+ /// </summary>
3529+ /// <param name="a">The input computation node with shape [batch, channels, height, width].</param>
3530+ /// <param name="upscaleFactor">The upscaling factor (r). Channels must be divisible by r².</param>
3531+ /// <returns>A computation node with shape [batch, channels/(r²), height*r, width*r].</returns>
3532+ public static ComputationNode < T > PixelShuffle ( ComputationNode < T > a , int upscaleFactor )
3533+ {
3534+ var numOps = MathHelper . GetNumericOperations < T > ( ) ;
3535+ var inputShape = a . Value . Shape ;
3536+
3537+ if ( inputShape . Length != 4 )
3538+ throw new ArgumentException ( "PixelShuffle expects 4D input [batch, channels, height, width]" ) ;
3539+
3540+ int batch = inputShape [ 0 ] ;
3541+ int channels = inputShape [ 1 ] ;
3542+ int inH = inputShape [ 2 ] ;
3543+ int inW = inputShape [ 3 ] ;
3544+ int r = upscaleFactor ;
3545+ int r2 = r * r ;
3546+
3547+ if ( channels % r2 != 0 )
3548+ throw new ArgumentException ( $ "Channels { channels } must be divisible by upscale_factor² ({ r2 } )") ;
3549+
3550+ int outC = channels / r2 ;
3551+ int outH = inH * r ;
3552+ int outW = inW * r ;
3553+
3554+ var outputShape = new int [ ] { batch , outC , outH , outW } ;
3555+ var result = new Tensor < T > ( outputShape ) ;
3556+
3557+ // Forward: rearrange channels into spatial dimensions
3558+ // input[b, c, h, w] -> output[b, c/(r²), h*r + r_h, w*r + r_w]
3559+ // where c = c_out * r² + r_h * r + r_w
3560+ for ( int b = 0 ; b < batch ; b ++ )
3561+ {
3562+ for ( int c = 0 ; c < channels ; c ++ )
3563+ {
3564+ int c_out = c / r2 ;
3565+ int c_offset = c % r2 ;
3566+ int r_h = c_offset / r ;
3567+ int r_w = c_offset % r ;
3568+
3569+ for ( int h = 0 ; h < inH ; h ++ )
3570+ {
3571+ for ( int w = 0 ; w < inW ; w ++ )
3572+ {
3573+ int out_h = h * r + r_h ;
3574+ int out_w = w * r + r_w ;
3575+ result [ b , c_out , out_h , out_w ] = a . Value [ b , c , h , w ] ;
3576+ }
3577+ }
3578+ }
3579+ }
3580+
3581+ void BackwardFunction ( Tensor < T > gradient )
3582+ {
3583+ if ( ! a . RequiresGradient ) return ;
3584+
3585+ if ( a . Gradient == null )
3586+ a . Gradient = new Tensor < T > ( inputShape ) ;
3587+
3588+ // Backward: reverse the rearrangement
3589+ for ( int b = 0 ; b < batch ; b ++ )
3590+ {
3591+ for ( int c = 0 ; c < channels ; c ++ )
3592+ {
3593+ int c_out = c / r2 ;
3594+ int c_offset = c % r2 ;
3595+ int r_h = c_offset / r ;
3596+ int r_w = c_offset % r ;
3597+
3598+ for ( int h = 0 ; h < inH ; h ++ )
3599+ {
3600+ for ( int w = 0 ; w < inW ; w ++ )
3601+ {
3602+ int out_h = h * r + r_h ;
3603+ int out_w = w * r + r_w ;
3604+ a . Gradient [ b , c , h , w ] = numOps . Add (
3605+ a . Gradient [ b , c , h , w ] ,
3606+ gradient [ b , c_out , out_h , out_w ] ) ;
3607+ }
3608+ }
3609+ }
3610+ }
3611+ }
3612+
3613+ var node = new ComputationNode < T > (
3614+ value : result ,
3615+ requiresGradient : a . RequiresGradient ,
3616+ parents : new List < ComputationNode < T > > { a } ,
3617+ backwardFunction : BackwardFunction ,
3618+ name : null ) ;
3619+
3620+ var tape = GradientTape < T > . Current ;
3621+ if ( tape != null && tape . IsRecording )
3622+ tape . RecordOperation ( node ) ;
3623+
3624+ return node ;
3625+ }
35253626}
35263627}
0 commit comments