@@ -814,7 +814,7 @@ def W_v(self) -> torch.Tensor:
814814 _tmp = torch .Tensor (self .v_weight )
815815 _num_v = _tmp .size (0 )
816816 self .cache ["W_v" ] = torch .sparse_coo_tensor (
817- torch .arange (0 , _num_v ).view (1 , - 1 ).repeat (2 , 1 ),
817+ torch .arange (0 , _num_v , device = self . device ).view (1 , - 1 ).repeat (2 , 1 ),
818818 _tmp ,
819819 torch .Size ([_num_v , _num_v ]),
820820 device = self .device ,
@@ -830,7 +830,7 @@ def W_e(self) -> torch.Tensor:
830830 _tmp = torch .cat (_tmp , dim = 0 ).view (- 1 )
831831 _num_e = _tmp .size (0 )
832832 self .cache ["W_e" ] = torch .sparse_coo_tensor (
833- torch .arange (0 , _num_e ).view (1 , - 1 ).repeat (2 , 1 ),
833+ torch .arange (0 , _num_e , device = self . device ).view (1 , - 1 ).repeat (2 , 1 ),
834834 _tmp ,
835835 torch .Size ([_num_e , _num_e ]),
836836 device = self .device ,
@@ -848,7 +848,7 @@ def W_e_of_group(self, group_name: str) -> torch.Tensor:
848848 _tmp = self ._fetch_W_of_group (group_name ).view (- 1 )
849849 _num_e = _tmp .size (0 )
850850 self .group_cache [group_name ]["W_e" ] = torch .sparse_coo_tensor (
851- torch .arange (0 , _num_e ).view (1 , - 1 ).repeat (2 , 1 ),
851+ torch .arange (0 , _num_e , device = self . device ).view (1 , - 1 ).repeat (2 , 1 ),
852852 _tmp ,
853853 torch .Size ([_num_e , _num_e ]),
854854 device = self .device ,
@@ -863,7 +863,7 @@ def D_v(self) -> torch.Tensor:
863863 _tmp = [self .D_v_of_group (name )._values ().clone () for name in self .group_names ]
864864 _tmp = torch .vstack (_tmp ).sum (dim = 0 ).view (- 1 )
865865 self .cache ["D_v" ] = torch .sparse_coo_tensor (
866- torch .arange (0 , self .num_v ).view (1 , - 1 ).repeat (2 , 1 ),
866+ torch .arange (0 , self .num_v , device = self . device ).view (1 , - 1 ).repeat (2 , 1 ),
867867 _tmp ,
868868 torch .Size ([self .num_v , self .num_v ]),
869869 device = self .device ,
@@ -885,7 +885,7 @@ def D_v_of_group(self, group_name: str) -> torch.Tensor:
885885 _tmp = torch .sparse .sum (H_ , dim = 1 ).to_dense ().clone ().view (- 1 )
886886 _num_v = _tmp .size (0 )
887887 self .group_cache [group_name ]["D_v" ] = torch .sparse_coo_tensor (
888- torch .arange (0 , _num_v ).view (1 , - 1 ).repeat (2 , 1 ),
888+ torch .arange (0 , _num_v , device = self . device ).view (1 , - 1 ).repeat (2 , 1 ),
889889 _tmp ,
890890 torch .Size ([_num_v , _num_v ]),
891891 device = self .device ,
@@ -959,7 +959,7 @@ def D_e(self) -> torch.Tensor:
959959 _tmp = torch .cat (_tmp , dim = 0 ).view (- 1 )
960960 _num_e = _tmp .size (0 )
961961 self .cache ["D_e" ] = torch .sparse_coo_tensor (
962- torch .arange (0 , _num_e ).view (1 , - 1 ).repeat (2 , 1 ),
962+ torch .arange (0 , _num_e , device = self . device ).view (1 , - 1 ).repeat (2 , 1 ),
963963 _tmp ,
964964 torch .Size ([_num_e , _num_e ]),
965965 device = self .device ,
@@ -977,7 +977,7 @@ def D_e_of_group(self, group_name: str) -> torch.Tensor:
977977 _tmp = torch .sparse .sum (self .H_T_of_group (group_name ), dim = 1 ).to_dense ().clone ().view (- 1 )
978978 _num_e = _tmp .size (0 )
979979 self .group_cache [group_name ]["D_e" ] = torch .sparse_coo_tensor (
980- torch .arange (0 , _num_e ).view (1 , - 1 ).repeat (2 , 1 ),
980+ torch .arange (0 , _num_e , device = self . device ).view (1 , - 1 ).repeat (2 , 1 ),
981981 _tmp ,
982982 torch .Size ([_num_e , _num_e ]),
983983 device = self .device ,
@@ -1090,8 +1090,8 @@ def L_sym(self) -> torch.Tensor:
10901090 if self .cache .get ("L_sym" ) is None :
10911091 L_HGNN = self .L_HGNN .clone ()
10921092 self .cache ["L_sym" ] = torch .sparse_coo_tensor (
1093- torch .hstack ([torch .arange (0 , self .num_v ).view (1 , - 1 ).repeat (2 , 1 ), L_HGNN ._indices (),]),
1094- torch .hstack ([torch .ones (self .num_v ), - L_HGNN ._values ()]),
1093+ torch .hstack ([torch .arange (0 , self .num_v , device = self . device ).view (1 , - 1 ).repeat (2 , 1 ), L_HGNN ._indices (),]),
1094+ torch .hstack ([torch .ones (self .num_v , device = self . device ), - L_HGNN ._values ()]),
10951095 torch .Size ([self .num_v , self .num_v ]),
10961096 device = self .device ,
10971097 ).coalesce ()
@@ -1110,8 +1110,8 @@ def L_sym_of_group(self, group_name: str) -> torch.Tensor:
11101110 if self .group_cache [group_name ].get ("L_sym" ) is None :
11111111 L_HGNN = self .L_HGNN_of_group (group_name ).clone ()
11121112 self .group_cache [group_name ]["L_sym" ] = torch .sparse_coo_tensor (
1113- torch .hstack ([torch .arange (0 , self .num_v ).view (1 , - 1 ).repeat (2 , 1 ), L_HGNN ._indices (),]),
1114- torch .hstack ([torch .ones (self .num_v ), - L_HGNN ._values ()]),
1113+ torch .hstack ([torch .arange (0 , self .num_v , device = self . device ).view (1 , - 1 ).repeat (2 , 1 ), L_HGNN ._indices (),]),
1114+ torch .hstack ([torch .ones (self .num_v , device = self . device ), - L_HGNN ._values ()]),
11151115 torch .Size ([self .num_v , self .num_v ]),
11161116 device = self .device ,
11171117 ).coalesce ()
@@ -1128,8 +1128,8 @@ def L_rw(self) -> torch.Tensor:
11281128 _tmp = self .D_v_neg_1 .mm (self .H ).mm (self .W_e ).mm (self .D_e_neg_1 ).mm (self .H_T )
11291129 self .cache ["L_rw" ] = (
11301130 torch .sparse_coo_tensor (
1131- torch .hstack ([torch .arange (0 , self .num_v ).view (1 , - 1 ).repeat (2 , 1 ), _tmp ._indices (),]),
1132- torch .hstack ([torch .ones (self .num_v ), - _tmp ._values ()]),
1131+ torch .hstack ([torch .arange (0 , self .num_v , device = self . device ).view (1 , - 1 ).repeat (2 , 1 ), _tmp ._indices (),]),
1132+ torch .hstack ([torch .ones (self .num_v , device = self . device ), - _tmp ._values ()]),
11331133 torch .Size ([self .num_v , self .num_v ]),
11341134 device = self .device ,
11351135 )
@@ -1158,8 +1158,8 @@ def L_rw_of_group(self, group_name: str) -> torch.Tensor:
11581158 )
11591159 self .group_cache [group_name ]["L_rw" ] = (
11601160 torch .sparse_coo_tensor (
1161- torch .hstack ([torch .arange (0 , self .num_v ).view (1 , - 1 ).repeat (2 , 1 ), _tmp ._indices (),]),
1162- torch .hstack ([torch .ones (self .num_v ), - _tmp ._values ()]),
1161+ torch .hstack ([torch .arange (0 , self .num_v , device = self . device ).view (1 , - 1 ).repeat (2 , 1 ), _tmp ._indices (),]),
1162+ torch .hstack ([torch .ones (self .num_v , device = self . device ), - _tmp ._values ()]),
11631163 torch .Size ([self .num_v , self .num_v ]),
11641164 device = self .device ,
11651165 )
0 commit comments