Skip to content

Commit 3b5eea8

Browse files
committed
enhance nbr and N_v of graph with specified k-hop
1 parent 0ac4fd3 commit 3b5eea8

File tree

5 files changed

+30
-11
lines changed

5 files changed

+30
-11
lines changed

dhg/_global.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ def get_dhg_cache_root():
1212
# global paths
1313
CACHE_ROOT = get_dhg_cache_root()
1414
DATASETS_ROOT = CACHE_ROOT / "datasets"
15-
REMOTE_ROOT = "https://data.deephypergraph.com/"
15+
# REMOTE_ROOT = "https://data.deephypergraph.com/"
16+
REMOTE_ROOT = "https://download.moon-lab.tech:28501/"
1617
REMOTE_DATASETS_ROOT = REMOTE_ROOT + "datasets/"
1718
# REMOTE_DATASETS_ROOT = "https://data.shrec22.moon-lab.tech:18443/DHG/datasets/"

dhg/data/dblp.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,9 @@
66

77

88
class DBLP8k(BaseData):
9-
r"""The DBLP-8k dataset is a citation network dataset for link prediction task.
10-
The dataset is a part of the dataset crawled according to DBLP API,
11-
and we have selected each item based on some conditions,
12-
such as the venue and publication year (from 2018 to 2022). It contains 6498 authors and 2603 papers.
9+
r"""The DBLP-8k dataset is a citation network dataset for link prediction task.
10+
The dataset is a part of the dataset crawled according to DBLP API, and we have selected each item based on some conditions, such as the venue and publication year (from 2018 to 2022). It contains 6498 authors and 2603 papers.
11+
1312
The content of the DBLP-8k dataset includes the following:
1413
1514
- ``num_vertices``: The number of vertices: :math:`8,657`.

dhg/structure/graphs/graph.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -421,13 +421,14 @@ def deg_v(self) -> List[int]:
421421
"""
422422
return self.D_v._values().cpu().numpy().tolist()
423423

424-
def nbr_v(self, v_idx: int) -> Tuple[List[int], List[float]]:
425-
r""" Return a vertex list of the neighbors of the vertex ``v_idx``.
424+
def nbr_v(self, v_idx: int, hop: int = 1) -> List[int]:
425+
r""" Return a vertex list of the ``k``-hop neighbors of the vertex ``v_idx``.
426426
427427
Args:
428428
``v_idx`` (``int``): The index of the vertex.
429+
``hop`` (``int``): The number of the hop.
429430
"""
430-
return self.N_v(v_idx).cpu().numpy().tolist()
431+
return self.N_v(v_idx, hop).cpu().numpy().tolist()
431432

432433
# =====================================================================================
433434
# properties for deep learning
@@ -521,13 +522,25 @@ def D_v_neg_1_2(self,) -> torch.Tensor:
521522
).coalesce()
522523
return self.cache["D_v_neg_1_2"]
523524

524-
def N_v(self, v_idx: int) -> Tuple[List[int], List[float]]:
525-
r""" Return the neighbors of the vertex ``v_idx`` with ``torch.Tensor`` format.
525+
def N_v(self, v_idx: int, hop: int = 1) -> List[int]:
526+
r""" Return the ``k``-hop neighbors of the vertex ``v_idx`` with ``torch.Tensor`` format.
526527
527528
Args:
528529
``v_idx`` (``int``): The index of the vertex.
530+
``hop`` (``int``): The number of the hop.
529531
"""
530-
sub_v_set = self.A[v_idx]._indices()[0].clone()
532+
assert hop >= 1, "``hop`` must be a number larger than or equal to 1."
533+
if hop == 1:
534+
A_k = self.A
535+
else:
536+
if self.cache.get(f"A_{hop}") is None:
537+
A_1, A_k = self.A.clone(), self.A.clone()
538+
for _ in range(hop - 1):
539+
A_k = torch.sparse.mm(A_k, A_1)
540+
self.cache[f"A_{hop}"] = A_k
541+
else:
542+
A_k = self.cache[f"A_{hop}"]
543+
sub_v_set = A_k[v_idx]._indices()[0].clone()
531544
return sub_v_set
532545

533546
@property

docs/source/api/data.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ Hypergraph Datasets
6161
dhg.data.WalmartTrips
6262
dhg.data.HouseCommittees
6363
dhg.data.News20
64+
dhg.data.DBLP8k
6465

6566

6667
**Welcome to contribute datasets!**

tests/structure/test_graph.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ def test_nbr(g1):
174174
assert g1.nbr_v(0) == [1, 2, 3]
175175
g1.remove_edges((0, 2))
176176
assert g1.nbr_v(2) == []
177+
# hop k
178+
g3 = Graph(5, [(0, 1), (0, 3), (1, 4), (2, 3)])
179+
assert sorted(g3.nbr_v(3, 1)) == [0, 2]
180+
assert sorted(g3.nbr_v(3, 2)) == [1, 3]
181+
assert sorted(g3.nbr_v(3, 3)) == [0, 2, 4]
177182

178183

179184
# test deep learning

0 commit comments

Comments
 (0)