File tree Expand file tree Collapse file tree 1 file changed +8
-9
lines changed Expand file tree Collapse file tree 1 file changed +8
-9
lines changed Original file line number Diff line number Diff line change @@ -541,25 +541,24 @@ function getgraph(g::GNNGraph, i::AbstractVector{Int}; nmap=false)
541
541
graphmap = Dict (i => inew for (inew, i) in enumerate (i))
542
542
graph_indicator = [graphmap[i] for i in g. graph_indicator[node_mask]]
543
543
544
+ s, t = edge_index (g)
545
+ w = edge_weight (g)
546
+ edge_mask = s .∈ Ref (nodes)
547
+
544
548
if g. graph isa COO_T
545
- s, t = edge_index (g)
546
- w = edge_weight (g)
547
- edge_mask = s .∈ Ref (nodes)
548
549
s = [nodemap[i] for i in s[edge_mask]]
549
550
t = [nodemap[i] for i in t[edge_mask]]
550
551
w = isnothing (w) ? nothing : w[edge_mask]
551
552
graph = (s, t, w)
552
- num_edges = length (s)
553
- edata = getobs (g. edata, edge_mask)
554
553
elseif g. graph isa ADJMAT_T
555
554
graph = g. graph[nodes, nodes]
556
- num_edges = count (>= (0 ), graph)
557
- @assert g. edata == (;) # TODO
558
- edata = (;)
559
555
end
556
+
560
557
ndata = getobs (g. ndata, node_mask)
558
+ edata = getobs (g. edata, edge_mask)
561
559
gdata = getobs (g. gdata, i)
562
-
560
+
561
+ num_edges = sum (edge_mask)
563
562
num_nodes = length (graph_indicator)
564
563
num_graphs = length (i)
565
564
You can’t perform that action at this time.
0 commit comments