@@ -270,12 +270,12 @@ function SparseArrays.blockdiag(g1::GNNGraph, g2::GNNGraph)
270
270
t = vcat (t1, nv1 .+ t2)
271
271
w = cat_features (get_edge_weight (g1), get_edge_weight (g2))
272
272
graph = (s, t, w)
273
- ind1 = isnothing (g1. graph_indicator) ? ones_like (s1, Int, nv1) : g1. graph_indicator
274
- ind2 = isnothing (g2. graph_indicator) ? ones_like (s2, Int, nv2) : g2. graph_indicator
273
+ ind1 = isnothing (g1. graph_indicator) ? ones_like (s1, nv1) : g1. graph_indicator
274
+ ind2 = isnothing (g2. graph_indicator) ? ones_like (s2, nv2) : g2. graph_indicator
275
275
elseif g1. graph isa ADJMAT_T
276
276
graph = blockdiag (g1. graph, g2. graph)
277
- ind1 = isnothing (g1. graph_indicator) ? ones_like (graph, Int, nv1) : g1. graph_indicator
278
- ind2 = isnothing (g2. graph_indicator) ? ones_like (graph, Int, nv2) : g2. graph_indicator
277
+ ind1 = isnothing (g1. graph_indicator) ? ones_like (graph, nv1) : g1. graph_indicator
278
+ ind2 = isnothing (g2. graph_indicator) ? ones_like (graph, nv2) : g2. graph_indicator
279
279
end
280
280
graph_indicator = vcat (ind1, g1. num_graphs .+ ind2)
281
281
@@ -358,27 +358,31 @@ julia> g12.ndata.x
358
358
1.0 1.0 1.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0
359
359
```
360
360
"""
361
- function Flux. batch (gs:: Vector{<:GNNGraph} )
362
- nodes = [g. num_nodes for g in gs]
363
-
364
- if all (y -> isa (y, COO_T), [g. graph for g in gs] )
365
- edge_indices = [edge_index (g) for g in gs]
366
- nodesum = cumsum ([0 , nodes... ])[1 : end - 1 ]
367
- s = cat_features ([ei[1 ] .+ nodesum[ii] for (ii, ei) in enumerate (edge_indices)])
368
- t = cat_features ([ei[2 ] .+ nodesum[ii] for (ii, ei) in enumerate (edge_indices)])
369
- w = reduce (vcat, [get_edge_weight (g) for g in gs])
370
- w = w isa Vector{Nothing} ? nothing : w
371
- graph = (s, t, w)
372
- graph_indicator = vcat ([ones_like (ei[1 ],Int,nodes[ii]) .+ (ii - 1 ) for (ii,ei) in enumerate (edge_indices)]. .. )
373
- elseif all (y -> isa (y, ADJMAT_T), [g. graph for g in gs] )
374
- graph = blockdiag ([g. graph for g in gs]. .. )
375
- graph_indicator = vcat ([ones_like (graph,Int,nodes[ii]) .+ (ii - 1 ) for ii in 1 : length (nodes)]. .. )
361
+ Flux. batch (gs:: AbstractVector{<:GNNGraph} ) = blockdiag (gs... )
362
+
363
+ function Flux. batch (gs:: AbstractVector{<:GNNGraph{T}} ) where T<: COO_T
364
+ v_num_nodes = [g. num_nodes for g in gs]
365
+ edge_indices = [edge_index (g) for g in gs]
366
+ nodesum = cumsum ([0 ; v_num_nodes])[1 : end - 1 ]
367
+ s = cat_features ([ei[1 ] .+ nodesum[ii] for (ii, ei) in enumerate (edge_indices)])
368
+ t = cat_features ([ei[2 ] .+ nodesum[ii] for (ii, ei) in enumerate (edge_indices)])
369
+ w = cat_features ([get_edge_weight (g) for g in gs])
370
+ graph = (s, t, w)
371
+
372
+ function materialize_graph_indicator (g)
373
+ g. graph_indicator === nothing ? ones_like (s, g. num_nodes) : g. graph_indicator
376
374
end
375
+
376
+ v_gi = materialize_graph_indicator .(gs)
377
+ v_num_graphs = [g. num_graphs for g in gs]
378
+ graphsum = cumsum ([0 ; v_num_graphs])[1 : end - 1 ]
379
+ v_gi = [ng .+ gi for (ng, gi) in zip (graphsum, v_gi)]
380
+ graph_indicator = cat_features (v_gi)
377
381
378
382
GNNGraph (graph,
379
- sum (nodes ),
383
+ sum (v_num_nodes ),
380
384
sum ([g. num_edges for g in gs]),
381
- sum ([g . num_graphs for g in gs] ),
385
+ sum (v_num_graphs ),
382
386
graph_indicator,
383
387
cat_features ([g. ndata for g in gs]),
384
388
cat_features ([g. edata for g in gs]),
0 commit comments