Skip to content

Commit eca4185

Browse files
committed
Merge branch '0.9.2' of https://github.com/iMoonLab/DeepHypergraph into 0.9.2
2 parents 126f844 + 661c0f0 commit eca4185

File tree

7 files changed

+91
-75
lines changed

7 files changed

+91
-75
lines changed

dhg/utils/structure.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def remap_edge_list(
6565
``bipartite_graph`` (``bool``): Whether the structure is bipartite graph. Defaults to ``False``.
6666
``ret_map`` (``bool``): Whether to return the map dictionary of raw marker to new index. Defaults to ``False``.
6767
"""
68+
e_list = [[str(v) for v in e] for e in e_list]
6869
if bipartite_graph:
6970
u_set, v_set = set(), set()
7071
for u, v in e_list:
@@ -87,7 +88,7 @@ def remap_edge_list(
8788
v_set.add(v)
8889
v_list = sorted(v_set)
8990
v_map = {raw_v: new_v for new_v, raw_v in enumerate(v_list)}
90-
e_list = [set([v_map[v] for v in e]) for e in e_list]
91+
e_list = [tuple([v_map[v] for v in e]) for e in e_list]
9192
if ret_map:
9293
return e_list, v_map
9394
else:
@@ -107,6 +108,7 @@ def remap_edge_lists(
107108
``bipartite_graph`` (``bool``): Whether the structure is bipartite graph. Defaults to ``False``.
108109
``ret_map`` (``bool``): Whether to return the map dictionary of raw marker to new index. Defaults to ``False``.
109110
"""
111+
e_lists = [[[str(v) for v in e] for e in e_list] for e_list in e_lists]
110112
if bipartite_graph:
111113
u_set, v_set = set(), set()
112114
for e_list in e_lists:
@@ -131,7 +133,7 @@ def remap_edge_lists(
131133
v_set.add(v)
132134
v_list = sorted(v_set)
133135
v_map = {raw_v: new_v for new_v, raw_v in enumerate(v_list)}
134-
e_list = [[set([v_map[v] for v in e]) for e in e_list] for e_list in e_lists]
136+
e_list = [[tuple([v_map[v] for v in e]) for e in e_list] for e_list in e_lists]
135137
if ret_map:
136138
return e_list, v_map
137139
else:
@@ -151,6 +153,7 @@ def remap_adj_list(
151153
``bipartite_graph`` (``bool``): Whether the structure is bipartite graph. Defaults to ``False``.
152154
``ret_map`` (``bool``): Whether to return the map dictionary of raw marker to new index. Defaults to ``False``.
153155
"""
156+
adj_list = [[str(v) for v in line] for line in adj_list]
154157
if bipartite_graph:
155158
u_set, v_set = set(), set()
156159
for line in adj_list:
@@ -202,6 +205,7 @@ def remap_adj_lists(
202205
``bipartite_graph`` (``bool``): Whether the structure is bipartite graph. Defaults to ``False``.
203206
``ret_map`` (``bool``): Whether to return the map dictionary of raw marker to new index. Defaults to ``False``.
204207
"""
208+
adj_lists = [[[str(v) for v in line] for line in adj_list] for adj_list in adj_lists]
205209
if bipartite_graph:
206210
u_set, v_set = set(), set()
207211
for adj_list in adj_lists:

tests/datapipe/test_loaders.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,12 @@
44

55

66
def test_load_from_txt(tmp_path):
7-
origin = [
8-
[0, 1, 2],
9-
[3, 4],
10-
[5, 6, 7, 8]
11-
]
12-
with open(tmp_path, 'w') as f:
7+
tmp_file_name = tmp_path / "test_load_from_txt.txt"
8+
origin = [[0, 1, 2], [3, 4], [5, 6, 7, 8]]
9+
with open(tmp_file_name, "w") as f:
1310
for ori in origin:
14-
f.write(' '.join(map(str, ori)) + '\n')
15-
data = load_from_txt(tmp_path)
11+
f.write(" ".join(map(str, ori)) + "\n")
12+
data = load_from_txt(tmp_file_name, "int")
1613
for ori, dat in zip(origin, data):
1714
for a, b in zip(ori, dat):
1815
assert a == b

tests/metrics/test_classification.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22

33
import torch
4+
import numpy as np
45
from sklearn.metrics import f1_score, confusion_matrix
56
import dhg.metrics.classification as dm
67

tests/structure/test_hypergraph.py

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -236,6 +236,7 @@ def test_add_hyperedges_from_bigraph():
236236
assert (0, 1) in h.e_of_group("bigraph-v")[0]
237237
assert (0, 2) in h.e_of_group("bigraph-v")[0]
238238

239+
239240
def test_remove_hyperedges(g1):
240241
assert g1.e[0] == [(0, 1, 2, 5), (0, 1), (2, 3, 4)]
241242
assert g1.e[1] == [1, 1, 1]
@@ -580,13 +581,12 @@ def test_smoothing():
580581
assert pytest.approx(g.smoothing(x, L, lbd)) == x + lbd * L @ x
581582

582583

583-
584584
def test_L_sym(g1):
585585
H = g1.H.to_dense().cpu()
586586
D_v_neg_1_2 = torch.diag(H.sum(dim=1).view(-1) ** (-0.5))
587587
D_e_neg_1 = torch.diag(H.sum(dim=0).view(-1) ** (-1))
588588
W_e = g1.W_e.to_dense()
589-
L_sym = torch.eye(H.shape[0]) - D_v_neg_1_2 @ H @ W_e @ D_e_neg_1 @ H.t() @ D_v_neg_1_2
589+
L_sym = torch.eye(H.shape[0]) - D_v_neg_1_2 @ H @ W_e @ D_e_neg_1 @ H.t() @ D_v_neg_1_2
590590
assert (L_sym == g1.L_sym.to_dense().cpu()).all()
591591

592592

@@ -597,32 +597,32 @@ def test_L_sym_group(g1):
597597
D_v_neg_1_2 = torch.diag(H.sum(dim=1).view(-1) ** (-0.5))
598598
D_e_neg_1 = torch.diag(H.sum(dim=0).view(-1) ** (-1))
599599
W_e = g1.W_e.to_dense()
600-
L_sym = torch.eye(H.shape[0]) - D_v_neg_1_2 @ H @ W_e @ D_e_neg_1 @ H.t() @ D_v_neg_1_2
600+
L_sym = torch.eye(H.shape[0]) - D_v_neg_1_2 @ H @ W_e @ D_e_neg_1 @ H.t() @ D_v_neg_1_2
601601
assert (L_sym == g1.L_sym.to_dense().cpu()).all()
602602
# main group
603603
H = g1.H_of_group("main").to_dense().cpu()
604604
D_v_neg_1_2 = torch.diag(H.sum(dim=1).view(-1) ** (-0.5))
605605
D_e_neg_1 = torch.diag(H.sum(dim=0).view(-1) ** (-1))
606606
W_e = g1.W_e_of_group("main").to_dense()
607-
L_sym = torch.eye(H.shape[0]) - D_v_neg_1_2 @ H @ W_e @ D_e_neg_1 @ H.t() @ D_v_neg_1_2
608-
assert (L_sym == g1.L_sym_of_group('main').to_dense().cpu()).all()
607+
L_sym = torch.eye(H.shape[0]) - D_v_neg_1_2 @ H @ W_e @ D_e_neg_1 @ H.t() @ D_v_neg_1_2
608+
assert (L_sym == g1.L_sym_of_group("main").to_dense().cpu()).all()
609609
# knn group
610610
H = g1.H_of_group("knn").to_dense().cpu()
611611
D_v_neg_1_2 = H.sum(dim=1).view(-1) ** (-0.5)
612612
D_v_neg_1_2[torch.isinf(D_v_neg_1_2)] = 0
613613
D_v_neg_1_2 = torch.diag(D_v_neg_1_2)
614614
D_e_neg_1 = torch.diag(H.sum(dim=0).view(-1) ** (-1))
615615
W_e = g1.W_e_of_group("knn").to_dense()
616-
L_sym = torch.eye(H.shape[0]) - D_v_neg_1_2 @ H @ W_e @ D_e_neg_1 @ H.t() @ D_v_neg_1_2
617-
assert (L_sym == g1.L_sym_of_group('knn').to_dense().cpu()).all()
616+
L_sym = torch.eye(H.shape[0]) - D_v_neg_1_2 @ H @ W_e @ D_e_neg_1 @ H.t() @ D_v_neg_1_2
617+
assert (L_sym == g1.L_sym_of_group("knn").to_dense().cpu()).all()
618618

619619

620620
def test_L_rw(g1):
621621
H = g1.H.to_dense().cpu()
622622
D_v_neg_1 = torch.diag(H.sum(dim=1).view(-1) ** (-1))
623623
D_e_neg_1 = torch.diag(H.sum(dim=0).view(-1) ** (-1))
624624
W_e = g1.W_e.to_dense()
625-
L_rw = torch.eye(H.shape[0]) - D_v_neg_1 @ H @ W_e @ D_e_neg_1 @ H.t()
625+
L_rw = torch.eye(H.shape[0]) - D_v_neg_1 @ H @ W_e @ D_e_neg_1 @ H.t()
626626
assert (L_rw == g1.L_rw.to_dense().cpu()).all()
627627

628628

@@ -633,24 +633,24 @@ def test_L_rw_group(g1):
633633
D_v_neg_1 = torch.diag(H.sum(dim=1).view(-1) ** (-1))
634634
D_e_neg_1 = torch.diag(H.sum(dim=0).view(-1) ** (-1))
635635
W_e = g1.W_e.to_dense()
636-
L_rw = torch.eye(H.shape[0]) - D_v_neg_1 @ H @ W_e @ D_e_neg_1 @ H.t()
636+
L_rw = torch.eye(H.shape[0]) - D_v_neg_1 @ H @ W_e @ D_e_neg_1 @ H.t()
637637
assert (L_rw == g1.L_rw.to_dense().cpu()).all()
638638
# main group
639639
H = g1.H_of_group("main").to_dense().cpu()
640640
D_v_neg_1 = torch.diag(H.sum(dim=1).view(-1) ** (-1))
641641
D_e_neg_1 = torch.diag(H.sum(dim=0).view(-1) ** (-1))
642642
W_e = g1.W_e_of_group("main").to_dense()
643-
L_rw = torch.eye(H.shape[0]) - D_v_neg_1 @ H @ W_e @ D_e_neg_1 @ H.t()
644-
assert (L_rw == g1.L_rw_of_group('main').to_dense().cpu()).all()
643+
L_rw = torch.eye(H.shape[0]) - D_v_neg_1 @ H @ W_e @ D_e_neg_1 @ H.t()
644+
assert (L_rw == g1.L_rw_of_group("main").to_dense().cpu()).all()
645645
# knn group
646646
H = g1.H_of_group("knn").to_dense().cpu()
647647
D_v_neg_1 = H.sum(dim=1).view(-1) ** (-1)
648648
D_v_neg_1[torch.isinf(D_v_neg_1)] = 0
649649
D_v_neg_1 = torch.diag(D_v_neg_1)
650650
D_e_neg_1 = torch.diag(H.sum(dim=0).view(-1) ** (-1))
651651
W_e = g1.W_e_of_group("knn").to_dense()
652-
L_rw = torch.eye(H.shape[0]) - D_v_neg_1 @ H @ W_e @ D_e_neg_1 @ H.t()
653-
assert (L_rw == g1.L_rw_of_group('knn').to_dense().cpu()).all()
652+
L_rw = torch.eye(H.shape[0]) - D_v_neg_1 @ H @ W_e @ D_e_neg_1 @ H.t()
653+
assert (L_rw == g1.L_rw_of_group("knn").to_dense().cpu()).all()
654654

655655

656656
def test_smoothing_with_HGNN(g1):

tests/utils/test_sparse.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66

77
def test_sparse_dropout():
8-
a = torch.rand(10, 20)
8+
a = (torch.rand(10, 20) > 0.7).float()
99

1010
idx = torch.nonzero(a).T
1111
data = a[idx[0], idx[1]]
@@ -15,7 +15,7 @@ def test_sparse_dropout():
1515

1616
assert coo.size() == dropped.size()
1717

18-
assert dropped._nnz() == pytest.approx(coo._nnz() * 0.7, 0.1)
18+
assert (dropped._values()!=0).sum() == pytest.approx(coo._nnz() * 0.7, 0.15)
1919

2020
for i in range(10):
2121
for j in range(20):

tests/utils/test_split.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -26,20 +26,34 @@ def test_split_by_ratio():
2626
v_label = np.random.randint(0, 10, n)
2727
r_train, r_val, r_test = 0.5, 0.2, 0.1
2828
m_train, m_val, m_test = split_by_ratio(n, v_label, r_train, r_val, r_test)
29-
assert m_train.sum() == 500
30-
assert m_val.sum() == 200
31-
assert m_test.sum() == 100
29+
assert pytest.approx(m_train.sum(), 5) == 500
30+
assert pytest.approx(m_val.sum(), 5) == 200
31+
assert pytest.approx(m_test.sum(), 5) == 100
3232

3333
m_train, m_val, m_test = split_by_ratio(n, v_label, r_train, r_val)
34-
assert m_train.sum() == 500
35-
assert m_val.sum() == 200
36-
assert m_test.sum() == 300
34+
assert pytest.approx(m_train.sum(), 5) == 500
35+
assert pytest.approx(m_val.sum(), 5) == 200
36+
assert pytest.approx(m_test.sum(), 5) == 300
3737

3838

3939
def test_split_by_num_for_UI_bigraph():
4040
e_list = [
41-
[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [1, 1], [1, 2], [1, 3], [1, 4], [2, 2], [2, 3], [2, 4], [3, 3], [3, 4], [4, 4]
42-
]
41+
[0, 0],
42+
[0, 1],
43+
[0, 2],
44+
[0, 3],
45+
[0, 4],
46+
[1, 1],
47+
[1, 2],
48+
[1, 3],
49+
[1, 4],
50+
[2, 2],
51+
[2, 3],
52+
[2, 4],
53+
[3, 3],
54+
[3, 4],
55+
[4, 4],
56+
]
4357
g = dhg.BiGraph(5, 5, e_list)
4458
train_num = 3
4559
train_adj, test_adj = split_by_num_for_UI_bigraph(g, train_num)
@@ -54,20 +68,15 @@ def test_split_by_num_for_UI_bigraph():
5468
assert len(test_adj[1]) == 2
5569

5670

57-
5871
def test_split_by_ratio_for_UI_bigraph():
59-
e_list = [
60-
[0, 0], [0, 1], [0, 2], [0, 3], [0, 4], [1, 1], [1, 2], [1, 3], [1, 4], [2, 2], [2, 3], [2, 4], [3, 3], [3, 4], [4, 4]
61-
]
62-
g = dhg.BiGraph(5, 5, e_list)
72+
e_list = []
73+
for idx in range(100):
74+
e_list.append((0, idx))
75+
e_list.append((1, idx))
76+
g = dhg.BiGraph(2, 100, e_list)
6377
train_ratio = 0.6
6478
train_adj, test_adj = split_by_ratio_for_UI_bigraph(g, train_ratio)
65-
assert len(train_adj) == 5
66-
assert len(test_adj) == 3
67-
assert len(train_adj[0]) == 4
68-
assert len(train_adj[1]) == 3
69-
assert len(train_adj[2]) == 2
70-
assert len(train_adj[3]) == 2
71-
assert len(test_adj[0]) == 3
72-
assert len(test_adj[1]) == 2
73-
assert len(test_adj[2]) == 2
79+
assert (len(train_adj[0]) - 1) / len(g.nbr_v(0)) == pytest.approx(train_ratio, 0.1)
80+
assert (len(train_adj[1]) - 1) / len(g.nbr_v(1)) == pytest.approx(train_ratio, 0.1)
81+
assert (len(test_adj[0]) - 1) / len(g.nbr_v(0)) == pytest.approx(1 - train_ratio, 0.1)
82+
assert (len(test_adj[1]) - 1) / len(g.nbr_v(1)) == pytest.approx(1 - train_ratio, 0.1)

tests/utils/test_structure.py

Lines changed: 33 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -3,62 +3,67 @@
33

44

55
def test_remap_edge_list():
6-
e_list = [(1, 3), ('A', 100), (4.5, 'A')]
6+
e_list = [(1, 3), (100, "A"), (4.5, "A")]
77

88
e1, m = utils.remap_edge_list(e_list, ret_map=True)
99
for i, (u, v) in enumerate(e_list):
10-
assert e1[i][0] == m[u]
11-
assert e1[i][1] == m[v]
10+
assert e1[i][0] == m[str(u)]
11+
assert e1[i][1] == m[str(v)]
1212

13-
e2, m = utils.remap_edge_list(e_list, ret_map=True, bipartite_graph=True)
13+
e2, m_u, m_v = utils.remap_edge_list(e_list, ret_map=True, bipartite_graph=True)
1414
for i, (u, v) in enumerate(e_list):
15-
assert e2[i][0] == m[u]
16-
assert e2[i][1] == m[v]
15+
assert e2[i][0] == m_u[str(u)]
16+
assert e2[i][1] == m_v[str(v)]
1717

1818

1919
def test_remap_edge_lists():
20-
e_list = [[(1, 3), ('A', 100), (4.5, 'A')], [(1, 5), ('B', 101), (4.1, 'A')]]
20+
e_lists = [[(1, 3), (100, "A"), (4.5, "A")], [(1, 5), ("B", 101), (4.1, "A")]]
2121

22-
e1, m = utils.remap_edge_lists(e_list, ret_map=True)
23-
for i, ee in enumerate(e_list):
24-
for j, (u, v) in enumerate(ee):
25-
assert e1[i][j][0] == m[u]
26-
assert e1[i][j][1] == m[v]
22+
e1, m = utils.remap_edge_lists(*e_lists, ret_map=True)
23+
for i, e_list in enumerate(e_lists):
24+
for j, (u, v) in enumerate(e_list):
25+
assert e1[i][j][0] == m[str(u)]
26+
assert e1[i][j][1] == m[str(v)]
2727

28-
e2, m = utils.remap_edge_list(e_list, ret_map=True, bipartite_graph=True)
29-
for i, ee in enumerate(e_list):
30-
for j, (u, v) in enumerate(ee):
31-
assert e2[i][j][0] == m[u]
32-
assert e2[i][j][1] == m[v]
28+
e2, m_u, m_v = utils.remap_edge_lists(*e_lists, ret_map=True, bipartite_graph=True)
29+
for i, e_list in enumerate(e_lists):
30+
for j, (u, v) in enumerate(e_list):
31+
assert e2[i][j][0] == m_u[str(u)]
32+
assert e2[i][j][1] == m_v[str(v)]
3333

3434

3535
def test_remap_adj_list():
36-
adj_list = [[0, 'A', 1.5], ['A', 1, 2], [1.5, 1, 0]]
36+
adj_list = [[0, "A", 1.5], ["A", 1, 2], [1.5, 1, 0]]
3737
e1, m = utils.remap_adj_list(adj_list, ret_map=True)
3838
for i, adj in enumerate(adj_list):
3939
for j, a in enumerate(adj):
40-
assert e1[i][j] == m[a]
40+
assert e1[i][j] == m[str(a)]
4141

42-
e2, m = utils.remap_adj_list(adj_list, ret_map=True, bipartite_graph=True)
42+
e2, m_u, m_v = utils.remap_adj_list(adj_list, ret_map=True, bipartite_graph=True)
4343
for i, adj in enumerate(adj_list):
4444
for j, a in enumerate(adj):
45-
assert e2[i][j] == m[a]
45+
if j == 0:
46+
assert e2[i][j] == m_u[str(a)]
47+
else:
48+
assert e2[i][j] == m_v[str(a)]
4649

4750

4851
def test_remap_adj_lists():
49-
adj_lists = [[[0, 'A', 1.5], ['A', 1, 2], [1.5, 1, 0]], [[0, 3, 'A', 'B'], [1, 2, 'A', 'B'], ['A', 'B', 0, 1]]]
50-
e1, m = utils.remap_adj_lists(adj_lists, ret_map=True)
52+
adj_lists = [[[0, "A", 1.5], ["A", 1, 2], [1.5, 1, 0]], [[0, 3, "A", "B"], [1, 2, "A", "B"], ["A", "B", 0, 1]]]
53+
e1, m = utils.remap_adj_lists(*adj_lists, ret_map=True)
5154
for i, adj_list in enumerate(adj_lists):
5255
for j, adj in enumerate(adj_list):
5356
for k, a in enumerate(adj):
54-
assert e1[i][j][k] == m[a]
57+
assert e1[i][j][k] == m[str(a)]
5558

56-
e2, m = utils.remap_adj_list(adj_lists, ret_map=True, bipartite_graph=True)
59+
e2, m_u, m_v = utils.remap_adj_lists(*adj_lists, ret_map=True, bipartite_graph=True)
5760
for i, adj_list in enumerate(adj_lists):
5861
for j, adj in enumerate(adj_list):
5962
for k, a in enumerate(adj):
60-
assert e2[i][j][k] == m[a]
61-
63+
if k == 0:
64+
assert e2[i][j][k] == m_u[str(a)]
65+
else:
66+
assert e2[i][j][k] == m_v[str(a)]
6267

6368

6469
def test_edge_list_to_adj_list():
@@ -95,4 +100,4 @@ def test_adj_list_to_edge_list():
95100
assert (1, 3) in e_list
96101
assert (1, 4) in e_list
97102
assert (2, 3) in e_list
98-
assert len(e_list) == 7 # 8 ???
103+
assert len(e_list) == 8 # 8 ???

0 commit comments

Comments
 (0)