Skip to content

Commit e17fc95

Browse files
committed
add hypergraph vertex weight (v_weight and W_v)
1 parent 7786e2d commit e17fc95

File tree

2 files changed

+44
-7
lines changed

2 files changed

+44
-7
lines changed

dhg/structure/hypergraphs/hypergraph.py

Lines changed: 38 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ class Hypergraph(BaseHypergraph):
2222
``num_v`` (``int``): The number of vertices in the hypergraph.
2323
``e_list`` (``Union[List[int], List[List[int]]]``, optional): A list of hyperedges describes how the vertices point to the hyperedges. Defaults to ``None``.
2424
``e_weight`` (``Union[float, List[float]]``, optional): A list of weights for hyperedges. If set to ``None``, the value ``1`` is used for all hyperedges. Defaults to ``None``.
25+
``v_weight`` (``Union[List[float]]``, optional): A list of weights for vertices. If set to ``None``, the value ``1`` is used for all vertices. Defaults to ``None``.
2526
``merge_op`` (``str``): The operation to merge those conflicting hyperedges in the same hyperedge group, which can be ``'mean'``, ``'sum'`` or ``'max'``. Defaults to ``'mean'``.
2627
``device`` (``torch.device``, optional): The deivce to store the hypergraph. Defaults to ``torch.device('cpu')``.
2728
"""
@@ -31,10 +32,18 @@ def __init__(
3132
num_v: int,
3233
e_list: Optional[Union[List[int], List[List[int]]]] = None,
3334
e_weight: Optional[Union[float, List[float]]] = None,
35+
v_weight: Optional[List[float]] = None,
3436
merge_op: str = "mean",
3537
device: torch.device = torch.device("cpu"),
3638
):
3739
super().__init__(num_v, device=device)
40+
# init vertex weight
41+
if v_weight is None:
42+
self._v_weight = [1.0] * self.num_v
43+
else:
44+
assert len(v_weight) == self.num_v, "The length of vertex weight is not equal to the number of vertices."
45+
self._v_weight = v_weight
46+
# init hyperedges
3847
if e_list is not None:
3948
self.add_hyperedges(e_list, e_weight, merge_op=merge_op)
4049

@@ -495,6 +504,12 @@ def v(self) -> List[int]:
495504
r"""Return the list of vertices.
496505
"""
497506
return super().v
507+
508+
@property
509+
def v_weight(self) -> List[float]:
510+
r"""Return the list of vertex weights.
511+
"""
512+
return self._v_weight
498513

499514
@property
500515
def e(self) -> Tuple[List[List[int]], List[float]]:
@@ -634,7 +649,7 @@ def vars_for_DL(self) -> List[str]:
634649
Sparse Diagnal Matrices:
635650
636651
.. math::
637-
\mathbf{W}_e, \mathbf{D}_v, \mathbf{D}_v^{-1}, \mathbf{D}_v^{-\frac{1}{2}}, \mathbf{D}_e, \mathbf{D}_e^{-1},
652+
\mathbf{W}_v, \mathbf{W}_e, \mathbf{D}_v, \mathbf{D}_v^{-1}, \mathbf{D}_v^{-\frac{1}{2}}, \mathbf{D}_e, \mathbf{D}_e^{-1},
638653
639654
Vectors:
640655
@@ -649,6 +664,7 @@ def vars_for_DL(self) -> List[str]:
649664
"L_sym",
650665
"L_rw",
651666
"L_HGNN",
667+
"W_v",
652668
"W_e",
653669
"D_v",
654670
"D_v_neg_1",
@@ -754,14 +770,14 @@ def e2v_weight_of_group(self, group_name: str) -> torch.Tensor:
754770

755771
@property
756772
def H(self) -> torch.Tensor:
757-
r"""Return the hypergraph incidence matrix :math:`\mathbf{H}` with ``torch.Tensor`` format.
773+
r"""Return the hypergraph incidence matrix :math:`\mathbf{H}` with ``torch.sparse_coo_tensor`` format.
758774
"""
759775
if self.cache.get("H") is None:
760776
self.cache["H"] = self.H_v2e
761777
return self.cache["H"]
762778

