Skip to content

Commit dac5a8e

Browse files
authored
Support multiple labels for nodes, edges and graphs (#196)
* Support multiple labels for nodes, edges and graphs * Use as feature matrix and not as vector * Add test for Cuneiform
1 parent b9a6f71 commit dac5a8e

File tree

2 files changed

+37
-3
lines changed

2 files changed

+37
-3
lines changed

src/datasets/graphs/tudataset.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -72,13 +72,13 @@ function TUDataset(name; dir=nothing)
7272
# LOAD OPTIONAL FILES IF EXIST
7373

7474
node_labels = isfile(joinpath(d, "$(name)_node_labels.txt")) ?
75-
readdlm(joinpath(d, "$(name)_node_labels.txt"), Int) |> vec :
75+
readdlm(joinpath(d, "$(name)_node_labels.txt"), ',', Int)' |> collect |> maybesqueeze :
7676
nothing
7777
edge_labels = isfile(joinpath(d, "$(name)_edge_labels.txt")) ?
78-
readdlm(joinpath(d, "$(name)_edge_labels.txt"), Int) |> vec :
78+
readdlm(joinpath(d, "$(name)_edge_labels.txt"), ',', Int)' |> collect |> maybesqueeze :
7979
nothing
8080
graph_labels = isfile(joinpath(d, "$(name)_graph_labels.txt")) ?
81-
readdlm(joinpath(d, "$(name)_graph_labels.txt"), Int) |> vec :
81+
readdlm(joinpath(d, "$(name)_graph_labels.txt"), ',', Int)' |> collect |> maybesqueeze :
8282
nothing
8383

8484
node_attributes = isfile(joinpath(d, "$(name)_node_attributes.txt")) ?

test/datasets/graphs.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -252,3 +252,37 @@ end
252252
end
253253
end
254254
end
255+
256+
@testset "TUDataset - Cuneiform" begin
257+
data = TUDataset("Cuneiform")
258+
259+
@test data.num_nodes == 5680
260+
@test data.num_edges == 23922
261+
@test data.num_graphs == 267
262+
263+
@test data.num_nodes == sum(g->g.num_nodes, data.graphs)
264+
@test data.num_edges == sum(g->g.num_edges, data.graphs)
265+
@test data.num_edges == sum(g->length(g.edge_index[1]), data.graphs)
266+
@test data.num_edges == sum(g->length(g.edge_index[2]), data.graphs)
267+
@test data.num_graphs == length(data) == length(data.graphs)
268+
269+
i = rand(1:length(data))
270+
di = data[i]
271+
@test di isa NamedTuple
272+
g, targets = di.graphs, di.targets
273+
@test targets isa Int
274+
@test g isa Graph
275+
@test all(1 .<= g.edge_index[1] .<= g.num_nodes)
276+
@test all(1 .<= g.edge_index[2] .<= g.num_nodes)
277+
278+
# graph data
279+
@test size(data.graph_data.targets) == (data.num_graphs, )
280+
281+
# node data
282+
@test size(g.node_data.features) == (3, g.num_nodes)
283+
@test size(g.node_data.targets) == (2, g.num_nodes)
284+
285+
# edge data
286+
@test size(g.edge_data.features) == (2, g.num_edges)
287+
@test size(g.edge_data.targets) == (g.num_edges, )
288+
end

0 commit comments

Comments
 (0)