@@ -562,7 +562,7 @@ def _fetch_H_of_group(self, direction: str, group_name: str):
562562 torch .ones (len (v_idx )),
563563 torch .Size ([self .num_v , num_e ]),
564564 device = self .device ,
565- ).coalesce ()
565+ ).coalesce (). to ( self . device )
566566 return H
567567
568568 def _fetch_R_of_group (self , direction : str , group_name : str ):
@@ -587,7 +587,7 @@ def _fetch_R_of_group(self, direction: str, group_name: str):
587587 w_list .extend (self ._raw_groups [group_name ][e ][f"w_{ direction } " ])
588588 R = torch .sparse_coo_tensor (
589589 torch .vstack ([v_idx , e_idx ]), torch .tensor (w_list ), torch .Size ([self .num_v , num_e ]), device = self .device ,
590- ).coalesce ()
590+ ).coalesce (). to ( self . device )
591591 return R
592592
593593 def _fetch_W_of_group (self , group_name : str ):
@@ -598,7 +598,7 @@ def _fetch_W_of_group(self, group_name: str):
598598 """
599599 assert group_name in self .group_names , f"The specified { group_name } is not in existing hyperedge groups."
600600 w_list = [content ["w_e" ] for content in self ._raw_groups [group_name ].values ()]
601- W = torch .tensor (w_list , device = self .device ).view ((- 1 , 1 ))
601+ W = torch .tensor (w_list , device = self .device ).view ((- 1 , 1 )). to ( self . device )
602602 return W
603603
604604 # some structure modification functions
@@ -798,7 +798,7 @@ def W_v(self) -> torch.Tensor:
798798 r"""Return the vertex weight matrix of the hypergraph.
799799 """
800800 if self .cache ["W_v" ] is None :
801- self .cache ["W_v" ] = torch .tensor (self .v_weight , dtype = torch .float , device = self .device ).view (- 1 , 1 )
801+ self .cache ["W_v" ] = torch .tensor (self .v_weight , dtype = torch .float , device = self .device ).view (- 1 , 1 ). to ( self . device )
802802 return self .cache ["W_v" ]
803803
804804 @property
0 commit comments