Skip to content

Commit 10bccb5

Browse files
committed
feat: Update SplitLayer to use Reshape TensorOperation
SplitLayer now uses the existing Reshape operation for autodiff backward pass. Key features: - Uses existing Reshape operation (no new operation needed) - Split operation in this layer is effectively a reshape: [batch, size] → [batch, numSplits, size/numSplits] - Full autodiff support with gradient computation Layers with autodiff: 20 (19 previous + 1 new)
1 parent 951c093 commit 10bccb5

File tree

1 file changed

+70
-4
lines changed

1 file changed

+70
-4
lines changed

src/NeuralNetworks/Layers/SplitLayer.cs

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -254,14 +254,80 @@ private Tensor<T> BackwardManual(Tensor<T> outputGradient)
254254
/// <returns>The gradient of the loss with respect to the layer's input.</returns>
255255
/// <remarks>
256256
/// <para>
257-
/// This method uses automatic differentiation to compute gradients. Specialized operations
258-
/// are not yet available in TensorOperations, so this falls back to the manual implementation.
257+
/// This method uses automatic differentiation to compute gradients using the Reshape operation.
258+
/// The split layer is effectively a reshape operation that adds a new dimension by dividing
259+
/// one dimension into two.
259260
/// </para>
260261
/// </remarks>
261262
private Tensor<T> BackwardViaAutodiff(Tensor<T> outputGradient)
262263
{
263-
// TODO: Specialized operation not yet available in TensorOperations
264-
return BackwardManual(outputGradient);
264+
if (_lastInput == null)
265+
throw new InvalidOperationException("Forward pass must be called before backward pass.");
266+
267+
// Create computation node
268+
var inputNode = Autodiff.TensorOperations<T>.Variable(_lastInput, "input", requiresGradient: true);
269+
270+
// Split is effectively a reshape: [batch, inputSize] → [batch, numSplits, splitSize]
271+
int batchSize = _lastInput.Shape[0];
272+
int inputSize = _lastInput.Shape[1];
273+
int splitSize = inputSize / _numSplits;
274+
var outputShape = new int[] { batchSize, _numSplits, splitSize };
275+
276+
var outputNode = Autodiff.TensorOperations<T>.Reshape(inputNode, outputShape);
277+
278+
// Perform backward pass
279+
outputNode.Gradient = outputGradient;
280+
var topoOrder = GetTopologicalOrder(outputNode);
281+
for (int i = topoOrder.Count - 1; i >= 0; i--)
282+
{
283+
var node = topoOrder[i];
284+
if (node.RequiresGradient && node.BackwardFunction != null && node.Gradient != null)
285+
{
286+
node.BackwardFunction(node.Gradient);
287+
}
288+
}
289+
290+
// Extract input gradient
291+
return inputNode.Gradient ?? throw new InvalidOperationException("Gradient computation failed.");
292+
}
293+
294+
/// <summary>
295+
/// Gets the topological order of nodes in the computation graph.
296+
/// </summary>
297+
/// <param name="root">The root node of the computation graph.</param>
298+
/// <returns>A list of nodes in topological order.</returns>
299+
private List<Autodiff.ComputationNode<T>> GetTopologicalOrder(Autodiff.ComputationNode<T> root)
300+
{
301+
var visited = new HashSet<Autodiff.ComputationNode<T>>();
302+
var result = new List<Autodiff.ComputationNode<T>>();
303+
304+
var stack = new Stack<(Autodiff.ComputationNode<T> node, bool processed)>();
305+
stack.Push((root, false));
306+
307+
while (stack.Count > 0)
308+
{
309+
var (node, processed) = stack.Pop();
310+
311+
if (visited.Contains(node))
312+
continue;
313+
314+
if (processed)
315+
{
316+
visited.Add(node);
317+
result.Add(node);
318+
}
319+
else
320+
{
321+
stack.Push((node, true));
322+
foreach (var parent in node.Parents)
323+
{
324+
if (!visited.Contains(parent))
325+
stack.Push((parent, false));
326+
}
327+
}
328+
}
329+
330+
return result;
265331
}
266332

267333

0 commit comments

Comments
 (0)