Skip to content

Commit 8346349

Browse files
committed
fix device error of graph,di_graph,bi_graph,hypergraph laplacian computing
1 parent 1ce68d4 commit 8346349

File tree

4 files changed

+23
-23
lines changed

4 files changed

+23
-23
lines changed

dhg/structure/graphs/bipartite_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,7 @@ def D_u(self) -> torch.Tensor:
509509
if self.cache.get("D_u", None) is None:
510510
_tmp = torch.sparse.sum(self.B, dim=1).to_dense().clone().view(-1)
511511
self.cache["D_u"] = torch.sparse_coo_tensor(
512-
indices=torch.arange(0, self.num_u).view(1, -1).repeat(2, 1),
512+
indices=torch.arange(0, self.num_u, device=self.device).view(1, -1).repeat(2, 1),
513513
values=_tmp,
514514
size=torch.Size([self.num_u, self.num_u]),
515515
device=self.device,
@@ -523,7 +523,7 @@ def D_v(self) -> torch.Tensor:
523523
if self.cache.get("D_v", None) is None:
524524
_tmp = torch.sparse.sum(self.B_T, dim=1).to_dense().clone().view(-1)
525525
self.cache["D_v"] = torch.sparse_coo_tensor(
526-
indices=torch.arange(0, self.num_v).view(1, -1).repeat(2, 1),
526+
indices=torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1),
527527
values=_tmp,
528528
size=torch.Size([self.num_v, self.num_v]),
529529
device=self.device,

dhg/structure/graphs/directed_graph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def D_v_in(self) -> torch.Tensor:
443443
if self.cache.get("D_v_in", None) is None:
444444
_tmp = torch.sparse.sum(self.A_T, dim=1).to_dense().clone().view(-1)
445445
self.cache["D_v_in"] = torch.sparse_coo_tensor(
446-
indices=torch.arange(0, self.num_v).view(1, -1).repeat(2, 1),
446+
indices=torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1),
447447
values=_tmp,
448448
size=torch.Size([self.num_v, self.num_v]),
449449
device=self.device,
@@ -457,7 +457,7 @@ def D_v_out(self) -> torch.Tensor:
457457
if self.cache.get("D_v_out", None) is None:
458458
_tmp = torch.sparse.sum(self.A, dim=1).to_dense().clone().view(-1)
459459
self.cache["D_v_out"] = torch.sparse_coo_tensor(
460-
indices=torch.arange(0, self.num_v).view(1, -1).repeat(2, 1),
460+
indices=torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1),
461461
values=_tmp,
462462
size=torch.Size([self.num_v, self.num_v]),
463463
device=self.device,

dhg/structure/graphs/graph.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -592,8 +592,8 @@ def L_sym(self) -> torch.Tensor:
592592
_tmp_g.remove_selfloop()
593593
_L = _tmp_g.D_v_neg_1_2.mm(_tmp_g.A).mm(_tmp_g.D_v_neg_1_2).clone()
594594
self.cache["L_sym"] = torch.sparse_coo_tensor(
595-
torch.hstack([torch.arange(0, self.num_v).view(1, -1).repeat(2, 1), _L._indices(),]),
596-
torch.hstack([torch.ones(self.num_v), -_L._values()]),
595+
torch.hstack([torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1), _L._indices(),]),
596+
torch.hstack([torch.ones(self.num_v, device=self.device), -_L._values()]),
597597
torch.Size([self.num_v, self.num_v]),
598598
device=self.device,
599599
).coalesce()
@@ -611,8 +611,8 @@ def L_rw(self) -> torch.Tensor:
611611
_tmp_g.remove_selfloop()
612612
_L = _tmp_g.D_v_neg_1.mm(_tmp_g.A).clone()
613613
self.cache["L_rw"] = torch.sparse_coo_tensor(
614-
torch.hstack([torch.arange(0, self.num_v).view(1, -1).repeat(2, 1), _L._indices(),]),
615-
torch.hstack([torch.ones(self.num_v), -_L._values()]),
614+
torch.hstack([torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1), _L._indices(),]),
615+
torch.hstack([torch.ones(self.num_v, device=self.device), -_L._values()]),
616616
torch.Size([self.num_v, self.num_v]),
617617
device=self.device,
618618
).coalesce()

