Skip to content

Commit f7c0a0b

Browse files
committed
Fix ModifiedGradientDescentOptimizer to use correct projection
The UpdateVector method was using an incorrect scalar heuristic that uniformly scaled all parameters by (1 - ||x||²), which required clipping when ||x||² >= 1 and completely discarded the parameter term. Issue: - Used modFactor = 1 - ||x||² as a scalar multiplier - Clipped to zero when ||x||² >= 1, dropping currentParameters entirely - This is not the correct vector equivalent of W * (I - x x^T) Fix: Replaced with correct projection for vector parameter w: - w * (I - x x^T) = w - x*(x^T*w) = w - x*dot(w,x) - Compute dot = dot(currentParameters, input) - Projection: currentParameters - input * dot - Then subtract gradient: -η * gradient - Final: w_{t+1} = w_t - x_t*dot(w_t,x_t) - η*gradient Benefits: - Mathematically correct implementation of Equations 27-29 - No clipping needed - projection is always numerically stable - Parameters never discarded regardless of input norm - Added validation for dimension matching This ensures the Modified Gradient Descent optimizer correctly implements the paper's formulation for vector parameters.
1 parent 18e08cf commit f7c0a0b

File tree

1 file changed

+27
-27
lines changed

1 file changed

+27
-27
lines changed

src/Optimizers/ModifiedGradientDescentOptimizer.cs

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -67,45 +67,45 @@ public Matrix<T> UpdateMatrix(Matrix<T> currentParameters, Vector<T> input, Vect
6767
/// <summary>
6868
/// Updates a parameter vector using modified gradient descent.
6969
///
70-
/// NOTE: This is a simplified scalar approximation of the matrix operation.
71-
/// The matrix form W_t * (I - x_t x_t^T) is always stable, but this scalar
72-
/// version using (1 - ||x_t||²) requires clipping to prevent instability
73-
/// when input norm exceeds 1.
70+
/// For a vector parameter w, the matrix operation W * (I - x x^T) becomes:
71+
/// w_new = w * (I - x x^T) = w - x*(x^T*w) = w - x*dot(w,x)
72+
///
73+
/// Full update: w_{t+1} = w_t - x_t*dot(w_t,x_t) - η * gradient
7474
/// </summary>
75-
/// <param name="currentParameters">Current parameters</param>
76-
/// <param name="input">Input vector</param>
77-
/// <param name="outputGradient">Output gradient</param>
78-
/// <returns>Updated parameters</returns>
75+
/// <param name="currentParameters">Current parameter vector w_t</param>
76+
/// <param name="input">Input vector x_t</param>
77+
/// <param name="outputGradient">Output gradient ∇_y L(w_t; x_t)</param>
78+
/// <returns>Updated parameters w_{t+1}</returns>
7979
public Vector<T> UpdateVector(Vector<T> currentParameters, Vector<T> input, Vector<T> outputGradient)
8080
{
81+
if (currentParameters.Length != input.Length)
82+
throw new ArgumentException($"Parameter length ({currentParameters.Length}) must match input length ({input.Length})");
83+
84+
if (currentParameters.Length != outputGradient.Length)
85+
throw new ArgumentException($"Parameter length ({currentParameters.Length}) must match gradient length ({outputGradient.Length})");
86+
8187
var updated = new Vector<T>(currentParameters.Length);
8288

83-
// For vector form: apply element-wise operations
84-
// This is a simplified version that preserves the spirit of the modification
85-
T inputNormSquared = _numOps.Zero;
86-
for (int i = 0; i < input.Length; i++)
89+
// Compute dot(w_t, x_t) = x_t^T * w_t
90+
T dotProduct = _numOps.Zero;
91+
for (int i = 0; i < currentParameters.Length; i++)
8792
{
88-
inputNormSquared = _numOps.Add(inputNormSquared, _numOps.Square(input[i]));
93+
dotProduct = _numOps.Add(dotProduct, _numOps.Multiply(currentParameters[i], input[i]));
8994
}
9095

91-
// Apply modified update rule
96+
// Apply modified update rule: w_{t+1} = w_t - x_t*dot(w_t,x_t) - η*gradient
9297
for (int i = 0; i < currentParameters.Length; i++)
9398
{
94-
// Standard GD component: -η * gradient
95-
T gradComponent = _numOps.Multiply(outputGradient[i], _learningRate);
99+
// Projection term: w_t - x_t*dot(w_t,x_t)
100+
// This is the vector equivalent of W_t * (I - x_t*x_t^T)
101+
T projectionComponent = _numOps.Multiply(input[i], dotProduct);
102+
T projectedParam = _numOps.Subtract(currentParameters[i], projectionComponent);
96103

97-
// Modification: scale by (1 - ||x_t||²) factor for regularization
98-
// CRITICAL: Clip to prevent negative scaling when ||x_t||² > 1
99-
// Without clipping, parameters would explode when input norm exceeds 1
100-
T modFactor = _numOps.Subtract(_numOps.One, inputNormSquared);
101-
if (_numOps.LessThan(modFactor, _numOps.Zero))
102-
{
103-
modFactor = _numOps.Zero;
104-
}
105-
106-
T paramComponent = _numOps.Multiply(currentParameters[i], modFactor);
104+
// Gradient term: -η * gradient
105+
T gradComponent = _numOps.Multiply(outputGradient[i], _learningRate);
107106

108-
updated[i] = _numOps.Subtract(paramComponent, gradComponent);
107+
// Final update: w_{t+1} = (w_t - x_t*dot(w_t,x_t)) - η*gradient
108+
updated[i] = _numOps.Subtract(projectedParam, gradComponent);
109109
}
110110

111111
return updated;

0 commit comments

Comments
 (0)