Skip to content

Commit c2bf5fe

Browse files
don't add reverse edges in movielens (#193)
* don't add reverse edges in movielens * cleanup
1 parent 90e996c commit c2bf5fe

File tree

3 files changed

+58
-58
lines changed

3 files changed

+58
-58
lines changed

src/datasets/graphs/movielens.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -429,9 +429,9 @@ function generate_movielens_graph(user_data::Dict, movie_data::Dict, rating_data
429429

430430
user_rates_movie = rating_data["user_movie"]
431431
user_ids, movie_ids = user_rates_movie[:, 1], user_rates_movie[:, 2]
432-
edge_indices = Dict(("user", "rating", "movie") => ([user_ids; movie_ids], [movie_ids; user_ids]))
432+
edge_indices = Dict(("user", "rating", "movie") => (user_ids, movie_ids))
433433

434-
edge_data = Dict(("user", "rating", "movie") => Dict(Symbol(k) => maybesqueeze([v;v]) for (k, v) in rating_data if k ["user_movie", "metadata"]))
434+
edge_data = Dict(("user", "rating", "movie") => Dict(Symbol(k) => maybesqueeze(v) for (k, v) in rating_data if k ["user_movie", "metadata"]))
435435

436436
return HeteroGraph(; num_nodes, edge_indices, node_data, edge_data)
437437
end
@@ -442,17 +442,17 @@ function generate_movielens_graph(movie_data::Dict, rating_data::Dict, user_tag_
442442
user_rates_movie = rating_data["user_movie"]
443443
user_ids, movie_ids = user_rates_movie[:, 1], user_rates_movie[:, 2]
444444
num_users = user_ids |> unique |> length # Calculate the number of users
445-
edge_indices[("user", "rating", "movie")] = ([user_ids; movie_ids], [movie_ids; user_ids])
445+
edge_indices[("user", "rating", "movie")] = (user_ids, movie_ids)
446446

447447
user_tags_movie = user_tag_data["user_movie"]
448448
user_ids, movie_ids = user_tags_movie[:, 1], user_tags_movie[:, 2]
449449
num_users = max(num_users, user_ids |> unique |> length)
450-
edge_indices[("user", "tag", "movie")] = ([user_ids; movie_ids], [movie_ids; user_ids])
450+
edge_indices[("user", "tag", "movie")] = (user_ids, movie_ids)
451451

452452
if !isempty(genome_tag_data)
453453
movie_score_tag = genome_tag_data["movie_tag"]
454454
movie_ids, tag_ids = movie_score_tag[:, 1], movie_score_tag[:, 1]
455-
edge_indices[("movie", "score", "tag")] = ([movie_ids; tag_ids], [movie_ids; tag_ids])
455+
edge_indices[("movie", "score", "tag")] = (movie_ids, tag_ids)
456456
end
457457

458458
# ideally the HeteroGraph function should be able to compute the number of egdes,
@@ -462,11 +462,11 @@ function generate_movielens_graph(movie_data::Dict, rating_data::Dict, user_tag_
462462

463463
_edge_data = Dict()
464464
_edge_data[("user", "rating", "movie")] = Dict(
465-
Symbol(k) => maybesqueeze([v;v]) for (k, v) in rating_data if k ["user_movie", "metadata"])
465+
Symbol(k) => maybesqueeze(v) for (k, v) in rating_data if k ["user_movie", "metadata"])
466466
_edge_data[("user", "tag", "movie")] = Dict(
467-
Symbol(k) => maybesqueeze([v;v]) for (k, v) in user_tag_data if k ["user_movie", "metadata"])
467+
Symbol(k) => maybesqueeze(v) for (k, v) in user_tag_data if k ["user_movie", "metadata"])
468468
isempty(genome_tag_data) || (_edge_data[("movie", "score", "tag")] = Dict(
469-
Symbol(k) => maybesqueeze([v;v]) for (k, v) in genome_tag_data if k ["movie_tag", "metadata", "num_tags"]))
469+
Symbol(k) => maybesqueeze(v) for (k, v) in genome_tag_data if k ["movie_tag", "metadata", "num_tags"]))
470470

471471
edge_data = Dict(k=>v for (k,v) in _edge_data if !isempty(v))
472472

test/datasets/graphs.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,3 +212,43 @@ end
212212
end
213213
end
214214

