Skip to content

Commit cc475cc

Browse files
committed
add operator to convert dense to sparse and convert sparse to dense
1 parent b50963b commit cc475cc

File tree

3 files changed

+40
-0
lines changed

3 files changed

+40
-0
lines changed

SuperMoon/hyedge/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
from .utils.self_loop import self_loop_add, self_loop_remove
66
from .gather_neighbor import neighbor_grid, neighbor_distance, gather_patch_ft
77
from .utils.count import count_hyedge, count_node
8+
from .structure_convert import matrix2index, index2matrix
89

910
__all__ = ['pairwise_euclidean_distance',
1011
'count_hyedge', 'count_node',
1112
'degree_node', 'degree_hyedge',
1213
'self_loop_add', 'self_loop_remove',
1314
'contiguous_hyedge_idx', 'filter_node_index', 'remove_negative_index',
1415
'neighbor_grid', 'neighbor_distance', 'gather_patch_ft',
16+
'matrix2index', 'index2matrix',
1517
]
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import torch
2+
3+
4+
def index2matrix(index):
5+
assert index.size(0) == 2
6+
7+
index = index.long()
8+
v_len = index.size(1)
9+
v = torch.ones(v_len).float()
10+
matrix = torch.sparse_coo_tensor(index, v).to_dense()
11+
return matrix
12+
13+
14+
def matrix2index(matrix: torch.tensor):
15+
i_v = matrix.to_sparse()
16+
index = i_v.indices()
17+
return index;
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import torch
2+
3+
from SuperMoon.hyedge import matrix2index, index2matrix
4+
5+
6+
def test_matrix2index():
7+
matrix = torch.tensor([[0, 1, 0, 0],
8+
[1, 1, 0, 1],
9+
[0, 0, 1, 0]])
10+
index = matrix2index(matrix)
11+
assert torch.all(index == torch.tensor([[0, 1, 1, 1, 2],
12+
[1, 0, 1, 3, 2]]))
13+
14+
15+
def test_index2matrix():
16+
index = torch.tensor([[0, 1, 1, 1, 2],
17+
[1, 0, 1, 3, 2]])
18+
matrix = index2matrix(index)
19+
assert torch.all(matrix == torch.tensor([[0, 1, 0, 0],
20+
[1, 1, 0, 1],
21+
[0, 0, 1, 0]]).float())

0 commit comments

Comments
 (0)