Skip to content

Commit 118a714

Browse files
committed
update node/hyedge count functions
1 parent cc475cc commit 118a714

File tree

6 files changed

+17
-17
lines changed

6 files changed

+17
-17
lines changed

.idea/HyperG_package.iml

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

.idea/misc.xml

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

SuperMoon/hyedge/utils/count.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1-
def count_node(H):
2-
return H[0].max().item() + 1
1+
def count_node(H, node_num=None):
2+
return H[0].max().item() + 1 if node_num is None else node_num
33

44

5-
def count_hyedge(H):
6-
return H[1].max().item() + 1
5+
def count_hyedge(H, hyedge_num=None):
6+
return H[1].max().item() + 1 if hyedge_num is None else hyedge_num

SuperMoon/hyedge/utils/degree.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,18 @@
33
from .count import count_hyedge, count_node
44

55

6-
def degree_node(H):
6+
def degree_node(H, node_num=None):
77
node_idx, edge_idx = H
8-
node_num = count_node(H)
8+
node_num = count_node(H, node_num)
99
src = torch.ones_like(node_idx).float().to(H.device)
1010
out = torch.zeros(node_num).to(H.device)
1111
return out.scatter_add(0, node_idx, src).long()
1212
# return torch.zeros(node_num).scatter_add(0, node_idx, torch.ones_like(node_idx).float()).long()
1313

1414

15-
def degree_hyedge(H: torch.Tensor):
15+
def degree_hyedge(H: torch.Tensor, hyedge_num=None):
1616
node_idx, hyedge_idx = H
17-
edge_num = count_hyedge(H)
17+
hyedge_num = count_hyedge(H, hyedge_num=hyedge_num)
1818
src = torch.ones_like(hyedge_idx).float().to(H.device)
19-
out = torch.zeros(edge_num).to(H.device)
19+
out = torch.zeros(hyedge_num).to(H.device)
2020
return out.scatter_add(0, hyedge_idx, src).long()

SuperMoon/hyedge/utils/self_loop.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,9 @@
55
from .verify import contiguous_hyedge_idx
66

77

8-
def self_loop_remove(H):
8+
def self_loop_remove(H, hyedge_num=None):
99
node_idx, hyedge_idx = H
10-
DE = degree_hyedge(H)
10+
DE = degree_hyedge(H, hyedge_num)
1111
loop_edge_idx = torch.where(DE == 1)[0]
1212

1313
mask = torch.ones_like(hyedge_idx).bool()
@@ -18,9 +18,9 @@ def self_loop_remove(H):
1818
return contiguous_hyedge_idx(H)
1919

2020

21-
def self_loop_add(H):
21+
def self_loop_add(H, node_num=None):
2222
H = self_loop_remove(H)
23-
node_num = count_node(H)
23+
node_num = count_node(H, node_num=node_num)
2424

2525
loop_node_idx = torch.arange(node_num)
2626
loop_hyedge_idx = torch.arange(node_num)

SuperMoon/hyedge/utils/verify.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
from .degree import degree_hyedge
44

55

6-
def contiguous_hyedge_idx(H):
6+
def contiguous_hyedge_idx(H, hyedge_num=None):
77
node_idx, hyedge_idx = H
8-
DE = degree_hyedge(H)
8+
DE = degree_hyedge(H, hyedge_num)
99
zero_idx = torch.where(DE == 0)[0]
1010

1111
bias = torch.zeros_like(hyedge_idx)

0 commit comments

Comments
 (0)