Skip to content

Commit 1b8b29d

Browse files
committed
feat: Add PixelShuffle operation for sub-pixel convolution
PixelShuffle (depth-to-space) rearranges channel data into spatial dimensions: - Input: [batch, channels, height, width] - Output: [batch, channels/(r²), height*r, width*r] Used for: - SubpixelConvolutionalLayer (super-resolution, upsampling) - Efficient learned upsampling in GANs and autoencoders Forward: Rearranges r² channels into an r×r spatial block Backward: Reverses the rearrangement Total TensorOperations: 34 (33 previous + 1 new)
1 parent 32986e9 commit 1b8b29d

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed

src/Autodiff/TensorOperations.cs

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3522,5 +3522,106 @@ void BackwardFunction(Tensor<T> gradient)
35223522

35233523
return node;
35243524
}
3525+
3526+
/// <summary>
3527+
/// Performs pixel shuffle (depth-to-space) operation for sub-pixel convolution.
3528+
/// </summary>
3529+
/// <param name="a">The input computation node with shape [batch, channels, height, width].</param>
3530+
/// <param name="upscaleFactor">The upscaling factor (r). Channels must be divisible by r².</param>
3531+
/// <returns>A computation node with shape [batch, channels/(r²), height*r, width*r].</returns>
3532+
public static ComputationNode<T> PixelShuffle(ComputationNode<T> a, int upscaleFactor)
3533+
{
3534+
var numOps = MathHelper.GetNumericOperations<T>();
3535+
var inputShape = a.Value.Shape;
3536+
3537+
if (inputShape.Length != 4)
3538+
throw new ArgumentException("PixelShuffle expects 4D input [batch, channels, height, width]");
3539+
3540+
int batch = inputShape[0];
3541+
int channels = inputShape[1];
3542+
int inH = inputShape[2];
3543+
int inW = inputShape[3];
3544+
int r = upscaleFactor;
3545+
int r2 = r * r;
3546+
3547+
if (channels % r2 != 0)
3548+
throw new ArgumentException($"Channels {channels} must be divisible by upscale_factor² ({r2})");
3549+
3550+
int outC = channels / r2;
3551+
int outH = inH * r;
3552+
int outW = inW * r;
3553+
3554+
var outputShape = new int[] { batch, outC, outH, outW };
3555+
var result = new Tensor<T>(outputShape);
3556+
3557+
// Forward: rearrange channels into spatial dimensions
3558+
// input[b, c, h, w] -> output[b, c/(r²), h*r + r_h, w*r + r_w]
3559+
// where c = c_out * r² + r_h * r + r_w
3560+
for (int b = 0; b < batch; b++)
3561+
{
3562+
for (int c = 0; c < channels; c++)
3563+
{
3564+
int c_out = c / r2;
3565+
int c_offset = c % r2;
3566+
int r_h = c_offset / r;
3567+
int r_w = c_offset % r;
3568+
3569+
for (int h = 0; h < inH; h++)
3570+
{
3571+
for (int w = 0; w < inW; w++)
3572+
{
3573+
int out_h = h * r + r_h;
3574+
int out_w = w * r + r_w;
3575+
result[b, c_out, out_h, out_w] = a.Value[b, c, h, w];
3576+
}
3577+
}
3578+
}
3579+
}
3580+
3581+
void BackwardFunction(Tensor<T> gradient)
3582+
{
3583+
if (!a.RequiresGradient) return;
3584+
3585+
if (a.Gradient == null)
3586+
a.Gradient = new Tensor<T>(inputShape);
3587+
3588+
// Backward: reverse the rearrangement
3589+
for (int b = 0; b < batch; b++)
3590+
{
3591+
for (int c = 0; c < channels; c++)
3592+
{
3593+
int c_out = c / r2;
3594+
int c_offset = c % r2;
3595+
int r_h = c_offset / r;
3596+
int r_w = c_offset % r;
3597+
3598+
for (int h = 0; h < inH; h++)
3599+
{
3600+
for (int w = 0; w < inW; w++)
3601+
{
3602+
int out_h = h * r + r_h;
3603+
int out_w = w * r + r_w;
3604+
a.Gradient[b, c, h, w] = numOps.Add(
3605+
a.Gradient[b, c, h, w],
3606+
gradient[b, c_out, out_h, out_w]);
3607+
}
3608+
}
3609+
}
3610+
}
3611+
}
3612+
3613+
var node = new ComputationNode<T>(
3614+
value: result,
3615+
requiresGradient: a.RequiresGradient,
3616+
parents: new List<ComputationNode<T>> { a },
3617+
backwardFunction: BackwardFunction,
3618+
name: null);
3619+
3620+
var tape = GradientTape<T>.Current;
3621+
if (tape != null && tape.IsRecording)
3622+
tape.RecordOperation(node);
3623+
3624+
return node;
3625+
}
35253626
}
35263627
}

0 commit comments

Comments
 (0)