Skip to content

Commit bd8a8dd

Browse files
committed
fix group_name and test bugs
1 parent be220d8 commit bd8a8dd

File tree

3 files changed

+21
-22
lines changed

3 files changed

+21
-22
lines changed

dhg/structure/hypergraphs/hypergraph.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1488,7 +1488,7 @@ def e2v_aggregation_of_group(
14881488
P = self.H_of_group(group_name)
14891489
if aggr == "mean":
14901490
X = torch.sparse.mm(P, X)
1491-
X = torch.sparse.mm(self.D_v_neg_1_of_group[group_name], X)
1491+
X = torch.sparse.mm(self.D_v_neg_1_of_group(group_name), X)
14921492
elif aggr == "sum":
14931493
X = torch.sparse.mm(P, X)
14941494
elif aggr == "softmax_then_sum":
@@ -1499,12 +1499,12 @@ def e2v_aggregation_of_group(
14991499
else:
15001500
# init message path
15011501
assert (
1502-
e2v_weight.shape[0] == self.e2v_weight_of_group[group_name].shape[0]
1502+
e2v_weight.shape[0] == self.e2v_weight_of_group(group_name).shape[0]
15031503
), f"The size of e2v_weight must be equal to the size of self.e2v_weight_of_group('{group_name}')."
15041504
P = torch.sparse_coo_tensor(
1505-
self.H_of_group[group_name]._indices(),
1505+
self.H_of_group(group_name)._indices(),
15061506
e2v_weight,
1507-
self.H_of_group[group_name].shape,
1507+
self.H_of_group(group_name).shape,
15081508
device=self.device,
15091509
)
15101510
if drop_rate > 0.0:

tests/structure/test_graph.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,8 @@ def test_laplacian_symmetric():
292292
for _ in range(num_e):
293293
s = random.randrange(num_v)
294294
d = random.randrange(num_v)
295+
if s == d:
296+
continue
295297
g.add_edges((s, d))
296298
A[s, d] = 1
297299
A[d, s] = 1
@@ -318,6 +320,8 @@ def test_laplacian_random_walk():
318320
for _ in range(num_e):
319321
s = random.randrange(num_v)
320322
d = random.randrange(num_v)
323+
if s == d:
324+
continue
321325
g.add_edges((s, d))
322326
A[s, d] = 1
323327
A[d, s] = 1

tests/structure/test_hypergraph.py

Lines changed: 13 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -217,24 +217,19 @@ def test_add_hyperedges_from_graph_kHop(g1):
217217

218218

219219
def test_add_hyperedges_from_bigraph():
220-
g = BiGraph(3, 4, [[0, 0], [1, 0], [2, 0], [3, 0], [0, 1], [1, 1], [2, 2], [3, 2]])
221-
h = Hypergraph(6)
222-
h.add_hyperedges_from_bigraph(g, group_name="bigraph")
223-
assert h.num_e == 4
224-
assert (0, 1, 2, 3) in h.e_of_group("bigraph")[0]
225-
assert (0, 1) in h.e_of_group("bigraph")[0]
226-
assert (2, 3) in h.e_of_group("bigraph")[0]
227-
228-
h.add_hyperedges_from_bigraph(g, group_name="bigraph-u", U_as_vertex=True)
229-
assert h.num_e == 4
230-
assert (0, 1, 2, 3) in h.e_of_group("bigraph-u")[0]
231-
assert (0, 1) in h.e_of_group("bigraph-u")[0]
232-
assert (2, 3) in h.e_of_group("bigraph-u")[0]
233-
234-
h.add_hyperedges_from_bigraph(g, group_name="bigraph-v", U_as_vertex=False)
235-
assert h.num_e == 3
236-
assert (0, 1) in h.e_of_group("bigraph-v")[0]
237-
assert (0, 2) in h.e_of_group("bigraph-v")[0]
220+
g = BiGraph(4, 3, [[0, 0], [1, 0], [2, 0], [3, 0], [0, 1], [1, 1], [2, 2], [3, 2]])
221+
hg = Hypergraph(3)
222+
hg.add_hyperedges_from_bigraph(g, group_name="bigraph")
223+
assert hg.num_e == 2
224+
assert (0, 1) in hg.e_of_group("bigraph")[0]
225+
assert (0, 2) in hg.e_of_group("bigraph")[0]
226+
227+
hg = Hypergraph(4)
228+
hg.add_hyperedges_from_bigraph(g, group_name="bigraph-u", U_as_vertex=True)
229+
assert hg.num_e == 3
230+
assert (0, 1, 2, 3) in hg.e_of_group("bigraph-u")[0]
231+
assert (0, 1) in hg.e_of_group("bigraph-u")[0]
232+
assert (2, 3) in hg.e_of_group("bigraph-u")[0]
238233

239234

240235
def test_remove_hyperedges(g1):

0 commit comments

Comments
 (0)