@@ -254,14 +254,80 @@ private Tensor<T> BackwardManual(Tensor<T> outputGradient)
254254 /// <returns>The gradient of the loss with respect to the layer's input.</returns>
255255 /// <remarks>
256256 /// <para>
257- /// This method uses automatic differentiation to compute gradients. Specialized operations
258- /// are not yet available in TensorOperations, so this falls back to the manual implementation.
257+ /// This method uses automatic differentiation to compute gradients using the Reshape operation.
258+ /// The split layer is effectively a reshape operation that adds a new dimension by dividing
259+ /// one dimension into two.
259260 /// </para>
260261 /// </remarks>
261262 private Tensor < T > BackwardViaAutodiff ( Tensor < T > outputGradient )
262263 {
263- // TODO: Specialized operation not yet available in TensorOperations
264- return BackwardManual ( outputGradient ) ;
264+ if ( _lastInput == null )
265+ throw new InvalidOperationException ( "Forward pass must be called before backward pass." ) ;
266+
267+ // Create computation node
268+ var inputNode = Autodiff . TensorOperations < T > . Variable ( _lastInput , "input" , requiresGradient : true ) ;
269+
270+ // Split is effectively a reshape: [batch, inputSize] → [batch, numSplits, splitSize]
271+ int batchSize = _lastInput . Shape [ 0 ] ;
272+ int inputSize = _lastInput . Shape [ 1 ] ;
273+ int splitSize = inputSize / _numSplits ;
274+ var outputShape = new int [ ] { batchSize , _numSplits , splitSize } ;
275+
276+ var outputNode = Autodiff . TensorOperations < T > . Reshape ( inputNode , outputShape ) ;
277+
278+ // Perform backward pass
279+ outputNode . Gradient = outputGradient ;
280+ var topoOrder = GetTopologicalOrder ( outputNode ) ;
281+ for ( int i = topoOrder . Count - 1 ; i >= 0 ; i -- )
282+ {
283+ var node = topoOrder [ i ] ;
284+ if ( node . RequiresGradient && node . BackwardFunction != null && node . Gradient != null )
285+ {
286+ node . BackwardFunction ( node . Gradient ) ;
287+ }
288+ }
289+
290+ // Extract input gradient
291+ return inputNode . Gradient ?? throw new InvalidOperationException ( "Gradient computation failed." ) ;
292+ }
293+
294+ /// <summary>
295+ /// Gets the topological order of nodes in the computation graph.
296+ /// </summary>
297+ /// <param name="root">The root node of the computation graph.</param>
298+ /// <returns>A list of nodes in topological order.</returns>
299+ private List < Autodiff . ComputationNode < T > > GetTopologicalOrder ( Autodiff . ComputationNode < T > root )
300+ {
301+ var visited = new HashSet < Autodiff . ComputationNode < T > > ( ) ;
302+ var result = new List < Autodiff . ComputationNode < T > > ( ) ;
303+
304+ var stack = new Stack < ( Autodiff . ComputationNode < T > node , bool processed ) > ( ) ;
305+ stack . Push ( ( root , false ) ) ;
306+
307+ while ( stack . Count > 0 )
308+ {
309+ var ( node , processed ) = stack . Pop ( ) ;
310+
311+ if ( visited . Contains ( node ) )
312+ continue ;
313+
314+ if ( processed )
315+ {
316+ visited . Add ( node ) ;
317+ result . Add ( node ) ;
318+ }
319+ else
320+ {
321+ stack . Push ( ( node , true ) ) ;
322+ foreach ( var parent in node . Parents )
323+ {
324+ if ( ! visited . Contains ( parent ) )
325+ stack . Push ( ( parent , false ) ) ;
326+ }
327+ }
328+ }
329+
330+ return result ;
265331 }
266332
267333
0 commit comments