|
33 | 33 | (:B, :to, :A) => GraphConv(64 => 32, relu));
|
34 | 34 | @test length(layer.etypes) == 2
|
35 | 35 | end
|
| 36 | + |
| 37 | + @testset "Destination node aggregation" begin |
| 38 | + # deterministic setup to validate the aggregation |
| 39 | + d, n = 3, 5 |
| 40 | + g = GNNHeteroGraph(((:A, :to, :B) => ([1, 1, 2, 3], [1, 2, 2, 3]), |
| 41 | + (:B, :to, :A) => ([1, 1, 2, 3], [1, 2, 2, 3]), |
| 42 | + (:C, :to, :A) => ([1, 1, 2, 3], [1, 2, 2, 3])); num_nodes = Dict(:A => n, :B => n, :C => n)) |
| 43 | + model = HeteroGraphConv([ |
| 44 | + (:A, :to, :B) => GraphConv(d => d, init = ones, bias = false), |
| 45 | + (:B, :to, :A) => GraphConv(d => d, init = ones, bias = false), |
| 46 | + (:C, :to, :A) => GraphConv(d => d, init = ones, bias = false)]; aggr = +) |
| 47 | + x = (A = rand(Float32, d, n), B = rand(Float32, d, n), C = rand(Float32, d, n)) |
| 48 | + y = model(g, x) |
| 49 | + weights = ones(Float32, d, d) |
| 50 | + |
| 51 | + ### Test default summation aggregation |
| 52 | + # B2 has 2 edges from A and itself (sense check) |
| 53 | + expected = sum(weights * x.A[:, [1, 2]]; dims = 2) .+ weights * x.B[:, [2]] |
| 54 | + output = y.B[:, [2]] |
| 55 | + @test expected ≈ output |
| 56 | + |
| 57 | + # B5 has only itself |
| 58 | + @test weights * x.B[:, [5]] ≈ y.B[:, [5]] |
| 59 | + |
| 60 | + # A1 has 1 edge from B, 1 from C and twice itself |
| 61 | + expected = sum(weights * x.B[:, [1]] + weights * x.C[:, [1]]; dims = 2) .+ |
| 62 | + 2 * weights * x.A[:, [1]] |
| 63 | + output = y.A[:, [1]] |
| 64 | + @test expected ≈ output |
| 65 | + |
| 66 | + # A2 has 2 edges from B, 2 from C and twice itself |
| 67 | + expected = sum(weights * x.B[:, [1, 2]] + weights * x.C[:, [1, 2]]; dims = 2) .+ |
| 68 | + 2 * weights * x.A[:, [2]] |
| 69 | + output = y.A[:, [2]] |
| 70 | + @test expected ≈ output |
| 71 | + |
| 72 | + # A5 has only itself but twice |
| 73 | + @test 2 * weights * x.A[:, [5]] ≈ y.A[:, [5]] |
| 74 | + |
| 75 | + #### Test different aggregation function |
| 76 | + model2 = HeteroGraphConv([ |
| 77 | + (:A, :to, :B) => GraphConv(d => d, init = ones, bias = false), |
| 78 | + (:B, :to, :A) => GraphConv(d => d, init = ones, bias = false), |
| 79 | + (:C, :to, :A) => GraphConv(d => d, init = ones, bias = false)]; aggr = -) |
| 80 | + y2 = model2(g, x) |
| 81 | + # B no change |
| 82 | + @test y.B ≈ y2.B |
| 83 | + |
| 84 | + # A1 has 1 edge from B, 1 from C, itself cancels out |
| 85 | + expected = sum(weights * x.B[:, [1]] - weights * x.C[:, [1]]; dims = 2) |
| 86 | + output = y2.A[:, [1]] |
| 87 | + @test expected ≈ output |
| 88 | + |
| 89 | + # A2 has 2 edges from B, 2 from C, itself cancels out |
| 90 | + expected = sum(weights * x.B[:, [1, 2]] - weights * x.C[:, [1, 2]]; dims = 2) |
| 91 | + output = y2.A[:, [2]] |
| 92 | + @test expected ≈ output |
| 93 | + end |
36 | 94 | end
|
0 commit comments