Skip to content

Commit 773f653

Browse files
committed
Update edge attention and readout
1 parent e254e01 commit 773f653

File tree

6 files changed

+12
-7
lines changed

6 files changed

+12
-7
lines changed

examples/gnn/ParallelReverseAutoDiff.GnnExample/GraphAttentionPaths/EdgeAttention/Architecture/EdgeAttention.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,8 @@
106106
{
107107
"id": "concatenated",
108108
"type": "BatchMatrixConcatenateOperation",
109-
"inputs": [ "attention_weights_values_array" ]
109+
"inputs": [ "attention_weights_values_array" ],
110+
"switchFirstTwoDimensions": true
110111
},
111112
{
112113
"id": "reduce",

examples/gnn/ParallelReverseAutoDiff.GnnExample/GraphAttentionPaths/Readout/Architecture/Readout.json

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,8 @@
154154
{
155155
"id": "concatenated",
156156
"type": "BatchMatrixConcatenateOperation",
157-
"inputs": [ "attention_weights_values_array" ]
157+
"inputs": [ "attention_weights_values_array" ],
158+
"switchFirstTwoDimensions": true
158159
},
159160
{
160161
"id": "fully_connected",

examples/gnn/ParallelReverseAutoDiff.GnnExample/ParallelReverseAutoDiff.GnnExample.csproj

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@
8484

8585
<ItemGroup>
8686
<PackageReference Include="Newtonsoft.Json" Version="13.0.3" />
87-
<PackageReference Include="ParallelReverseAutoDiff" Version="1.1.5" />
87+
<PackageReference Include="ParallelReverseAutoDiff" Version="1.1.6" />
8888
<PackageReference Include="StyleCop.Analyzers" Version="1.2.0-beta.435">
8989
<PrivateAssets>all</PrivateAssets>
9090
<IncludeAssets>runtime; build; native; contentfiles; analyzers; buildtransitive</IncludeAssets>

test/ParallelReverseAutoDiff.Test/GraphAttentionPaths/minibatch2.json

Lines changed: 1 addition & 0 deletions
Large diffs are not rendered by default.

test/ParallelReverseAutoDiff.Test/GraphAttentionPathsNeuralNetworkTest.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public async Task GivenGraphAttentionPathsNeuralNetworkMiniBatch_ProcessesMiniBa
1414
CudaBlas.Instance.Initialize();
1515
try
1616
{
17-
var json = EmbeddedResource.ReadAllJson("ParallelReverseAutoDiff.Test.GraphAttentionPaths", "minibatch");
17+
var json = EmbeddedResource.ReadAllJson("ParallelReverseAutoDiff.Test.GraphAttentionPaths", "minibatch2");
1818
var graphs = JsonConvert.DeserializeObject<List<GapGraph>>(json);
1919

2020
for (int i = 0; i < graphs.Count; ++i)
@@ -24,9 +24,9 @@ public async Task GivenGraphAttentionPathsNeuralNetworkMiniBatch_ProcessesMiniBa
2424

2525
int batchSize = 4;
2626

27-
GraphAttentionPathsNeuralNetwork neuralNetwork = new GraphAttentionPathsNeuralNetwork(graphs, batchSize, 16, 115, 10, 2, 4, 0.001d, 4d);
27+
GraphAttentionPathsNeuralNetwork neuralNetwork = new GraphAttentionPathsNeuralNetwork(graphs, batchSize, 16, 115, 5, 2, 4, 0.001d, 4d);
2828
await neuralNetwork.Initialize();
29-
DeepMatrix gradientOfLoss = await neuralNetwork.Forward();
29+
DeepMatrix gradientOfLoss = neuralNetwork.Forward();
3030
await neuralNetwork.Backward(gradientOfLoss);
3131
}
3232
finally
@@ -63,7 +63,7 @@ public async Task GivenGraphAttentionPathsNeuralNetwork_UsesCudaOperationsSucces
6363

6464
GraphAttentionPathsNeuralNetwork neuralNetwork = new GraphAttentionPathsNeuralNetwork(graphs, batchSize, 10, 100, 10, 2, 4, 0.001d, 4d);
6565
await neuralNetwork.Initialize();
66-
DeepMatrix gradientOfLoss = await neuralNetwork.Forward();
66+
DeepMatrix gradientOfLoss = neuralNetwork.Forward();
6767
await neuralNetwork.Backward(gradientOfLoss);
6868
}
6969
finally

test/ParallelReverseAutoDiff.Test/ParallelReverseAutoDiff.Test.csproj

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
<None Remove="GraphAttentionPaths\GCN\Architecture\MessagePassing.json" />
1818
<None Remove="GraphAttentionPaths\LSTM\Architecture\NodeProcessing.json" />
1919
<None Remove="GraphAttentionPaths\minibatch.json" />
20+
<None Remove="GraphAttentionPaths\minibatch2.json" />
2021
<None Remove="GraphAttentionPaths\Readout\Architecture\Readout.json" />
2122
<None Remove="GraphAttentionPaths\Transformer\Architecture\Transformer.json" />
2223
</ItemGroup>
@@ -44,6 +45,7 @@
4445
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
4546
</EmbeddedResource>
4647
<EmbeddedResource Include="GraphAttentionPaths\minibatch.json" />
48+
<EmbeddedResource Include="GraphAttentionPaths\minibatch2.json" />
4749
<EmbeddedResource Include="GraphAttentionPaths\Readout\Architecture\Readout.json">
4850
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
4951
</EmbeddedResource>

0 commit comments

Comments
 (0)