Skip to content

Commit d4db332

Browse files
fix getgraph
1 parent d968d90 commit d4db332

File tree

1 file changed

+8
-9
lines changed

1 file changed

+8
-9
lines changed

src/gnngraph.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -541,25 +541,24 @@ function getgraph(g::GNNGraph, i::AbstractVector{Int}; nmap=false)
541541
graphmap = Dict(i => inew for (inew, i) in enumerate(i))
542542
graph_indicator = [graphmap[i] for i in g.graph_indicator[node_mask]]
543543

544+
s, t = edge_index(g)
545+
w = edge_weight(g)
546+
edge_mask = s .∈ Ref(nodes)
547+
544548
if g.graph isa COO_T
545-
s, t = edge_index(g)
546-
w = edge_weight(g)
547-
edge_mask = s .∈ Ref(nodes)
548549
s = [nodemap[i] for i in s[edge_mask]]
549550
t = [nodemap[i] for i in t[edge_mask]]
550551
w = isnothing(w) ? nothing : w[edge_mask]
551552
graph = (s, t, w)
552-
num_edges = length(s)
553-
edata = getobs(g.edata, edge_mask)
554553
elseif g.graph isa ADJMAT_T
555554
graph = g.graph[nodes, nodes]
556-
num_edges = count(>=(0), graph)
557-
@assert g.edata == (;) # TODO
558-
edata = (;)
559555
end
556+
560557
ndata = getobs(g.ndata, node_mask)
558+
edata = getobs(g.edata, edge_mask)
561559
gdata = getobs(g.gdata, i)
562-
560+
561+
num_edges = sum(edge_mask)
563562
num_nodes = length(graph_indicator)
564563
num_graphs = length(i)
565564

0 commit comments

Comments
 (0)