Skip to content

Commit d1020e4

Browse files
authored
fixes add_edges + adds tests (#335)
* fixes add_edges + adds tests Fixes #334 * remove docs typo * fix typo in the test
1 parent a0797e5 commit d1020e4

File tree

2 files changed

+48
-2
lines changed

2 files changed

+48
-2
lines changed

src/GNNGraphs/transform.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ end
149149
add_edges(g::GNNHeteroGraph, edge_t, s, t; [edata, num_nodes])
150150
add_edges(g::GNNHeteroGraph, edge_t => (s, t); [edata, num_nodes])
151151
152-
Add to heterograph `g` the releation of type `edge_t` with source node vector `s` and target node vector `t`.
152+
Add to heterograph `g` the relation of type `edge_t` with source node vector `s` and target node vector `t`.
153153
Optionally, pass the features `edata` for the new edges.
154154
`edge_t` is a triplet of symbols `(srctype, etype, dsttype)`.
155155
@@ -192,7 +192,7 @@ function add_edges(g::GNNHeteroGraph{<:COO_T},
192192
if node_t ntypes
193193
push!(ntypes, node_t)
194194
if haskey(num_nodes, node_t)
195-
_num_nodes[node_t] == num_nodes[node_t]
195+
_num_nodes[node_t] = num_nodes[node_t]
196196
else
197197
_num_nodes[node_t] = maximum(st)
198198
end

test/GNNGraphs/gnnheterograph.jl

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,52 @@ end
116116
@test size(g[:B].y) == (d, 2*n)
117117
end
118118

119+
@testset "add_edges" begin
120+
d, n = 3, 5
121+
g = rand_bipartite_heterograph(n, 2 * n, 15)
122+
s, t = [1, 2, 3], [3, 2, 1]
123+
## Keep the same ntypes - construct with args
124+
g1 = add_edges(g, (:A, :rel1, :B), s, t)
125+
@test num_node_types(g1) == 2
126+
@test num_edge_types(g1) == 3
127+
for i in eachindex(s, t)
128+
@test has_edge(g1, (:A, :rel1, :B), s[i], t[i])
129+
end
130+
# no change to num_nodes
131+
@test g1.num_nodes[:A] == n
132+
@test g1.num_nodes[:B] == 2n
133+
134+
## Keep the same ntypes - construct with a pair
135+
g2 = add_edges(g, (:A, :rel1, :B) => (s, t))
136+
@test num_node_types(g2) == 2
137+
@test num_edge_types(g2) == 3
138+
for i in eachindex(s, t)
139+
@test has_edge(g2, (:A, :rel1, :B), s[i], t[i])
140+
end
141+
# no change to num_nodes
142+
@test g2.num_nodes[:A] == n
143+
@test g2.num_nodes[:B] == 2n
144+
145+
## New ntype with num_nodes (applies only to the new ntype) and edata
146+
edata = rand(Float32, d, length(s))
147+
g3 = add_edges(g,
148+
(:A, :rel1, :C) => (s, t);
149+
num_nodes = Dict(:A => 1, :B => 1, :C => 10),
150+
edata)
151+
@test num_node_types(g3) == 3
152+
@test num_edge_types(g3) == 3
153+
for i in eachindex(s, t)
154+
@test has_edge(g3, (:A, :rel1, :C), s[i], t[i])
155+
end
156+
# added edata
157+
@test g3.edata[(:A, :rel1, :C)].e == edata
158+
# no change to existing num_nodes
159+
@test g3.num_nodes[:A] == n
160+
@test g3.num_nodes[:B] == 2n
161+
# new num_nodes added as per kwarg
162+
@test g3.num_nodes[:C] == 10
163+
end
164+
119165
## Cannot test this because DataStore is not an ordered collection
120166
## Uncomment when/if it will be based on OrderedDict
121167
# @testset "show" begin

0 commit comments

Comments
 (0)