55//------------------------------------------------------------------------------
66namespace ParallelReverseAutoDiff . RMAD
77{
8- using ParallelReverseAutoDiff . GravNetExample . Common ;
98 using System ;
109 using System . Linq ;
1110 using System . Threading . Tasks ;
11+ using ParallelReverseAutoDiff . GravNetExample . Common ;
1212
1313 /// <summary>
1414 /// Element-wise cartesian tiled summation operation.
@@ -51,6 +51,10 @@ public Matrix Forward(Matrix input1, Matrix input2, Matrix weights)
5151 this . calculatedValues = new CalculatedValues [ brokenInput1 . GetLength ( 0 ) , brokenInput1 . GetLength ( 1 ) ] [ , ] ;
5252 this . summationX = new double [ brokenInput1 . GetLength ( 0 ) , brokenInput1 . GetLength ( 1 ) ] [ ] ;
5353 this . summationY = new double [ brokenInput1 . GetLength ( 0 ) , brokenInput1 . GetLength ( 1 ) ] [ ] ;
54+ this . input1 = new Matrix [ brokenInput1 . GetLength ( 0 ) , brokenInput1 . GetLength ( 1 ) ] ;
55+ this . input2 = new Matrix [ brokenInput2 . GetLength ( 0 ) , brokenInput2 . GetLength ( 1 ) ] ;
56+ this . weights = new Matrix [ brokenWeights . GetLength ( 0 ) , brokenWeights . GetLength ( 1 ) ] ;
57+ this . output = new Matrix [ brokenInput1 . GetLength ( 0 ) , brokenInput1 . GetLength ( 1 ) ] ;
5458
5559 Parallel . For ( 0 , brokenInput1 . GetLength ( 0 ) , i =>
5660 {
@@ -60,7 +64,7 @@ public Matrix Forward(Matrix input1, Matrix input2, Matrix weights)
6064 }
6165 } ) ;
6266
63- this . Output = CommonMatrixUtils . PieceTogether ( this . output ) ;
67+ this . Output = CommonMatrixUtils . PieceTogetherExactly ( this . output ) ;
6468
6569 return this . Output ;
6670 }
@@ -76,7 +80,7 @@ private void InnerForward(int ii, int jj, Matrix input1, Matrix input2, Matrix w
7680 double [ ] summationX = new double [ input1 . Rows ] ;
7781 double [ ] summationY = new double [ input1 . Rows ] ;
7882 double [ , ] resultVectors = new double [ input1 . Rows * ( input1 . Cols / 2 ) , 2 ] ;
79- Parallel . For ( 0 , input1 . Rows , i =>
83+ for ( int i = 0 ; i < input1 . Rows ; i ++ )
8084 {
8185 double sumX = 0.0d ;
8286 double sumY = 0.0d ;
@@ -169,17 +173,20 @@ private void InnerForward(int ii, int jj, Matrix input1, Matrix input2, Matrix w
169173 calculatedValues . DLocalSumY_DMagnitude = dLocalSumY_dMagnitude ;
170174 calculatedValues . DLocalSumY_DWMagnitude = dLocalSumY_dWMagnitude ;
171175
176+ this . calculatedValues [ ii , jj ] [ i , j ] = calculatedValues ;
177+
172178 sumX += localSumX ;
173179 sumY += localSumY ;
174180 }
175181
176182 summationX [ i ] = sumX ;
177183 summationY [ i ] = sumY ;
178- } ) ;
184+ }
179185
180186 this . summationX [ ii , jj ] = summationX ;
181187 this . summationY [ ii , jj ] = summationY ;
182188
189+ this . output [ ii , jj ] = new Matrix ( 1 , 2 ) ;
183190 this . output [ ii , jj ] [ 0 , 0 ] = this . summationX [ ii , jj ] . Sum ( ) ;
184191 this . output [ ii , jj ] [ 0 , 1 ] = this . summationY [ ii , jj ] . Sum ( ) ;
185192 }
@@ -190,13 +197,13 @@ public override BackwardResult Backward(Matrix dOutput)
190197 this . dInput1 = new Matrix [ this . input1 . GetLength ( 0 ) , this . input1 . GetLength ( 1 ) ] ;
191198 this . dInput2 = new Matrix [ this . input2 . GetLength ( 0 ) , this . input2 . GetLength ( 1 ) ] ;
192199 this . dWeights = new Matrix [ this . weights . GetLength ( 0 ) , this . weights . GetLength ( 1 ) ] ;
193- var dOutputSections = CommonMatrixUtils . BreakIntoSections ( dOutput , 8 ) ;
200+ var dOutputSections = CommonMatrixUtils . BreakIntoSectionsExactly ( dOutput , 8 ) ;
194201
195202 Parallel . For ( 0 , this . dInput1 . GetLength ( 0 ) , i =>
196203 {
197204 for ( int j = 0 ; j < this . dInput2 . GetLength ( 1 ) ; j ++ )
198205 {
199- this . InnerBackward ( i , j , this . dInput1 [ i , j ] , this . dInput2 [ i , j ] , this . dWeights [ i , j ] , dOutputSections [ i , j ] ) ;
206+ this . InnerBackward ( i , j , dOutputSections [ i , j ] ) ;
200207 }
201208 } ) ;
202209
@@ -207,19 +214,19 @@ public override BackwardResult Backward(Matrix dOutput)
207214 . Build ( ) ;
208215 }
209216
210- private void InnerBackward ( int ii , int jj , Matrix input1 , Matrix input2 , Matrix weights , Matrix dOutput )
217+ private void InnerBackward ( int ii , int jj , Matrix dOutput )
211218 {
212- Matrix dInput1 = new Matrix ( input1 . Rows , input1 . Cols ) ;
213- Matrix dInput2 = new Matrix ( input2 . Rows , input2 . Cols ) ;
214- Matrix dWeights = new Matrix ( weights . Rows , weights . Cols ) ;
219+ Matrix dInput1 = new Matrix ( this . input1 [ ii , jj ] . Rows , this . input1 [ ii , jj ] . Cols ) ;
220+ Matrix dInput2 = new Matrix ( this . input2 [ ii , jj ] . Rows , this . input2 [ ii , jj ] . Cols ) ;
221+ Matrix dWeights = new Matrix ( this . weights [ ii , jj ] . Rows , this . weights [ ii , jj ] . Cols ) ;
215222
216223 double dSummationXOutput = dOutput [ 0 , 0 ] ; // Gradient of the loss function with respect to the output X
217224 double dSummationYOutput = dOutput [ 0 , 1 ] ; // Gradient of the loss function with respect to the output Y
218225
219226 // Updating gradients with respect to resultMagnitude and resultAngle
220- Parallel . For ( 0 , input1 . Rows , i =>
227+ for ( int i = 0 ; i < this . input1 [ ii , jj ] . Rows ; i ++ )
221228 {
222- for ( int j = 0 ; j < input1 . Cols / 2 ; j ++ )
229+ for ( int j = 0 ; j < this . input1 [ ii , jj ] . Cols / 2 ; j ++ )
223230 {
224231 var values = this . calculatedValues [ ii , jj ] [ i , j ] ;
225232
@@ -228,12 +235,12 @@ private void InnerBackward(int ii, int jj, Matrix input1, Matrix input2, Matrix
228235
229236 // Apply chain rule to propagate back to dInput1 and dInput2
230237 dInput1 [ i , j ] = dSummationXOutput * values . DLocalSumX_DMagnitude + dSummationYOutput * values . DLocalSumY_DMagnitude ;
231- dInput1 [ i , j + ( input1 . Cols / 2 ) ] = dSummationXOutput * values . DLocalSumX_DAngle + dSummationYOutput * values . DLocalSumY_DAngle ;
238+ dInput1 [ i , j + ( this . input1 [ ii , jj ] . Cols / 2 ) ] = dSummationXOutput * values . DLocalSumX_DAngle + dSummationYOutput * values . DLocalSumY_DAngle ;
232239
233240 dInput2 [ i , j ] = dSummationXOutput * values . DLocalSumX_DWMagnitude + dSummationYOutput * values . DLocalSumY_DWMagnitude ;
234- dInput2 [ i , j + ( input2 . Cols / 2 ) ] = dSummationXOutput * values . DLocalSumX_DWAngle + dSummationYOutput * values . DLocalSumY_DWAngle ;
241+ dInput2 [ i , j + ( this . input2 [ ii , jj ] . Cols / 2 ) ] = dSummationXOutput * values . DLocalSumX_DWAngle + dSummationYOutput * values . DLocalSumY_DWAngle ;
235242 }
236- } ) ;
243+ }
237244
238245 this . dInput1 [ ii , jj ] = dInput1 ;
239246 this . dInput2 [ ii , jj ] = dInput2 ;
0 commit comments