Skip to content

Commit 55df18c

Browse files
committed
fix hypergraph device bugs
1 parent aa25822 commit 55df18c

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

dhg/structure/base.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -562,7 +562,7 @@ def _fetch_H_of_group(self, direction: str, group_name: str):
562562
torch.ones(len(v_idx)),
563563
torch.Size([self.num_v, num_e]),
564564
device=self.device,
565-
).coalesce()
565+
).coalesce().to(self.device)
566566
return H
567567

568568
def _fetch_R_of_group(self, direction: str, group_name: str):
@@ -587,7 +587,7 @@ def _fetch_R_of_group(self, direction: str, group_name: str):
587587
w_list.extend(self._raw_groups[group_name][e][f"w_{direction}"])
588588
R = torch.sparse_coo_tensor(
589589
torch.vstack([v_idx, e_idx]), torch.tensor(w_list), torch.Size([self.num_v, num_e]), device=self.device,
590-
).coalesce()
590+
).coalesce().to(self.device)
591591
return R
592592

593593
def _fetch_W_of_group(self, group_name: str):
@@ -598,7 +598,7 @@ def _fetch_W_of_group(self, group_name: str):
598598
"""
599599
assert group_name in self.group_names, f"The specified {group_name} is not in existing hyperedge groups."
600600
w_list = [content["w_e"] for content in self._raw_groups[group_name].values()]
601-
W = torch.tensor(w_list, device=self.device).view((-1, 1))
601+
W = torch.tensor(w_list, device=self.device).view((-1, 1)).to(self.device)
602602
return W
603603

604604
# some structure modification functions
@@ -798,7 +798,7 @@ def W_v(self) -> torch.Tensor:
798798
r"""Return the vertex weight matrix of the hypergraph.
799799
"""
800800
if self.cache["W_v"] is None:
801-
self.cache["W_v"] = torch.tensor(self.v_weight, dtype=torch.float, device=self.device).view(-1, 1)
801+
self.cache["W_v"] = torch.tensor(self.v_weight, dtype=torch.float, device=self.device).view(-1, 1).to(self.device)
802802
return self.cache["W_v"]
803803

804804
@property

dhg/structure/hypergraphs/hypergraph.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ def H(self) -> torch.Tensor:
773773
r"""Return the hypergraph incidence matrix :math:`\mathbf{H}` with ``torch.sparse_coo_tensor`` format.
774774
"""
775775
if self.cache.get("H") is None:
776-
self.cache["H"] = self.H_v2e
776+
self.cache["H"] = self.H_v2e.to(self.device)
777777
return self.cache["H"]
778778

779779
def H_of_group(self, group_name: str) -> torch.Tensor:
@@ -792,7 +792,7 @@ def H_T(self) -> torch.Tensor:
792792
r"""Return the transpose of the hypergraph incidence matrix :math:`\mathbf{H}^\top` with ``torch.sparse_coo_tensor`` format.
793793
"""
794794
if self.cache.get("H_T") is None:
795-
self.cache["H_T"] = self.H.t()
795+
self.cache["H_T"] = self.H.t().to(self.device)
796796
return self.cache["H_T"]
797797

798798
def H_T_of_group(self, group_name: str) -> torch.Tensor:

0 commit comments

Comments
 (0)