Skip to content

Commit 72f9c45

Browse files
committed
Update tiled net.
1 parent 7e9ff48 commit 72f9c45

File tree

2 files changed

+304
-2
lines changed

2 files changed

+304
-2
lines changed

examples/gravnet/ParallelReverseAutoDiff.GravNetExample/TiledNetwork/Architecture/tilednet.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
},
6767
{
6868
"id": "vector_keys",
69-
"type": "NewGpuMatrixMultiplyOperation",
69+
"type": "NewTiledGpuMatrixMultiplyOperation",
7070
"inputs": [ "vector_add", "Keys" ],
7171
"gradientResultTo": [ null, "DKeys" ]
7272
},
@@ -83,7 +83,7 @@
8383
},
8484
{
8585
"id": "vector_queries",
86-
"type": "NewGpuMatrixMultiplyOperation",
86+
"type": "NewTiledGpuMatrixMultiplyOperation",
8787
"inputs": [ "vector_add", "Queries" ],
8888
"gradientResultTo": [ null, "DQueries" ]
8989
},
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
//------------------------------------------------------------------------------
2+
// <copyright file="NewTiledGpuMatrixMultiplyOperation.cs" author="ameritusweb" date="5/2/2023">
3+
// Copyright (c) 2023 ameritusweb All rights reserved.
4+
// </copyright>
5+
//------------------------------------------------------------------------------
6+
namespace ParallelReverseAutoDiff.RMAD
7+
{
8+
using System;
9+
using System.Linq;
10+
using ILGPU;
11+
using ILGPU.Runtime;
12+
using ParallelReverseAutoDiff.Exceptions;
13+
using ParallelReverseAutoDiff.GravNetExample.Common;
14+
15+
/// <summary>
16+
/// GPU Tiled matrix multiplication operation.
17+
/// </summary>
18+
public class NewTiledGpuMatrixMultiplyOperation : Operation
19+
{
20+
private const int TILESIZE = 8;
21+
private Matrix[,] input1;
22+
private Matrix[,] input2;
23+
private Matrix[,] output;
24+
private Matrix[,] dInput1;
25+
private Matrix[,] dInput2;
26+
27+
/// <summary>
28+
/// A common method for instantiating an operation.
29+
/// </summary>
30+
/// <param name="net">The neural network.</param>
31+
/// <returns>The instantiated operation.</returns>
32+
public static IOperation Instantiate(NeuralNetwork net)
33+
{
34+
return new NewTiledGpuMatrixMultiplyOperation();
35+
}
36+
37+
/// <summary>
38+
/// The tiled matrix multiplication kernel that runs on the accelerated device.
39+
/// </summary>
40+
/// <param name="aView">An input matrix of size MxK.</param>
41+
/// <param name="bView">An input matrix of size KxN.</param>
42+
/// <param name="cView">An output matrix of size MxN.</param>
43+
public static void MatrixMultiplyTiledKernel(
44+
ArrayView2D<double, Stride2D.DenseX> aView,
45+
ArrayView2D<double, Stride2D.DenseX> bView,
46+
ArrayView2D<double, Stride2D.DenseX> cView)
47+
{
48+
var global = Grid.GlobalIndex.XY;
49+
var x = Group.IdxX;
50+
var y = Group.IdxY;
51+
52+
var aTile = SharedMemory.Allocate2D<double, Stride2D.DenseX>(new Index2D(TILESIZE, TILESIZE), new Stride2D.DenseX(TILESIZE));
53+
var bTile = SharedMemory.Allocate2D<double, Stride2D.DenseX>(new Index2D(TILESIZE, TILESIZE), new Stride2D.DenseX(TILESIZE));
54+
55+
var total = 0.0d; // Initialize accumulator for sums across tiles
56+
57+
for (var i = 0; i < aView.IntExtent.Y; i += TILESIZE)
58+
{
59+
var sum = 0.0d;
60+
61+
if (global.X < aView.IntExtent.X && y + i < aView.IntExtent.Y)
62+
{
63+
aTile[x, y] = aView[global.X, y + i];
64+
}
65+
else
66+
{
67+
aTile[x, y] = 0;
68+
}
69+
70+
if (x + i < bView.IntExtent.X && global.Y < bView.IntExtent.Y)
71+
{
72+
bTile[x, y] = bView[x + i, global.Y];
73+
}
74+
else
75+
{
76+
bTile[x, y] = 0;
77+
}
78+
79+
Group.Barrier();
80+
81+
var kk = 0;
82+
83+
for (var k = 0; k < TILESIZE; k++)
84+
{
85+
sum += aTile[new Index2D(x, k)] * bTile[new Index2D(k, y)];
86+
}
87+
88+
Group.Barrier();
89+
90+
total += sum;
91+
}
92+
93+
if (global.X < cView.IntExtent.X && global.Y < cView.IntExtent.Y)
94+
{
95+
cView[global] = total;
96+
}
97+
}
98+
99+
/// <summary>
100+
/// Performs the forward operation for the matrix multiply function.
101+
/// </summary>
102+
/// <param name="input1">The first input to the matrix multiply operation.</param>
103+
/// <param name="input2">The second input to the matrix multiply operation.</param>
104+
/// <returns>The output of the matrix multiply operation.</returns>
105+
public Matrix Forward(Matrix input1, Matrix input2)
106+
{
107+
if (!CudaBlas.Instance.IsInitialized)
108+
{
109+
throw new CudaNotInitializedException();
110+
}
111+
112+
var brokenInput1 = CommonMatrixUtils.BreakIntoSections(input1, 8);
113+
var brokenInput2 = CommonMatrixUtils.BreakIntoSections(input2, 8);
114+
115+
this.input1 = new Matrix[brokenInput1.GetLength(0), brokenInput1.GetLength(1)];
116+
this.input2 = new Matrix[brokenInput2.GetLength(0), brokenInput2.GetLength(1)];
117+
this.output = new Matrix[brokenInput1.GetLength(0), brokenInput2.GetLength(1)];
118+
119+
Parallel.For(0, 8, i =>
120+
{
121+
for (int j = 0; j < 8; j++)
122+
{
123+
var i1 = brokenInput1[i, j];
124+
var i2 = brokenInput2[i, j];
125+
126+
this.InnerForward(i, j, i1, i2);
127+
}
128+
});
129+
130+
this.Output = CommonMatrixUtils.PieceTogether(this.output);
131+
return this.Output;
132+
}
133+
134+
/// <inheritdoc />
135+
public override BackwardResult Backward(Matrix dOutput)
136+
{
137+
if (!CudaBlas.Instance.IsInitialized)
138+
{
139+
throw new CudaNotInitializedException();
140+
}
141+
142+
this.dInput1 = new Matrix[this.input1.GetLength(0), this.input1.GetLength(1)];
143+
this.dInput2 = new Matrix[this.input2.GetLength(0), this.input2.GetLength(1)];
144+
var dOutputSections = CommonMatrixUtils.BreakIntoSections(dOutput, 8);
145+
146+
Parallel.For(0, this.dInput1.GetLength(0), i =>
147+
{
148+
for (int j = 0; j < this.dInput2.GetLength(1); j++)
149+
{
150+
this.InnerBackward(i, j, dOutputSections[i, j]);
151+
}
152+
});
153+
154+
return new BackwardResultBuilder()
155+
.AddInputGradient(CommonMatrixUtils.PieceTogether(this.dInput1))
156+
.AddInputGradient(CommonMatrixUtils.PieceTogether(this.dInput2))
157+
.Build();
158+
}
159+
160+
/// <summary>
161+
/// Multiplies two dense matrices and returns the resultant matrix (using tiling).
162+
/// </summary>
163+
/// <param name="accelerator">The Accelerator to run the multiplication on.</param>
164+
/// <param name="a">A dense MxK matrix.</param>
165+
/// <param name="b">A dense KxN matrix.</param>
166+
/// <returns>A dense MxN matrix.</returns>
167+
public double[,] MatrixMultiplyTiled(Accelerator accelerator, double[,] a, double[,] b)
168+
{
169+
var m = a.GetLength(0);
170+
var ka = a.GetLength(1);
171+
var kb = b.GetLength(0);
172+
var n = b.GetLength(1);
173+
174+
if (ka != kb)
175+
{
176+
throw new ArgumentException($"Cannot multiply {m}x{ka} matrix by {n}x{kb} matrix", nameof(b));
177+
}
178+
179+
var kernel = accelerator.LoadStreamKernel<
180+
ArrayView2D<double, Stride2D.DenseX>,
181+
ArrayView2D<double, Stride2D.DenseX>,
182+
ArrayView2D<double, Stride2D.DenseX>>(
183+
MatrixMultiplyTiledKernel);
184+
var groupSize = new Index2D(TILESIZE, TILESIZE);
185+
var numGroups = new Index2D((m + TILESIZE - 1) / TILESIZE, (n + TILESIZE - 1) / TILESIZE);
186+
187+
using var aBuffer = accelerator.Allocate2DDenseX<double>(new Index2D(m, ka));
188+
using var bBuffer = accelerator.Allocate2DDenseX<double>(new Index2D(ka, n));
189+
using var cBuffer = accelerator.Allocate2DDenseX<double>(new Index2D(m, n));
190+
aBuffer.CopyFromCPU(a);
191+
bBuffer.CopyFromCPU(b);
192+
193+
kernel((numGroups, groupSize), aBuffer, bBuffer, cBuffer);
194+
195+
// Reads data from the GPU buffer into a new CPU array.
196+
// Implicitly calls accelerator.DefaultStream.Synchronize() to ensure
197+
// that the kernel and memory copy are completed first.
198+
return cBuffer.GetAsArray2D();
199+
}
200+
201+
private void InnerForward(int ii, int jj, Matrix input1, Matrix input2)
202+
{
203+
this.input1[ii, jj] = input1;
204+
this.input2[ii, jj] = input2;
205+
int input1Cols = input1[0].Length;
206+
int input2Rows = input2.Length;
207+
208+
if (input1Cols != input2Rows)
209+
{
210+
throw new InvalidOperationException("Input 1 columns do not match Input 2 rows");
211+
}
212+
213+
var acceleratedTiledResult = this.MatrixMultiplyTiled(CudaBlas.Instance.Accelerator, this.To2D(input1.ToArray(), false), this.To2D(input2.ToArray(), false));
214+
215+
this.output[ii, jj] = new Matrix(this.ToJagged(acceleratedTiledResult));
216+
}
217+
218+
private void InnerBackward(int ii, int jj, Matrix dOutput)
219+
{
220+
// Calculate gradient w.r.t. input1
221+
222+
// Compute dInput1 using MatrixMultiply
223+
var acceleratedTiledResult1 = this.MatrixMultiplyTiled(CudaBlas.Instance.Accelerator, this.To2D(dOutput.ToArray(), false), this.To2D(this.input2[ii, jj].ToArray(), true));
224+
this.dInput1[ii, jj] = new Matrix(this.ToJagged(acceleratedTiledResult1));
225+
226+
// Calculate gradient w.r.t. input2
227+
228+
// Compute dInput2 using MatrixMultiply
229+
var acceleratedTiledResult2 = this.MatrixMultiplyTiled(CudaBlas.Instance.Accelerator, this.To2D(this.input1[ii, jj].ToArray(), true), this.To2D(dOutput.ToArray(), false));
230+
this.dInput2[ii, jj] = new Matrix(this.ToJagged(acceleratedTiledResult2));
231+
}
232+
233+
/// <summary>
234+
/// Converts a jagged array to a 2D array.
235+
/// </summary>
236+
/// <param name="source">The jagged array.</param>
237+
/// <param name="transpose">Whether to transpose the array.</param>
238+
/// <returns>The 2-D array.</returns>
239+
private double[,] To2D(double[][] source, bool transpose)
240+
{
241+
try
242+
{
243+
int firstDim = source.Length;
244+
int secondDim = source.GroupBy(row => row.Length).Single().Key; // throws InvalidOperationException if source is not rectangular
245+
246+
if (transpose)
247+
{
248+
var result = new double[secondDim, firstDim];
249+
for (int i = 0; i < secondDim; ++i)
250+
{
251+
for (int j = 0; j < firstDim; ++j)
252+
{
253+
result[i, j] = source[j][i];
254+
}
255+
}
256+
257+
return result;
258+
}
259+
else
260+
{
261+
var result = new double[firstDim, secondDim];
262+
for (int i = 0; i < firstDim; ++i)
263+
{
264+
for (int j = 0; j < secondDim; ++j)
265+
{
266+
result[i, j] = source[i][j];
267+
}
268+
}
269+
270+
return result;
271+
}
272+
}
273+
catch (InvalidOperationException)
274+
{
275+
throw new InvalidOperationException("The given jagged array is not rectangular.");
276+
}
277+
}
278+
279+
/// <summary>
280+
/// Converts a 2D array to a jagged array.
281+
/// </summary>
282+
/// <param name="source">The 2-D array.</param>
283+
/// <returns>The jagged array.</returns>
284+
private double[][] ToJagged(double[,] source)
285+
{
286+
int firstDim = source.GetLength(0);
287+
int secondDim = source.GetLength(1);
288+
var result = new double[firstDim][];
289+
290+
for (int i = 0; i < firstDim; ++i)
291+
{
292+
result[i] = new double[secondDim];
293+
for (int j = 0; j < secondDim; ++j)
294+
{
295+
result[i][j] = source[i, j];
296+
}
297+
}
298+
299+
return result;
300+
}
301+
}
302+
}

0 commit comments

Comments
 (0)