Skip to content

Commit 7786e2d

Browse files
committed
fix hypergraph D_v bugs ( h[v, e] -> w[e]*h[v, e] )
1 parent 7293527 commit 7786e2d

File tree

2 files changed

+9
-5
lines changed

2 files changed

+9
-5
lines changed

dhg/structure/hypergraphs/hypergraph.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,11 @@ def D_v_of_group(self, group_name: str) -> torch.Tensor:
847847
"""
848848
assert group_name in self.group_names, f"The specified {group_name} is not in existing hyperedge groups."
849849
if self.group_cache[group_name].get("D_v") is None:
850-
_tmp = torch.sparse.sum(self.H_of_group(group_name), dim=1).to_dense().clone().view(-1)
850+
H = self.H_of_group(group_name).clone()
851+
w_e = self.W_e_of_group(group_name)._values().clone()
852+
val = w_e[H._indices()[1]] * H._values()
853+
H_ = torch.sparse_coo_tensor(H._indices(), val, size=H.shape, device=self.device).coalesce()
854+
_tmp = torch.sparse.sum(H_, dim=1).to_dense().clone().view(-1)
851855
_num_v = _tmp.size(0)
852856
self.group_cache[group_name]["D_v"] = torch.sparse_coo_tensor(
853857
torch.arange(0, _num_v).view(1, -1).repeat(2, 1),

tests/structure/test_hypergraph.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -290,7 +290,7 @@ def test_add_and_remove_group(g1):
290290
def test_deg(g1, g2):
291291
assert g1.deg_v == [2, 2, 2, 1, 1, 1]
292292
assert g1.deg_e == [4, 2, 3]
293-
assert g2.deg_v == [2, 3, 3, 4, 1]
293+
assert g2.deg_v == [1.5, 2, 2, 3, 1]
294294
assert g2.deg_e == [3, 3, 2, 3, 2]
295295

296296

@@ -463,7 +463,7 @@ def test_W_e_group(g2):
463463
def test_D(g1, g2):
464464
assert (g1.D_v.cpu()._values() == torch.tensor([2, 2, 2, 1, 1, 1])).all()
465465
assert (g1.D_e.cpu()._values() == torch.tensor([4, 2, 3])).all()
466-
assert (g2.D_v.cpu()._values() == torch.tensor([2, 3, 3, 4, 1])).all()
466+
assert (g2.D_v.cpu()._values() == torch.tensor([1.5, 2, 2, 3, 1])).all()
467467
assert (g2.D_e.cpu()._values() == torch.tensor([3, 3, 2, 3, 2])).all()
468468

469469

@@ -483,11 +483,11 @@ def test_D_neg(g1, g2):
483483
# -1
484484
assert (g1.D_v_neg_1.cpu()._values() == torch.tensor([2, 2, 2, 1, 1, 1]) ** (-1.0)).all()
485485
assert (g1.D_e_neg_1.cpu()._values() == torch.tensor([4, 2, 3]) ** (-1.0)).all()
486-
assert (g2.D_v_neg_1.cpu()._values() == torch.tensor([2, 3, 3, 4, 1]) ** (-1.0)).all()
486+
assert (g2.D_v_neg_1.cpu()._values() == torch.tensor([1.5, 2, 2, 3, 1]) ** (-1.0)).all()
487487
assert (g2.D_e_neg_1.cpu()._values() == torch.tensor([3, 3, 2, 3, 2]) ** (-1.0)).all()
488488
# -1/2
489489
assert (g1.D_v_neg_1_2.cpu()._values() == torch.tensor([2, 2, 2, 1, 1, 1]) ** (-0.5)).all()
490-
assert (g2.D_v_neg_1_2.cpu()._values() == torch.tensor([2, 3, 3, 4, 1]) ** (-0.5)).all()
490+
assert (g2.D_v_neg_1_2.cpu()._values() == torch.tensor([1.5, 2, 2, 3, 1]) ** (-0.5)).all()
491491
# isolated vertex
492492
g3 = Hypergraph(3, [0, 1])
493493
assert (g3.D_v_neg_1.cpu()._values() == torch.tensor([1, 1, 0])).all()

0 commit comments

Comments
 (0)