Skip to content

Commit 7f1df8c

Browse files
indexing tests
1 parent 249d4bf commit 7f1df8c

File tree

3 files changed

+64
-5
lines changed

3 files changed

+64
-5
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "MLDatasets"
22
uuid = "eb30cadb-4394-5ae3-aed4-317e484a6458"
3-
version = "0.5.10"
3+
version = "0.5.11"
44

55
[deps]
66
BinDeps = "9e28174c-4ba2-5203-b857-d8d62c4213ee"

src/TUDataset/TUDataset.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ end
156156

157157
Base.getindex(data::TUDataset, i::Int) = getindex(data, [i])
158158

159-
function Base.getindex(data::TUDataset, i::Vector{Int})
159+
function Base.getindex(data::TUDataset, i::AbstractVector{Int})
160160
node_mask = data.graph_indicator .∈ Ref(i)
161161

162162
nodes = (1:data.num_nodes)[node_mask]

test/tst_tudataset.jl

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ end
77

88
@test data.num_nodes == 43471
99
@test data.num_edges == 162088
10-
@test data.num_graphs === 1113
10+
@test data.num_graphs == 1113
1111

1212
@test length(data.source) == data.num_edges
1313
@test length(data.target) == data.num_edges
@@ -22,14 +22,38 @@ end
2222

2323
@test length(data.graph_indicator) == data.num_nodes
2424
@test all(sort(unique(data.graph_indicator)) .== 1:data.num_graphs)
25+
26+
@testset "indexing" begin
27+
d1, d2 = data[5], data[6]
28+
d12 = data[5:6]
29+
30+
@test d12.num_nodes == d1.num_nodes + d2.num_nodes
31+
@test d12.num_edges == d1.num_edges + d2.num_edges
32+
@test d12.num_graphs == d1.num_graphs + d2.num_graphs == 2
33+
34+
@test length(d12.source) == d12.num_edges
35+
@test length(d12.target) == d12.num_edges
36+
@test length(d1.source) == d1.num_edges
37+
@test length(d1.target) == d1.num_edges
38+
@test length(d2.source) == d2.num_edges
39+
@test length(d2.target) == d2.num_edges
40+
41+
@test length(d12.graph_indicator) == d12.num_nodes
42+
@test length(d1.graph_indicator) == d1.num_nodes
43+
@test length(d2.graph_indicator) == d2.num_nodes
44+
45+
@test all(d1.graph_indicator .== 1)
46+
@test all(d2.graph_indicator .== 1)
47+
@test sort(unique(d12.graph_indicator)) == [1,2]
48+
end
2549
end
2650

2751
@testset "TUDataset - QM9" begin
2852
data = TUDataset("QM9")
2953

3054
@test data.num_nodes == 2333625
3155
@test data.num_edges == 4823498
32-
@test data.num_graphs === 129433
56+
@test data.num_graphs == 129433
3357

3458
@test length(data.source) == data.num_edges
3559
@test length(data.target) == data.num_edges
@@ -43,5 +67,40 @@ end
4367
@test data.graph_labels === nothing
4468

4569
@test length(data.graph_indicator) == data.num_nodes
46-
@test all(sort(unique(data.graph_indicator)) .== 1:data.num_graphs)
70+
@test all(sort(unique(data.graph_indicator)) .== 1:data.num_graphs)
71+
72+
73+
@testset "indexing" begin
74+
d1, d2 = data[5], data[6]
75+
d12 = data[5:6]
76+
77+
@test d12.num_nodes == d1.num_nodes + d2.num_nodes
78+
@test d12.num_edges == d1.num_edges + d2.num_edges
79+
@test d12.num_graphs == d1.num_graphs + d2.num_graphs == 2
80+
81+
@test length(d12.source) == d12.num_edges
82+
@test length(d12.target) == d12.num_edges
83+
@test length(d1.source) == d1.num_edges
84+
@test length(d1.target) == d1.num_edges
85+
@test length(d2.source) == d2.num_edges
86+
@test length(d2.target) == d2.num_edges
87+
88+
@test length(d12.graph_indicator) == d12.num_nodes
89+
@test length(d1.graph_indicator) == d1.num_nodes
90+
@test length(d2.graph_indicator) == d2.num_nodes
91+
92+
@test all(d1.graph_indicator .== 1)
93+
@test all(d2.graph_indicator .== 1)
94+
@test sort(unique(d12.graph_indicator)) == [1,2]
95+
96+
@test size(d12.node_attributes) == (16, d12.num_nodes)
97+
@test size(d12.edge_attributes) == (4, d12.num_edges)
98+
@test size(d12.graph_attributes) == (19, d12.num_graphs)
99+
@test size(d1.node_attributes) == (16, d1.num_nodes)
100+
@test size(d1.edge_attributes) == (4, d1.num_edges)
101+
@test size(d1.graph_attributes) == (19, d1.num_graphs)
102+
@test size(d2.node_attributes) == (16, d2.num_nodes)
103+
@test size(d2.edge_attributes) == (4, d2.num_edges)
104+
@test size(d2.graph_attributes) == (19, d2.num_graphs)
105+
end
47106
end

0 commit comments

Comments
 (0)