@@ -112,6 +112,129 @@ end
112
112
@test sum (count .([d. graph_data. train_mask, d. graph_data. test_mask, d. graph_data. val_mask])) == length (d)
113
113
end
114
114
115
+ @testset " ml-100k" begin
116
+ data = MovieLens (" 100k" )
117
+ @test length (data) == 1
118
+
119
+ g = data[1 ]
120
+ @test g == data[:]
121
+ @test g isa MLDatasets. HeteroGraph
122
+
123
+ num_nodes = Dict (
124
+ " movie" => 1682 ,
125
+ " user" => 943 ,
126
+ )
127
+ num_edges = Dict (
128
+ (" user" , " rating" , " movie" ) => 200000
129
+ )
130
+
131
+ for type in keys (num_nodes)
132
+ @test type ∈ g. node_types
133
+ @test g. num_nodes[type] == num_nodes[type]
134
+ node_data = get (g. node_data, type, nothing )
135
+ @test ! isnothing (node_data)
136
+ for (key, val) in node_data
137
+ @test key ∈ [:release_date , :genres , :age , :occupation , :zipcode , :gender ]
138
+ @test size (val)[end ] == num_nodes[type]
139
+ end
140
+ end
141
+
142
+ for type in keys (num_edges)
143
+ @test type ∈ g. edge_types
144
+ @test g. num_edges[type] == num_edges[type]
145
+ @test length (g. edge_indices[type][1 ]) == num_edges[type]
146
+ @test length (g. edge_indices[type][2 ]) == num_edges[type]
147
+ edge_data = g. edge_data[type]
148
+ for (key, val) in edge_data
149
+ @test key in [:timestamp , :rating ]
150
+ @test ndims (val) == 1
151
+ @test size (val)[end ] == num_edges[type]
152
+ end
153
+ end
154
+ end
155
+
156
+ @testset " ml-latest-small" begin
157
+ data = MovieLens (" latest-small" )
158
+ @test length (data) == 1
159
+
160
+ g = data[1 ]
161
+ @test g == data[:]
162
+ @test g isa MLDatasets. HeteroGraph
163
+
164
+ num_nodes = Dict (
165
+ " tag" => 3683 ,
166
+ " movie" => 9742 ,
167
+ " user" => 610
168
+ )
169
+ num_edges = Dict (
170
+ (" user" , " rating" , " movie" ) => 201672 ,
171
+ (" user" , " tag" , " movie" ) => 7366
172
+ )
173
+
174
+ for type in keys (num_nodes)
175
+ @test type ∈ g. node_types
176
+ @test g. num_nodes[type] == num_nodes[type]
177
+ node_data = get (g. node_data, type, nothing )
178
+ isnothing (node_data) || for (key, val) in node_data
179
+ @test size (val)[end ] == num_nodes[type]
180
+ end
181
+ end
182
+
183
+ for type in keys (num_edges)
184
+ @test type ∈ g. edge_types
185
+ @test g. num_edges[type] == num_edges[type]
186
+ @test length (g. edge_indices[type][1 ]) == num_edges[type]
187
+ @test length (g. edge_indices[type][2 ]) == num_edges[type]
188
+ edge_data = g. edge_data[type]
189
+ for (key, val) in edge_data
190
+ @test key in [:timestamp , :tag_name , :rating ]
191
+ @test ndims (val) == 1
192
+ @test size (val)[end ] == num_edges[type]
193
+ end
194
+ end
195
+ end
196
+
197
+ @testset " ml-1m" begin
198
+ data = MovieLens (" 1m" )
199
+ @test length (data) == 1
200
+
201
+ g = data[1 ]
202
+ @test g == data[:]
203
+ @test g isa MLDatasets. HeteroGraph
204
+
205
+ num_nodes = Dict (
206
+ " movie" => 3883 ,
207
+ " user" => 6040
208
+ )
209
+ num_edges = Dict (
210
+ (" user" , " rating" , " movie" ) => 2000418
211
+ )
212
+
213
+ for type in keys (num_nodes)
214
+ @test type ∈ g. node_types
215
+ @test g. num_nodes[type] == num_nodes[type]
216
+ node_data = get (g. node_data, type, nothing )
217
+ @test ! isnothing (node_data)
218
+ for (key, val) in node_data
219
+ @test key ∈ [:genres , :age , :occupation , :zipcode , :gender ]
220
+ @test size (val)[end ] == num_nodes[type]
221
+ end
222
+ end
223
+
224
+ for type in keys (num_edges)
225
+ @test type ∈ g. edge_types
226
+ @test g. num_edges[type] == num_edges[type]
227
+ @test length (g. edge_indices[type][1 ]) == num_edges[type]
228
+ @test length (g. edge_indices[type][2 ]) == num_edges[type]
229
+ edge_data = g. edge_data[type]
230
+ for (key, val) in edge_data
231
+ @test key in [:timestamp , :rating ]
232
+ @test ndims (val) == 1
233
+ @test size (val)[end ] == num_edges[type]
234
+ end
235
+ end
236
+ end
237
+
115
238
@testset " ml-10m" begin
116
239
data = MovieLens (" 10m" )
117
240
@test length (data) == 1
135
258
@test g. num_nodes[type] == num_nodes[type]
136
259
node_data = get (g. node_data, type, nothing )
137
260
isnothing (node_data) || for (key, val) in node_data
138
- @test size (val)[end ] == num_nodes[type]
261
+ @test size (val)[end ] == num_nodes[type]
262
+ end
139
263
end
140
264
141
265
for type in keys (num_edges)
0 commit comments