215+
@testset "ml-latest-small" begin
216+
data = MovieLens("latest-small")
217+
@test length(data) == 1
218+
219+
g = data[1]
220+
@test g == data[:]
221+
@test g isa MLDatasets.HeteroGraph
222+
223+
num_nodes = Dict(
224+
"tag" => 3683,
225+
"movie" => 9742,
226+
"user" => 610
227+
)
228+
num_edges = Dict(
229+
("user", "rating", "movie") => 100836,
230+
("user", "tag", "movie") => 3683
231+
)
232+
233+
for type in keys(num_nodes)
234+
@test type g.node_types
235+
@test g.num_nodes[type] == num_nodes[type]
236+
node_data = get(g.node_data, type, nothing)
237+
isnothing(node_data) || for (key, val) in node_data
238+
@test size(val)[end] == num_nodes[type]
239+
end
240+
end
241+
242+
for type in keys(num_edges)
243+
@test type g.edge_types
244+
@test g.num_edges[type] == num_edges[type]
245+
@test length(g.edge_indices[type][1]) == num_edges[type]
246+
@test length(g.edge_indices[type][2]) == num_edges[type]
247+
edge_data = g.edge_data[type]
248+
for (key, val) in edge_data
249+
@test key in [:timestamp, :tag_name, :rating]
250+
@test ndims(val) == 1
251+
@test size(val)[end] == num_edges[type]
252+
end
253+
end
254+
end

test/datasets/graphs_no_ci.jl

Lines changed: 10 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,44 +1,4 @@
11

2-
@testset "ml-latest-small" begin
3-
data = MovieLens("latest-small")
4-
@test length(data) == 1
5-
6-
g = data[1]
7-
@test g == data[:]
8-
@test g isa MLDatasets.HeteroGraph
9-
10-
num_nodes = Dict(
11-
"tag" => 3683,
12-
"movie" => 9742,
13-
"user" => 610
14-
)
15-
num_edges = Dict(
16-
("user", "rating", "movie") => 201672,
17-
("user", "tag", "movie") => 7366
18-
)
19-
20-
for type in keys(num_nodes)
21-
@test type g.node_types
22-
@test g.num_nodes[type] == num_nodes[type]
23-
node_data = get(g.node_data, type, nothing)
24-
isnothing(node_data) || for (key, val) in node_data
25-
@test size(val)[end] == num_nodes[type]
26-
end
27-
end
28-
29-
for type in keys(num_edges)
30-
@test type g.edge_types
31-
@test g.num_edges[type] == num_edges[type]
32-
@test length(g.edge_indices[type][1]) == num_edges[type]
33-
@test length(g.edge_indices[type][2]) == num_edges[type]
34-
edge_data = g.edge_data[type]
35-
for (key, val) in edge_data
36-
@test key in [:timestamp, :tag_name, :rating]
37-
@test ndims(val) == 1
38-
@test size(val)[end] == num_edges[type]
39-
end
40-
end
41-
end
422

433
@testset "ml-100k" begin
444
data = MovieLens("100k")
@@ -53,7 +13,7 @@ end
5313
"user" => 943,
5414
)
5515
num_edges = Dict(
56-
("user", "rating", "movie") => 200000
16+
("user", "rating", "movie") => 100000
5717
)
5818

5919
for type in keys(num_nodes)
@@ -94,7 +54,7 @@ end
9454
"user" => 6040
9555
)
9656
num_edges = Dict(
97-
("user", "rating", "movie") => 2000418
57+
("user", "rating", "movie") => 1000209
9858
)
9959

10060
for type in keys(num_nodes)
@@ -136,8 +96,8 @@ end
13696
"user" => 69878
13797
)
13898
num_edges = Dict(
139-
("user", "tag", "movie") => 191160,
140-
("user", "rating", "movie") => 20000108
99+
("user", "tag", "movie") => 95580,
100+
("user", "rating", "movie") => 10000054
141101
)
142102

143103
for type in keys(num_nodes)
@@ -177,9 +137,9 @@ end
177137
"user" => 138493
178138
)
179139
num_edges = Dict(
180-
("movie", "score", "tag") => 23419536,
181-
("user", "tag", "movie") => 931128,
182-
("user", "rating", "movie") => 40000526
140+
("movie", "score", "tag") => 11709768,
141+
("user", "tag", "movie") => 465564,
142+
("user", "rating", "movie") => 20000263
183143
)
184144

185145
for type in keys(num_nodes)
@@ -216,9 +176,9 @@ end
216176
"user" => 162541
217177
)
218178
num_edges = Dict(
219-
("movie", "score", "tag") => 31168896,
220-
("user", "tag", "movie") => 2186720,
221-
("user", "rating", "movie") => 50000190
179+
("movie", "score", "tag") => 15584448,
180+
("user", "tag", "movie") => 1093360,
181+
("user", "rating", "movie") => 25000095
222182
)
223183

224184
for type in keys(num_nodes)

0 commit comments

Comments
 (0)