Skip to content

Commit b22a209

Browse files
committed
Fix chain rule in HopeNetwork backward pass
The backward pass was incorrectly breaking the chain rule by: - Iterating over learning levels instead of actual CMS blocks - Using modulo indexing (level % _numCMSLevels) which broke gradient flow - Reusing the same gradient for all blocks instead of chaining them - Accumulating gradients incorrectly Fixed by: - Processing context flow gradients in reverse, accumulating them into upstream gradient - Iterating CMS blocks in reverse order (last to first) without modulo - Properly chaining gradients: each block receives accumulated gradient from previous block - Returning final chained gradient as true derivative w.r.t. HOPE input This ensures proper backpropagation through the entire HOPE architecture.
1 parent 94398a6 commit b22a209

File tree

1 file changed

+15
-15
lines changed

1 file changed

+15
-15
lines changed

src/NeuralNetworks/HopeNetwork.cs

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -174,28 +174,28 @@ public Tensor<T> Backward(Tensor<T> outputGradient)
174174
gradient = _recurrentLayers[i].Backward(gradient);
175175
}
176176

177-
// Backprop through CMS blocks and context flow
178-
Tensor<T>? totalGradient = null;
179-
177+
// Backprop through context flow levels (applied after CMS blocks in forward pass)
178+
// Context flow blended with the output of CMS blocks, so we propagate gradients through
180179
for (int level = _inContextLearningLevels - 1; level >= 0; level--)
181180
{
182-
// Compute context flow gradients
181+
// Compute and accumulate context flow gradients for this level
183182
var contextGrad = _contextFlow.ComputeContextGradients(gradient.ToVector(), level);
183+
var contextTensor = new Tensor<T>(new[] { _hiddenDim }, contextGrad);
184184

185-
int cmsIndex = level % _numCMSLevels;
186-
var cmsGrad = _cmsBlocks[cmsIndex].Backward(gradient);
185+
// Add context gradient to current upstream gradient (blending was additive in forward)
186+
gradient = AddTensors(gradient, contextTensor);
187+
}
187188

188-
if (totalGradient == null)
189-
{
190-
totalGradient = cmsGrad;
191-
}
192-
else
193-
{
194-
totalGradient = AddTensors(totalGradient, cmsGrad);
195-
}
189+
// Backprop through CMS blocks in reverse order (no modulo - proper chain rule)
190+
// Each block receives the accumulated gradient from the previous block
191+
for (int i = _numCMSLevels - 1; i >= 0; i--)
192+
{
193+
// Pass combined gradient to this CMS block's backward
194+
gradient = _cmsBlocks[i].Backward(gradient);
195+
// gradient now contains the downstream gradient for the next (previous) block
196196
}
197197

198-
return totalGradient!;
198+
return gradient;
199199
}
200200

201201
/// <summary>

0 commit comments

Comments
 (0)