Skip to content

Commit 249d4bf

Browse files
fix graph indicator when indexing
1 parent b044190 commit 249d4bf

File tree

1 file changed

+7
-3
lines changed

1 file changed

+7
-3
lines changed

src/TUDataset/TUDataset.jl

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,14 +154,18 @@ function datadir_tudataset(name, dir = nothing)
154154
return d
155155
end
156156

157-
function Base.getindex(data::TUDataset, i)
157+
Base.getindex(data::TUDataset, i::Int) = getindex(data, [i])
158+
159+
function Base.getindex(data::TUDataset, i::Vector{Int})
158160
node_mask = data.graph_indicator .∈ Ref(i)
159-
graph_indicator = data.graph_indicator[node_mask]
160161

161162
nodes = (1:data.num_nodes)[node_mask]
162163
node_labels = isnothing(data.node_labels) ? nothing : data.node_labels[node_mask]
163-
nodemap = Dict(v => i for (i, v) in enumerate(nodes))
164+
nodemap = Dict(v => vnew for (vnew, v) in enumerate(nodes))
164165

166+
graphmap = Dict(i => inew for (inew, i) in enumerate(i))
167+
graph_indicator = [graphmap[i] for i in data.graph_indicator[node_mask]]
168+
165169
edge_mask = data.source .∈ Ref(nodes)
166170
source = [nodemap[i] for i in data.source[edge_mask]]
167171
target = [nodemap[i] for i in data.target[edge_mask]]

0 commit comments

Comments
 (0)