29
29
g123 = Flux. batch ([g1, g2, g3])
30
30
@test g123. graph_indicator == [fill (1 , 10 ); fill (2 , 4 ); fill (3 , 7 )]
31
31
32
+ # Allow wider eltype
33
+ g123 = Flux. batch (GNNGraph[g1, g2, g3])
34
+ @test g123. graph_indicator == [fill (1 , 10 ); fill (2 , 4 ); fill (3 , 7 )]
35
+
36
+
32
37
s, t = edge_index (g123)
33
38
@test s == [edge_index (g1)[1 ]; 10 .+ edge_index (g2)[1 ]; 14 .+ edge_index (g3)[1 ]]
34
39
@test t == [edge_index (g1)[2 ]; 10 .+ edge_index (g2)[2 ]; 14 .+ edge_index (g3)[2 ]]
141
146
gnew = add_edges (g, (snew, tnew, wnew))
142
147
@test get_edge_weight (gnew) == [w; wnew]
143
148
end
144
-
145
- @testset " heterograph" begin
146
- hg = rand_bipartite_heterograph ((2 , 2 ), (4 , 0 ), bidirected= false )
147
- hg = add_edges (hg, (:B ,:to ,:A ), [1 , 1 ], [1 ,2 ])
148
- @test hg. num_edges == Dict ((:A ,:to ,:B ) => 4 , (:B ,:to ,:A ) => 2 )
149
- @test has_edge (hg, (:B ,:to ,:A ), 1 , 1 )
150
- @test has_edge (hg, (:B ,:to ,:A ), 1 , 2 )
151
- @test ! has_edge (hg, (:B ,:to ,:A ), 2 , 1 )
152
- @test ! has_edge (hg, (:B ,:to ,:A ), 2 , 2 )
153
-
154
- @testset " new nodes" begin
155
- hg = rand_bipartite_heterograph ((2 , 2 ), 3 )
156
- hg = add_edges (hg, (:C ,:rel ,:B ) => ([1 , 3 ], [1 ,2 ]))
157
- @test hg. num_nodes == Dict (:A => 2 , :B => 2 , :C => 3 )
158
- @test hg. num_edges == Dict ((:A ,:to ,:B ) => 3 , (:B ,:to ,:A ) => 3 , (:C ,:rel ,:B ) => 2 )
159
- s, t = edge_index (hg, (:C ,:rel ,:B ))
160
- @test s == [1 , 3 ]
161
- @test t == [1 , 2 ]
162
-
163
- hg = add_edges (hg, (:D ,:rel ,:F ) => ([1 , 3 ], [1 ,2 ]))
164
- @test hg. num_nodes == Dict (:A => 2 , :B => 2 , :C => 3 , :D => 3 , :F => 2 )
165
- @test hg. num_edges == Dict ((:A ,:to ,:B ) => 3 , (:B ,:to ,:A ) => 3 , (:C ,:rel ,:B ) => 2 , (:D ,:rel ,:F ) => 2 )
166
- s, t = edge_index (hg, (:D ,:rel ,:F ))
167
- @test s == [1 , 3 ]
168
- @test t == [1 , 2 ]
169
- end
170
-
171
- @testset " also add weights" begin
172
- hg = GNNHeteroGraph ((:user , :rate , :movie ) => ([1 ,1 ,2 ,3 ], [7 ,13 ,5 ,7 ], [0.1 , 0.2 , 0.3 , 0.4 ]))
173
- hgnew = add_edges (hg, (:user , :like , :actor ) => ([1 , 2 ], [3 , 4 ], [0.5 , 0.6 ]))
174
- @test hgnew. num_nodes[:user ] == 3
175
- @test hgnew. num_nodes[:movie ] == 13
176
- @test hgnew. num_nodes[:actor ] == 4
177
- @test hgnew. num_edges == Dict ((:user , :rate , :movie ) => 4 , (:user , :like , :actor ) => 2 )
178
- @test get_edge_weight (hgnew, (:user , :rate , :movie )) == [0.1 , 0.2 , 0.3 , 0.4 ]
179
- @test get_edge_weight (hgnew, (:user , :like , :actor )) == [0.5 , 0.6 ]
180
-
181
- hgnew2 = add_edges (hgnew, (:user , :like , :actor ) => ([6 , 7 ], [8 , 10 ], [0.7 , 0.8 ]))
182
- @test hgnew2. num_nodes[:user ] == 7
183
- @test hgnew2. num_nodes[:movie ] == 13
184
- @test hgnew2. num_nodes[:actor ] == 10
185
- @test hgnew2. num_edges == Dict ((:user , :rate , :movie ) => 4 , (:user , :like , :actor ) => 4 )
186
- @test get_edge_weight (hgnew2, (:user , :rate , :movie )) == [0.1 , 0.2 , 0.3 , 0.4 ]
187
- @test get_edge_weight (hgnew2, (:user , :like , :actor )) == [0.5 , 0.6 , 0.7 , 0.8 ]
188
- end
189
- end
190
149
end
191
150
end
192
151
@@ -358,26 +317,81 @@ end
358
317
0.0 0.0 0.0 ]
359
318
end
360
319
361
- @testset " batch heterograph" begin
362
- gs = [rand_bipartite_heterograph ((10 , 15 ), 20 ) for _ in 1 : 5 ]
363
- g = Flux. batch (gs)
364
- @test g. num_nodes[:A ] == 50
365
- @test g. num_nodes[:B ] == 75
366
- @test g. num_edges[(:A ,:to ,:B )] == 100
367
- @test g. num_edges[(:B ,:to ,:A )] == 100
368
- @test g. num_graphs == 5
369
- @test g. graph_indicator == Dict (:A => vcat ([fill (i, 10 ) for i in 1 : 5 ]. .. ),
370
- :B => vcat ([fill (i, 15 ) for i in 1 : 5 ]. .. ))
371
-
372
- for gi in gs
373
- gi. ndata[:A ]. x = ones (2 , 10 )
374
- gi. ndata[:A ]. y = zeros (10 )
375
- gi. edata[(:A ,:to ,:B )]. e = fill (2 , 20 )
376
- gi. gdata. u = 7
320
+ @testset " HeteroGraphs" begin
321
+ @testset " batch" begin
322
+ gs = [rand_bipartite_heterograph ((10 , 15 ), 20 ) for _ in 1 : 5 ]
323
+ g = Flux. batch (gs)
324
+ @test g. num_nodes[:A ] == 50
325
+ @test g. num_nodes[:B ] == 75
326
+ @test g. num_edges[(:A ,:to ,:B )] == 100
327
+ @test g. num_edges[(:B ,:to ,:A )] == 100
328
+ @test g. num_graphs == 5
329
+ @test g. graph_indicator == Dict (:A => vcat ([fill (i, 10 ) for i in 1 : 5 ]. .. ),
330
+ :B => vcat ([fill (i, 15 ) for i in 1 : 5 ]. .. ))
331
+
332
+ for gi in gs
333
+ gi. ndata[:A ]. x = ones (2 , 10 )
334
+ gi. ndata[:A ]. y = zeros (10 )
335
+ gi. edata[(:A ,:to ,:B )]. e = fill (2 , 20 )
336
+ gi. gdata. u = 7
337
+ end
338
+ g = Flux. batch (gs)
339
+ @test g. ndata[:A ]. x == ones (2 , 50 )
340
+ @test g. ndata[:A ]. y == zeros (50 )
341
+ @test g. edata[(:A ,:to ,:B )]. e == fill (2 , 100 )
342
+ @test g. gdata. u == fill (7 , 5 )
343
+
344
+ # Allow for wider eltype
345
+ g = Flux. batch (GNNHeteroGraph[g for g in gs])
346
+ @test g. ndata[:A ]. x == ones (2 , 50 )
347
+ @test g. ndata[:A ]. y == zeros (50 )
348
+ @test g. edata[(:A ,:to ,:B )]. e == fill (2 , 100 )
349
+ @test g. gdata. u == fill (7 , 5 )
350
+ end
351
+
352
+ @testset " add_edges" begin
353
+ hg = rand_bipartite_heterograph ((2 , 2 ), (4 , 0 ), bidirected= false )
354
+ hg = add_edges (hg, (:B ,:to ,:A ), [1 , 1 ], [1 ,2 ])
355
+ @test hg. num_edges == Dict ((:A ,:to ,:B ) => 4 , (:B ,:to ,:A ) => 2 )
356
+ @test has_edge (hg, (:B ,:to ,:A ), 1 , 1 )
357
+ @test has_edge (hg, (:B ,:to ,:A ), 1 , 2 )
358
+ @test ! has_edge (hg, (:B ,:to ,:A ), 2 , 1 )
359
+ @test ! has_edge (hg, (:B ,:to ,:A ), 2 , 2 )
360
+
361
+ @testset " new nodes" begin
362
+ hg = rand_bipartite_heterograph ((2 , 2 ), 3 )
363
+ hg = add_edges (hg, (:C ,:rel ,:B ) => ([1 , 3 ], [1 ,2 ]))
364
+ @test hg. num_nodes == Dict (:A => 2 , :B => 2 , :C => 3 )
365
+ @test hg. num_edges == Dict ((:A ,:to ,:B ) => 3 , (:B ,:to ,:A ) => 3 , (:C ,:rel ,:B ) => 2 )
366
+ s, t = edge_index (hg, (:C ,:rel ,:B ))
367
+ @test s == [1 , 3 ]
368
+ @test t == [1 , 2 ]
369
+
370
+ hg = add_edges (hg, (:D ,:rel ,:F ) => ([1 , 3 ], [1 ,2 ]))
371
+ @test hg. num_nodes == Dict (:A => 2 , :B => 2 , :C => 3 , :D => 3 , :F => 2 )
372
+ @test hg. num_edges == Dict ((:A ,:to ,:B ) => 3 , (:B ,:to ,:A ) => 3 , (:C ,:rel ,:B ) => 2 , (:D ,:rel ,:F ) => 2 )
373
+ s, t = edge_index (hg, (:D ,:rel ,:F ))
374
+ @test s == [1 , 3 ]
375
+ @test t == [1 , 2 ]
376
+ end
377
+
378
+ @testset " also add weights" begin
379
+ hg = GNNHeteroGraph ((:user , :rate , :movie ) => ([1 ,1 ,2 ,3 ], [7 ,13 ,5 ,7 ], [0.1 , 0.2 , 0.3 , 0.4 ]))
380
+ hgnew = add_edges (hg, (:user , :like , :actor ) => ([1 , 2 ], [3 , 4 ], [0.5 , 0.6 ]))
381
+ @test hgnew. num_nodes[:user ] == 3
382
+ @test hgnew. num_nodes[:movie ] == 13
383
+ @test hgnew. num_nodes[:actor ] == 4
384
+ @test hgnew. num_edges == Dict ((:user , :rate , :movie ) => 4 , (:user , :like , :actor ) => 2 )
385
+ @test get_edge_weight (hgnew, (:user , :rate , :movie )) == [0.1 , 0.2 , 0.3 , 0.4 ]
386
+ @test get_edge_weight (hgnew, (:user , :like , :actor )) == [0.5 , 0.6 ]
387
+
388
+ hgnew2 = add_edges (hgnew, (:user , :like , :actor ) => ([6 , 7 ], [8 , 10 ], [0.7 , 0.8 ]))
389
+ @test hgnew2. num_nodes[:user ] == 7
390
+ @test hgnew2. num_nodes[:movie ] == 13
391
+ @test hgnew2. num_nodes[:actor ] == 10
392
+ @test hgnew2. num_edges == Dict ((:user , :rate , :movie ) => 4 , (:user , :like , :actor ) => 4 )
393
+ @test get_edge_weight (hgnew2, (:user , :rate , :movie )) == [0.1 , 0.2 , 0.3 , 0.4 ]
394
+ @test get_edge_weight (hgnew2, (:user , :like , :actor )) == [0.5 , 0.6 , 0.7 , 0.8 ]
395
+ end
377
396
end
378
- g = Flux. batch (gs)
379
- @test g. ndata[:A ]. x == ones (2 , 50 )
380
- @test g. ndata[:A ]. y == zeros (50 )
381
- @test g. edata[(:A ,:to ,:B )]. e == fill (2 , 100 )
382
- @test g. gdata. u == fill (7 , 5 )
383
397
end
0 commit comments