Skip to content

Commit 4cd6018

Browse files
committed
Add batch matrix multiply.
1 parent 8d8f21e commit 4cd6018

File tree

2 files changed

+98
-2
lines changed

2 files changed

+98
-2
lines changed
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
//------------------------------------------------------------------------------
2+
// <copyright file="BatchMatrixMultiplyOperation.cs" author="ameritusweb" date="5/2/2023">
3+
// Copyright (c) 2023 ameritusweb All rights reserved.
4+
// </copyright>
5+
//------------------------------------------------------------------------------
6+
namespace ParallelReverseAutoDiff.RMAD
7+
{
8+
using System;
9+
using System.Linq;
10+
using System.Threading.Tasks;
11+
12+
/// <summary>
13+
/// Batch matrix multiply operation.
14+
/// </summary>
15+
public class BatchMatrixMultiplyOperation : BatchOperation<MatrixMultiplyOperation>
16+
{
17+
private BatchMatrixMultiplyOperation(NeuralNetwork net)
18+
: base(net)
19+
{
20+
this.Operations = new MatrixMultiplyOperation[net.Parameters.BatchSize];
21+
}
22+
23+
/// <summary>
24+
/// A common method for instantiating an operation.
25+
/// </summary>
26+
/// <param name="net">The neural network.</param>
27+
/// <returns>The instantiated operation.</returns>
28+
public static IBatchOperation Instantiate(NeuralNetwork net)
29+
{
30+
return new BatchMatrixMultiplyOperation(net);
31+
}
32+
33+
/// <inheritdoc />
34+
public override void Store(Guid id)
35+
{
36+
this.IntermediateOperationArrays.AddOrUpdate(id, this.Operations, (key, oldValue) => this.Operations);
37+
}
38+
39+
/// <inheritdoc />
40+
public override void Restore(Guid id)
41+
{
42+
this.Operations = this.IntermediateOperationArrays[id].OfType<MatrixMultiplyOperation>().ToArray();
43+
}
44+
45+
/// <summary>
46+
/// Performs the forward operation for the matrix multiply function.
47+
/// </summary>
48+
/// <param name="input1">The first input to the matrix multiply operation.</param>
49+
/// <param name="input2">The second input to the matrix multiply operation.</param>
50+
/// <returns>The output of the Hadamard product operation.</returns>
51+
public DeepMatrix Forward(DeepMatrix input1, DeepMatrix input2)
52+
{
53+
this.ExtendOperations();
54+
var matrixArray = new Matrix[input1.Depth];
55+
for (int i = 0; i < input1.Depth; i++)
56+
{
57+
this.Operations[i] = new MatrixMultiplyOperation();
58+
matrixArray[i] = this.Operations[i].Forward(input1[i], input2[i]);
59+
}
60+
61+
this.DeepOutput = new DeepMatrix(matrixArray);
62+
return this.DeepOutput;
63+
}
64+
65+
/// <summary>
66+
/// Performs the forward operation for the matrix multiply function.
67+
/// </summary>
68+
/// <param name="input1">The first input to the matrix multiply operation.</param>
69+
/// <param name="input2">The second input to the matrix multiply operation.</param>
70+
/// <returns>The output of the Hadamard product operation.</returns>
71+
public DeepMatrix Forward(DeepMatrix input1, Matrix input2)
72+
{
73+
this.ExtendOperations();
74+
var matrixArray = new Matrix[input1.Depth];
75+
for (int i = 0; i < input1.Depth; i++)
76+
{
77+
this.Operations[i] = new MatrixMultiplyOperation();
78+
matrixArray[i] = this.Operations[i].Forward(input1[i], input2);
79+
}
80+
81+
this.DeepOutput = new DeepMatrix(matrixArray);
82+
return this.DeepOutput;
83+
}
84+
85+
/// <inheritdoc />
86+
public override BackwardResult[] Backward(DeepMatrix dOutput)
87+
{
88+
var result = new BackwardResult[dOutput.Depth];
89+
Parallel.For(0, dOutput.Depth, i =>
90+
{
91+
result[i] = this.Operations[i].Backward(dOutput[i]);
92+
});
93+
return result;
94+
}
95+
}
96+
}

test/ParallelReverseAutoDiff.Test/GraphAttentionPaths/Transformer/Architecture/Transformer.json

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@
172172
},
173173
{
174174
"id": "fully_connected",
175-
"type": "BatchGpuMatrixMultiplyOperation",
175+
"type": "BatchMatrixMultiplyOperation",
176176
"inputs": [ "concatenated", "FW" ],
177177
"gradientResultTo": [ null, "DFW" ]
178178
},
@@ -184,7 +184,7 @@
184184
},
185185
{
186186
"id": "fully_connected_2",
187-
"type": "BatchGpuMatrixMultiplyOperation",
187+
"type": "BatchMatrixMultiplyOperation",
188188
"inputs": [ "concatenated", "F2W" ],
189189
"gradientResultTo": [ null, "DF2W" ]
190190
},

0 commit comments

Comments
 (0)