763779
def H_of_group(self, group_name: str) -> torch.Tensor:
764-
r"""Return the hypergraph incidence matrix :math:`\mathbf{H}` of the specified hyperedge group with ``torch.Tensor`` format.
780+
r"""Return the hypergraph incidence matrix :math:`\mathbf{H}` of the specified hyperedge group with ``torch.sparse_coo_tensor`` format.
765781
766782
Args:
767783
``group_name`` (``str``): The name of the specified hyperedge group.
@@ -773,14 +789,14 @@ def H_of_group(self, group_name: str) -> torch.Tensor:
773789

774790
@property
775791
def H_T(self) -> torch.Tensor:
776-
r"""Return the transpose of the hypergraph incidence matrix :math:`\mathbf{H}^\top` with ``torch.Tensor`` format.
792+
r"""Return the transpose of the hypergraph incidence matrix :math:`\mathbf{H}^\top` with ``torch.sparse_coo_tensor`` format.
777793
"""
778794
if self.cache.get("H_T") is None:
779795
self.cache["H_T"] = self.H.t()
780796
return self.cache["H_T"]
781797

782798
def H_T_of_group(self, group_name: str) -> torch.Tensor:
783-
r"""Return the transpose of the hypergraph incidence matrix :math:`\mathbf{H}^\top` of the specified hyperedge group with ``torch.Tensor`` format.
799+
r"""Return the transpose of the hypergraph incidence matrix :math:`\mathbf{H}^\top` of the specified hyperedge group with ``torch.sparse_coo_tensor`` format.
784800
785801
Args:
786802
``group_name`` (``str``): The name of the specified hyperedge group.
@@ -789,10 +805,25 @@ def H_T_of_group(self, group_name: str) -> torch.Tensor:
789805
if self.group_cache[group_name].get("H_T") is None:
790806
self.group_cache[group_name]["H_T"] = self.H_of_group(group_name).t()
791807
return self.group_cache[group_name]["H_T"]
808+
809+
@property
810+
def W_v(self) -> torch.Tensor:
811+
r"""Return the weight matrix :math:`\mathbf{W}_v` of vertices with ``torch.sparse_coo_tensor`` format.
812+
"""
813+
if self.cache.get("W_v") is None:
814+
_tmp = torch.Tensor(self.v_weight)
815+
_num_v = _tmp.size(0)
816+
self.cache["W_v"] = torch.sparse_coo_tensor(
817+
torch.arange(0, _num_v).view(1, -1).repeat(2, 1),
818+
_tmp,
819+
torch.Size([_num_v, _num_v]),
820+
device=self.device,
821+
).coalesce()
822+
return self.cache["W_v"]
792823

793824
@property
794825
def W_e(self) -> torch.Tensor:
795-
r"""Return the weight matrix :math:`\mathbf{W}_e` of hyperedges with ``torch.Tensor`` format.
826+
r"""Return the weight matrix :math:`\mathbf{W}_e` of hyperedges with ``torch.sparse_coo_tensor`` format.
796827
"""
797828
if self.cache.get("W_e") is None:
798829
_tmp = [self.W_e_of_group(name)._values().clone() for name in self.group_names]
@@ -807,7 +838,7 @@ def W_e(self) -> torch.Tensor:
807838
return self.cache["W_e"]
808839

809840
def W_e_of_group(self, group_name: str) -> torch.Tensor:
810-
r"""Return the weight matrix :math:`\mathbf{W}_e` of hyperedges of the specified hyperedge group with ``torch.Tensor`` format.
841+
r"""Return the weight matrix :math:`\mathbf{W}_e` of hyperedges of the specified hyperedge group with ``torch.sparse_coo_tensor`` format.
811842
812843
Args:
813844
``group_name`` (``str``): The name of the specified hyperedge group.

tests/structure/test_hypergraph.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -449,6 +449,12 @@ def test_H_T_group(g1):
449449
assert (g1.H_T_of_group("knn").to_dense().cpu() == torch.tensor([[1, 0, 0, 0, 1, 1]])).all()
450450

451451

452+
def test_W_v(g2):
453+
assert (g2.W_v.cpu()._values() == torch.tensor([1, 1, 1, 1, 1])).all()
454+
hg = Hypergraph(5, [[1, 2], [0, 2, 3, 4]], v_weight=[0.1, 1, 2, 1, 1])
455+
assert (hg.W_v.cpu()._values() == torch.tensor([0.1, 1, 2, 1, 1])).all()
456+
457+
452458
def test_W_e(g2):
453459
assert (g2.W_e.cpu()._values() == torch.tensor([0.5, 1, 0.5, 1, 0.5])).all()
454460

0 commit comments

Comments
 (0)