1+ using System . Collections . Generic ;
12using AiDotNet . Helpers ;
23using AiDotNet . Interfaces ;
34using AiDotNet . LinearAlgebra ;
@@ -78,6 +79,13 @@ public class RelationalDistillationStrategy<T> : DistillationStrategyBase<Vector
7879 private readonly int _maxSamplesPerBatch ;
7980 private readonly RelationalDistanceMetric _distanceMetric ;
8081
82+ // Batch accumulation buffers for relational loss computation
83+ private readonly List < Vector < T > > _batchStudentOutputs = new ( ) ;
84+ private readonly List < Vector < T > > _batchTeacherOutputs = new ( ) ;
85+ private T _accumulatedRelationalLoss = default ! ;
86+ private int _samplesSinceRelationalCompute = 0 ;
87+ private readonly int _relationalBatchSize ;
88+
8189 /// <summary>
8290 /// Initializes a new instance of the RelationalDistillationStrategy class.
8391 /// </summary>
@@ -137,21 +145,44 @@ public RelationalDistillationStrategy(
137145 _angleWeight = angleWeight ;
138146 _maxSamplesPerBatch = maxSamplesPerBatch ;
139147 _distanceMetric = distanceMetric ;
148+ _relationalBatchSize = maxSamplesPerBatch ;
149+ _accumulatedRelationalLoss = NumOps . Zero ;
140150 }
141151
142152 /// <summary>
143- /// Computes standard output loss ( relational loss computed separately) .
153+ /// Computes combined output loss and relational loss.
144154 /// </summary>
155+ /// <remarks>
156+ /// <para>This method accumulates student/teacher outputs and computes relational loss
157+ /// when a batch is complete. The relational loss is then amortized across subsequent samples.</para>
158+ /// </remarks>
145159 public override T ComputeLoss ( Vector < T > studentOutput , Vector < T > teacherOutput , Vector < T > ? trueLabels = null )
146160 {
147161 ValidateOutputDimensions ( studentOutput , teacherOutput , v => v . Length ) ;
148162
163+ // Accumulate outputs for relational loss computation
164+ _batchStudentOutputs . Add ( studentOutput ) ;
165+ _batchTeacherOutputs . Add ( teacherOutput ) ;
166+
167+ // When batch is full, compute relational loss
168+ if ( _batchStudentOutputs . Count >= _relationalBatchSize )
169+ {
170+ _accumulatedRelationalLoss = ComputeRelationalLoss (
171+ _batchStudentOutputs . ToArray ( ) ,
172+ _batchTeacherOutputs . ToArray ( ) ) ;
173+ _samplesSinceRelationalCompute = 0 ;
174+ // Clear buffers for next batch
175+ _batchStudentOutputs . Clear ( ) ;
176+ _batchTeacherOutputs . Clear ( ) ;
177+ }
178+
149179 // Standard distillation loss
150180 var studentSoft = Softmax ( studentOutput , Temperature ) ;
151181 var teacherSoft = Softmax ( teacherOutput , Temperature ) ;
152182 var softLoss = KLDivergence ( teacherSoft , studentSoft ) ;
153183 softLoss = NumOps . Multiply ( softLoss , NumOps . FromDouble ( Temperature * Temperature ) ) ;
154184
185+ T baseLoss ;
155186 if ( trueLabels != null )
156187 {
157188 ValidateLabelDimensions ( studentOutput , trueLabels , v => v . Length ) ;
@@ -161,24 +192,43 @@ public override T ComputeLoss(Vector<T> studentOutput, Vector<T> teacherOutput,
161192 var alphaT = NumOps . FromDouble ( Alpha ) ;
162193 var oneMinusAlpha = NumOps . FromDouble ( 1.0 - Alpha ) ;
163194
164- return NumOps . Add (
195+ baseLoss = NumOps . Add (
165196 NumOps . Multiply ( alphaT , hardLoss ) ,
166197 NumOps . Multiply ( oneMinusAlpha , softLoss ) ) ;
167198 }
199+ else
200+ {
201+ baseLoss = softLoss ;
202+ }
203+
204+ // Add amortized relational loss (distributed across batch samples)
205+ T relationalContribution = NumOps . Zero ;
206+ if ( _samplesSinceRelationalCompute < _relationalBatchSize )
207+ {
208+ relationalContribution = NumOps . Divide (
209+ _accumulatedRelationalLoss ,
210+ NumOps . FromDouble ( _relationalBatchSize ) ) ;
211+ _samplesSinceRelationalCompute ++ ;
212+ }
168213
169- return softLoss ;
214+ return NumOps . Add ( baseLoss , relationalContribution ) ;
170215 }
171216
172217 /// <summary>
173- /// Computes gradient of output loss.
218+ /// Computes gradient of combined output loss and relational loss.
174219 /// </summary>
220+ /// <remarks>
221+ /// <para>The gradient includes both the standard distillation gradient and the relational gradient
222+ /// computed from pairwise distances and angular relationships in the accumulated batch.</para>
223+ /// </remarks>
175224 public override Vector < T > ComputeGradient ( Vector < T > studentOutput , Vector < T > teacherOutput , Vector < T > ? trueLabels = null )
176225 {
177226 ValidateOutputDimensions ( studentOutput , teacherOutput , v => v . Length ) ;
178227
179228 int n = studentOutput . Length ;
180229 var gradient = new Vector < T > ( n ) ;
181230
231+ // Standard soft gradient
182232 var studentSoft = Softmax ( studentOutput , Temperature ) ;
183233 var teacherSoft = Softmax ( teacherOutput , Temperature ) ;
184234
@@ -188,6 +238,7 @@ public override Vector<T> ComputeGradient(Vector<T> studentOutput, Vector<T> tea
188238 gradient [ i ] = NumOps . Multiply ( diff , NumOps . FromDouble ( Temperature * Temperature ) ) ;
189239 }
190240
241+ // Add hard gradient if labels provided
191242 if ( trueLabels != null )
192243 {
193244 ValidateLabelDimensions ( studentOutput , trueLabels , v => v . Length ) ;
@@ -210,6 +261,194 @@ public override Vector<T> ComputeGradient(Vector<T> studentOutput, Vector<T> tea
210261 }
211262 }
212263
264+ // Add relational gradient contribution
265+ // Since we're accumulating batch-level relational loss, we compute gradients
266+ // for all pairs/triplets involving this sample when the batch is complete
267+ if ( _batchStudentOutputs . Count > 1 && _batchStudentOutputs . Count <= _relationalBatchSize )
268+ {
269+ var relationalGrad = ComputeRelationalGradient (
270+ studentOutput ,
271+ teacherOutput ,
272+ _batchStudentOutputs ,
273+ _batchTeacherOutputs ) ;
274+
275+ for ( int i = 0 ; i < n ; i ++ )
276+ {
277+ gradient [ i ] = NumOps . Add ( gradient [ i ] , relationalGrad [ i ] ) ;
278+ }
279+ }
280+
281+ return gradient ;
282+ }
283+
284+ /// <summary>
285+ /// Computes gradient of relational loss with respect to a single student output.
286+ /// </summary>
287+ /// <param name="currentStudentOutput">The current sample's student output.</param>
288+ /// <param name="currentTeacherOutput">The current sample's teacher output.</param>
289+ /// <param name="batchStudentOutputs">Accumulated batch of student outputs.</param>
290+ /// <param name="batchTeacherOutputs">Accumulated batch of teacher outputs.</param>
291+ /// <returns>Gradient vector for the current sample.</returns>
292+ /// <remarks>
293+ /// <para>This computes ∂L_relational/∂studentOutput for all pairs and triplets involving
294+ /// the current sample in the batch.</para>
295+ /// </remarks>
296+ private Vector < T > ComputeRelationalGradient (
297+ Vector < T > currentStudentOutput ,
298+ Vector < T > currentTeacherOutput ,
299+ List < Vector < T > > batchStudentOutputs ,
300+ List < Vector < T > > batchTeacherOutputs )
301+ {
302+ int dim = currentStudentOutput . Length ;
303+ var gradient = new Vector < T > ( dim ) ;
304+
305+ // Initialize gradient to zero
306+ for ( int i = 0 ; i < dim ; i ++ )
307+ {
308+ gradient [ i ] = NumOps . Zero ;
309+ }
310+
311+ if ( batchStudentOutputs . Count < 2 )
312+ return gradient ;
313+
314+ int currentIdx = batchStudentOutputs . Count - 1 ; // Current sample is last in batch
315+
316+ // Distance-wise gradient: for all pairs involving current sample
317+ if ( _distanceWeight > 0 )
318+ {
319+ for ( int j = 0 ; j < batchStudentOutputs . Count - 1 ; j ++ )
320+ {
321+ var distGrad = ComputePairwiseDistanceGradient (
322+ currentStudentOutput ,
323+ batchStudentOutputs [ j ] ,
324+ currentTeacherOutput ,
325+ batchTeacherOutputs [ j ] ) ;
326+
327+ for ( int k = 0 ; k < dim ; k ++ )
328+ {
329+ var weighted = NumOps . Multiply ( distGrad [ k ] , NumOps . FromDouble ( _distanceWeight ) ) ;
330+ gradient [ k ] = NumOps . Add ( gradient [ k ] , weighted ) ;
331+ }
332+ }
333+ }
334+
335+ // Angle-wise gradient: for triplets involving current sample
336+ // (Simplified - only consider a subset of triplets for efficiency)
337+ if ( _angleWeight > 0 && batchStudentOutputs . Count >= 3 )
338+ {
339+ int maxTriplets = Math . Min ( 10 , batchStudentOutputs . Count - 1 ) ;
340+ for ( int t = 0 ; t < maxTriplets ; t ++ )
341+ {
342+ if ( t >= batchStudentOutputs . Count - 1 )
343+ break ;
344+
345+ int j = t ;
346+ int k = ( t + 1 ) % ( batchStudentOutputs . Count - 1 ) ;
347+
348+ var angleGrad = ComputeTripletAngleGradient (
349+ currentStudentOutput ,
350+ batchStudentOutputs [ j ] ,
351+ batchStudentOutputs [ k ] ,
352+ currentTeacherOutput ,
353+ batchTeacherOutputs [ j ] ,
354+ batchTeacherOutputs [ k ] ) ;
355+
356+ for ( int d = 0 ; d < dim ; d ++ )
357+ {
358+ var weighted = NumOps . Multiply ( angleGrad [ d ] , NumOps . FromDouble ( _angleWeight ) ) ;
359+ gradient [ d ] = NumOps . Add ( gradient [ d ] , weighted ) ;
360+ }
361+ }
362+ }
363+
364+ // Normalize by batch size
365+ for ( int i = 0 ; i < dim ; i ++ )
366+ {
367+ gradient [ i ] = NumOps . Divide ( gradient [ i ] , NumOps . FromDouble ( batchStudentOutputs . Count ) ) ;
368+ }
369+
370+ return gradient ;
371+ }
372+
373+ /// <summary>
374+ /// Computes gradient of distance-wise loss for a pair.
375+ /// </summary>
376+ private Vector < T > ComputePairwiseDistanceGradient (
377+ Vector < T > studentI ,
378+ Vector < T > studentJ ,
379+ Vector < T > teacherI ,
380+ Vector < T > teacherJ )
381+ {
382+ int dim = studentI . Length ;
383+ var gradient = new Vector < T > ( dim ) ;
384+
385+ var studentDist = ComputeDistance ( studentI , studentJ ) ;
386+ var teacherDist = ComputeDistance ( teacherI , teacherJ ) ;
387+
388+ var diff = NumOps . Subtract ( studentDist , teacherDist ) ;
389+ double diffVal = Convert . ToDouble ( diff ) ;
390+
391+ // Huber loss gradient
392+ double gradScale ;
393+ if ( Math . Abs ( diffVal ) < 1.0 )
394+ {
395+ gradScale = 2.0 * diffVal ; // Quadratic region
396+ }
397+ else
398+ {
399+ gradScale = 2.0 * Math . Sign ( diffVal ) ; // Linear region
400+ }
401+
402+ // Gradient of distance w.r.t. studentI
403+ double distVal = Convert . ToDouble ( studentDist ) + Epsilon ;
404+ for ( int k = 0 ; k < dim ; k ++ )
405+ {
406+ double component = Convert . ToDouble ( NumOps . Subtract ( studentI [ k ] , studentJ [ k ] ) ) ;
407+ gradient [ k ] = NumOps . FromDouble ( gradScale * component / distVal ) ;
408+ }
409+
410+ return gradient ;
411+ }
412+
413+ /// <summary>
414+ /// Computes gradient of angle-wise loss for a triplet (simplified approximation).
415+ /// </summary>
416+ private Vector < T > ComputeTripletAngleGradient (
417+ Vector < T > studentI ,
418+ Vector < T > studentJ ,
419+ Vector < T > studentK ,
420+ Vector < T > teacherI ,
421+ Vector < T > teacherJ ,
422+ Vector < T > teacherK )
423+ {
424+ int dim = studentI . Length ;
425+ var gradient = new Vector < T > ( dim ) ;
426+
427+ var studentAngle = ComputeAngle ( studentI , studentJ , studentK ) ;
428+ var teacherAngle = ComputeAngle ( teacherI , teacherJ , teacherK ) ;
429+
430+ var angleDiff = NumOps . Subtract ( studentAngle , teacherAngle ) ;
431+ double diffVal = Convert . ToDouble ( angleDiff ) ;
432+
433+ // Numerical gradient approximation (for simplicity)
434+ double eps = 0.001 ;
435+ for ( int d = 0 ; d < dim ; d ++ )
436+ {
437+ var perturbed = new Vector < T > ( dim ) ;
438+ for ( int k = 0 ; k < dim ; k ++ )
439+ {
440+ perturbed [ k ] = k == d
441+ ? NumOps . Add ( studentI [ k ] , NumOps . FromDouble ( eps ) )
442+ : studentI [ k ] ;
443+ }
444+
445+ var perturbedAngle = ComputeAngle ( perturbed , studentJ , studentK ) ;
446+ var angleGrad = NumOps . Subtract ( perturbedAngle , studentAngle ) ;
447+ var numGrad = NumOps . Divide ( angleGrad , NumOps . FromDouble ( eps ) ) ;
448+
449+ gradient [ d ] = NumOps . Multiply ( numGrad , NumOps . FromDouble ( 2.0 * diffVal ) ) ;
450+ }
451+
213452 return gradient ;
214453 }
215454
0 commit comments