dhg/structure/hypergraphs/hypergraph.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -814,7 +814,7 @@ def W_v(self) -> torch.Tensor:
814814
_tmp = torch.Tensor(self.v_weight)
815815
_num_v = _tmp.size(0)
816816
self.cache["W_v"] = torch.sparse_coo_tensor(
817-
torch.arange(0, _num_v).view(1, -1).repeat(2, 1),
817+
torch.arange(0, _num_v, device=self.device).view(1, -1).repeat(2, 1),
818818
_tmp,
819819
torch.Size([_num_v, _num_v]),
820820
device=self.device,
@@ -830,7 +830,7 @@ def W_e(self) -> torch.Tensor:
830830
_tmp = torch.cat(_tmp, dim=0).view(-1)
831831
_num_e = _tmp.size(0)
832832
self.cache["W_e"] = torch.sparse_coo_tensor(
833-
torch.arange(0, _num_e).view(1, -1).repeat(2, 1),
833+
torch.arange(0, _num_e, device=self.device).view(1, -1).repeat(2, 1),
834834
_tmp,
835835
torch.Size([_num_e, _num_e]),
836836
device=self.device,
@@ -848,7 +848,7 @@ def W_e_of_group(self, group_name: str) -> torch.Tensor:
848848
_tmp = self._fetch_W_of_group(group_name).view(-1)
849849
_num_e = _tmp.size(0)
850850
self.group_cache[group_name]["W_e"] = torch.sparse_coo_tensor(
851-
torch.arange(0, _num_e).view(1, -1).repeat(2, 1),
851+
torch.arange(0, _num_e, device=self.device).view(1, -1).repeat(2, 1),
852852
_tmp,
853853
torch.Size([_num_e, _num_e]),
854854
device=self.device,
@@ -863,7 +863,7 @@ def D_v(self) -> torch.Tensor:
863863
_tmp = [self.D_v_of_group(name)._values().clone() for name in self.group_names]
864864
_tmp = torch.vstack(_tmp).sum(dim=0).view(-1)
865865
self.cache["D_v"] = torch.sparse_coo_tensor(
866-
torch.arange(0, self.num_v).view(1, -1).repeat(2, 1),
866+
torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1),
867867
_tmp,
868868
torch.Size([self.num_v, self.num_v]),
869869
device=self.device,
@@ -885,7 +885,7 @@ def D_v_of_group(self, group_name: str) -> torch.Tensor:
885885
_tmp = torch.sparse.sum(H_, dim=1).to_dense().clone().view(-1)
886886
_num_v = _tmp.size(0)
887887
self.group_cache[group_name]["D_v"] = torch.sparse_coo_tensor(
888-
torch.arange(0, _num_v).view(1, -1).repeat(2, 1),
888+
torch.arange(0, _num_v, device=self.device).view(1, -1).repeat(2, 1),
889889
_tmp,
890890
torch.Size([_num_v, _num_v]),
891891
device=self.device,
@@ -959,7 +959,7 @@ def D_e(self) -> torch.Tensor:
959959
_tmp = torch.cat(_tmp, dim=0).view(-1)
960960
_num_e = _tmp.size(0)
961961
self.cache["D_e"] = torch.sparse_coo_tensor(
962-
torch.arange(0, _num_e).view(1, -1).repeat(2, 1),
962+
torch.arange(0, _num_e, device=self.device).view(1, -1).repeat(2, 1),
963963
_tmp,
964964
torch.Size([_num_e, _num_e]),
965965
device=self.device,
@@ -977,7 +977,7 @@ def D_e_of_group(self, group_name: str) -> torch.Tensor:
977977
_tmp = torch.sparse.sum(self.H_T_of_group(group_name), dim=1).to_dense().clone().view(-1)
978978
_num_e = _tmp.size(0)
979979
self.group_cache[group_name]["D_e"] = torch.sparse_coo_tensor(
980-
torch.arange(0, _num_e).view(1, -1).repeat(2, 1),
980+
torch.arange(0, _num_e, device=self.device).view(1, -1).repeat(2, 1),
981981
_tmp,
982982
torch.Size([_num_e, _num_e]),
983983
device=self.device,
@@ -1090,8 +1090,8 @@ def L_sym(self) -> torch.Tensor:
10901090
if self.cache.get("L_sym") is None:
10911091
L_HGNN = self.L_HGNN.clone()
10921092
self.cache["L_sym"] = torch.sparse_coo_tensor(
1093-
torch.hstack([torch.arange(0, self.num_v).view(1, -1).repeat(2, 1), L_HGNN._indices(),]),
1094-
torch.hstack([torch.ones(self.num_v), -L_HGNN._values()]),
1093+
torch.hstack([torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1), L_HGNN._indices(),]),
1094+
torch.hstack([torch.ones(self.num_v, device=self.device), -L_HGNN._values()]),
10951095
torch.Size([self.num_v, self.num_v]),
10961096
device=self.device,
10971097
).coalesce()
@@ -1110,8 +1110,8 @@ def L_sym_of_group(self, group_name: str) -> torch.Tensor:
11101110
if self.group_cache[group_name].get("L_sym") is None:
11111111
L_HGNN = self.L_HGNN_of_group(group_name).clone()
11121112
self.group_cache[group_name]["L_sym"] = torch.sparse_coo_tensor(
1113-
torch.hstack([torch.arange(0, self.num_v).view(1, -1).repeat(2, 1), L_HGNN._indices(),]),
1114-
torch.hstack([torch.ones(self.num_v), -L_HGNN._values()]),
1113+
torch.hstack([torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1), L_HGNN._indices(),]),
1114+
torch.hstack([torch.ones(self.num_v, device=self.device), -L_HGNN._values()]),
11151115
torch.Size([self.num_v, self.num_v]),
11161116
device=self.device,
11171117
).coalesce()
@@ -1128,8 +1128,8 @@ def L_rw(self) -> torch.Tensor:
11281128
_tmp = self.D_v_neg_1.mm(self.H).mm(self.W_e).mm(self.D_e_neg_1).mm(self.H_T)
11291129
self.cache["L_rw"] = (
11301130
torch.sparse_coo_tensor(
1131-
torch.hstack([torch.arange(0, self.num_v).view(1, -1).repeat(2, 1), _tmp._indices(),]),
1132-
torch.hstack([torch.ones(self.num_v), -_tmp._values()]),
1131+
torch.hstack([torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1), _tmp._indices(),]),
1132+
torch.hstack([torch.ones(self.num_v, device=self.device), -_tmp._values()]),
11331133
torch.Size([self.num_v, self.num_v]),
11341134
device=self.device,
11351135
)
@@ -1158,8 +1158,8 @@ def L_rw_of_group(self, group_name: str) -> torch.Tensor:
11581158
)
11591159
self.group_cache[group_name]["L_rw"] = (
11601160
torch.sparse_coo_tensor(
1161-
torch.hstack([torch.arange(0, self.num_v).view(1, -1).repeat(2, 1), _tmp._indices(),]),
1162-
torch.hstack([torch.ones(self.num_v), -_tmp._values()]),
1161+
torch.hstack([torch.arange(0, self.num_v, device=self.device).view(1, -1).repeat(2, 1), _tmp._indices(),]),
1162+
torch.hstack([torch.ones(self.num_v, device=self.device), -_tmp._values()]),
11631163
torch.Size([self.num_v, self.num_v]),
11641164
device=self.device,
11651165
)

0 commit comments

Comments
 (0)