File tree Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Original file line number Diff line number Diff line change @@ -154,14 +154,18 @@ function datadir_tudataset(name, dir = nothing)
154
154
return d
155
155
end
156
156
157
- function Base. getindex (data:: TUDataset , i)
157
+ Base. getindex (data:: TUDataset , i:: Int ) = getindex (data, [i])
158
+
159
+ function Base. getindex (data:: TUDataset , i:: Vector{Int} )
158
160
node_mask = data. graph_indicator .∈ Ref (i)
159
- graph_indicator = data. graph_indicator[node_mask]
160
161
161
162
nodes = (1 : data. num_nodes)[node_mask]
162
163
node_labels = isnothing (data. node_labels) ? nothing : data. node_labels[node_mask]
163
- nodemap = Dict (v => i for (i , v) in enumerate (nodes))
164
+ nodemap = Dict (v => vnew for (vnew , v) in enumerate (nodes))
164
165
166
+ graphmap = Dict (i => inew for (inew, i) in enumerate (i))
167
+ graph_indicator = [graphmap[i] for i in data. graph_indicator[node_mask]]
168
+
165
169
edge_mask = data. source .∈ Ref (nodes)
166
170
source = [nodemap[i] for i in data. source[edge_mask]]
167
171
target = [nodemap[i] for i in data. target[edge_mask]]
You can’t perform that action at this time.
0 commit comments