@@ -42,10 +42,10 @@ def forward(self, X_u: torch.Tensor, X_v: torch.Tensor, g: BiGraph) -> Tuple[tor
4242 for _idx in range (self .layer_depth ):
4343 if _idx % 2 == 0 :
4444 _tmp = self .layers [_idx ](last_X_v )
45- last_X_u = g .v2u (_tmp , aggr = "sum" )
45+ last_X_u = torch . tanh ( g .v2u (_tmp , aggr = "sum" ) )
4646 else :
4747 _tmp = self .layers [_idx ](last_X_u )
48- last_X_v = g .u2v (_tmp , aggr = "sum" )
48+ last_X_v = torch . tanh ( g .u2v (_tmp , aggr = "sum" ) )
4949 return last_X_u
5050
5151 def train_one_layer (
@@ -63,36 +63,36 @@ def train_one_layer(
6363 netG = layer .to (device )
6464 netD = Discriminator (X_true .shape [1 ], 16 , 1 , drop_rate = drop_rate ).to (device )
6565
66- optimG = optim .Adam (netG .parameters (), lr = lr , weight_decay = weight_decay )
67- optimD = optim .Adam (netD .parameters (), lr = lr , weight_decay = weight_decay )
66+ optimizer_G = optim .Adam (netG .parameters (), lr = lr , weight_decay = weight_decay )
67+ optimizer_D = optim .Adam (netD .parameters (), lr = lr , weight_decay = weight_decay )
6868
69- X_true , X_other = X_true .to (device ), X_other .to (device )
70- lbl_real = torch .ones (X_true .shape [0 ]).to (device )
71- lbl_fake = torch .zeros (X_true .shape [0 ]).to (device )
69+ X_true , X_other = X_true .detach (). to (device ), X_other . detach () .to (device )
70+ lbl_real = torch .ones (X_true .shape [0 ], 1 , requires_grad = False ).to (device )
71+ lbl_fake = torch .zeros (X_true .shape [0 ], 1 , requires_grad = False ).to (device )
7272
7373 netG .train (), netD .train ()
7474 for _ in range (max_epoch ):
7575 X_real = X_true
76- X_fake = mp_func (netG (X_other ))
76+ X_fake = torch . tanh ( mp_func (netG (X_other ) ))
7777
7878 # step 1: train Discriminator
79- optimD .zero_grad ()
79+ optimizer_D .zero_grad ()
8080
8181 pred_real = netD (X_real )
8282 pred_fake = netD (X_fake .detach ())
8383
84- lossD = F .binary_cross_entropy (pred_real , lbl_real ) + F .binary_cross_entropy (pred_fake , lbl_fake )
85- lossD .backward ()
86- optimD .step ()
84+ loss_D = F .binary_cross_entropy (pred_real , lbl_real ) + F .binary_cross_entropy (pred_fake , lbl_fake )
85+ loss_D .backward ()
86+ optimizer_D .step ()
8787
8888 # step 2: train Generator
89- optimG .zero_grad ()
89+ optimizer_G .zero_grad ()
9090
9191 pred_fake = netD (X_fake )
9292
93- lossG = F .binary_cross_entropy (pred_fake , lbl_real )
94- lossG .backward ()
95- optimG .step ()
93+ loss_G = F .binary_cross_entropy (pred_fake , lbl_real )
94+ loss_G .backward ()
95+ optimizer_G .step ()
9696
9797 def train_with_cascaded (
9898 self ,
@@ -117,7 +117,8 @@ def train_with_cascaded(
117117 ``drop_rate`` (``float``): The dropout rate. Default: ``0.5``.
118118 ``device`` (``str``): The device to use. Default: ``"cpu"``.
119119 """
120- last_X_u , last_X_v = X_u , X_v
120+ self = self .to (device )
121+ last_X_u , last_X_v = X_u .to (device ), X_v .to (device )
121122 for _idx in range (self .layer_depth ):
122123 if _idx % 2 == 0 :
123124 self .train_one_layer (
@@ -131,7 +132,8 @@ def train_with_cascaded(
131132 drop_rate ,
132133 device ,
133134 )
134- last_X_u = g .v2u (self .layers [_idx ](last_X_v ), aggr = "sum" )
135+ with torch .no_grad ():
136+ last_X_u = torch .tanh (g .v2u (self .layers [_idx ](last_X_v ), aggr = "sum" ))
135137 else :
136138 self .train_one_layer (
137139 last_X_v ,
@@ -144,7 +146,8 @@ def train_with_cascaded(
144146 drop_rate ,
145147 device ,
146148 )
147- last_X_v = g .u2v (self .layers [_idx ](last_X_u ), aggr = "sum" )
149+ with torch .no_grad ():
150+ last_X_v = torch .tanh (g .u2v (self .layers [_idx ](last_X_u ), aggr = "sum" ))
148151 return last_X_u
149152
150153
@@ -155,10 +158,14 @@ class BGNN_MLP(nn.Module):
155158 ``u_dim`` (``int``): The dimension of the vertex feature in set :math:`U`.
156159 ``v_dim`` (``int``): The dimension of the vertex feature in set :math:`V`.
157160 ``hid_dim`` (``int``): The dimension of the hidden layer.
158- ``layer_depth`` (``int``): The depth of layers.
161+ ``decoder_hid_dim`` (``int``): The dimension of the hidden layer in the decoder.
162+ ``drop_rate`` (``float``): The dropout rate. Default: ``0.5``.
163+ ``layer_depth`` (``int``): The depth of layers. Default: ``3``.
159164 """
160165
161- def __init__ (self , u_dim : int , v_dim : int , hid_dim : int , layer_depth : int = 3 ,) -> None :
166+ def __init__ (
167+ self , u_dim : int , v_dim : int , hid_dim : int , decoder_hid_dim : int , drop_rate : float = 0.5 , layer_depth : int = 3 ,
168+ ) -> None :
162169
163170 super ().__init__ ()
164171 self .layer_depth = layer_depth
@@ -168,10 +175,10 @@ def __init__(self, u_dim: int, v_dim: int, hid_dim: int, layer_depth: int = 3,)
168175 for _idx in range (layer_depth ):
169176 if _idx % 2 == 0 :
170177 self .layers .append (nn .Linear (v_dim , hid_dim ))
171- self .decoders .append (nn . Linear (hid_dim , u_dim ))
178+ self .decoders .append (Decoder (hid_dim , decoder_hid_dim , u_dim , drop_rate = drop_rate ))
172179 else :
173180 self .layers .append (nn .Linear (u_dim , hid_dim ))
174- self .decoders .append (nn . Linear (hid_dim , v_dim ))
181+ self .decoders .append (Decoder (hid_dim , decoder_hid_dim , v_dim , drop_rate = drop_rate ))
175182
176183 def forward (self , X_u : torch .Tensor , X_v : torch .Tensor , g : BiGraph ) -> Tuple [torch .Tensor , torch .Tensor ]:
177184 r"""The forward function.
@@ -185,10 +192,10 @@ def forward(self, X_u: torch.Tensor, X_v: torch.Tensor, g: BiGraph) -> Tuple[tor
185192 for _idx in range (self .layer_depth ):
186193 if _idx % 2 == 0 :
187194 _tmp = self .layers [_idx ](last_X_v )
188- last_X_u = g .v2u (_tmp , aggr = "sum" )
195+ last_X_u = self . decoders [ _idx ]( torch . tanh ( g .v2u (_tmp , aggr = "sum" )) )
189196 else :
190197 _tmp = self .layers [_idx ](last_X_u )
191- last_X_v = g .u2v (_tmp , aggr = "sum" )
198+ last_X_v = self . decoders [ _idx ]( torch . tanh ( g .u2v (_tmp , aggr = "sum" )) )
192199 return last_X_u
193200
194201 def train_one_layer (
@@ -208,12 +215,12 @@ def train_one_layer(
208215
209216 optimizer = optim .Adam ([* netG .parameters (), * netD .parameters ()], lr = lr , weight_decay = weight_decay )
210217
211- X_true , X_other = X_true .to (device ), X_other .to (device )
218+ X_true , X_other = X_true .detach (). to (device ), X_other . detach () .to (device )
212219
213220 netG .train (), netD .train ()
214221 for _ in range (max_epoch ):
215222 X_real = X_true
216- X_fake = netD (mp_func (netG (X_other )))
223+ X_fake = netD (torch . tanh ( mp_func (netG (X_other ) )))
217224
218225 optimizer .zero_grad ()
219226 loss = F .mse_loss (X_fake , X_real )
@@ -241,42 +248,49 @@ def train_with_cascaded(
241248 ``max_epoch`` (``int``): The maximum number of epochs.
242249 ``device`` (``str``): The device to use. Default: ``"cpu"``.
243250 """
244- last_X_u , last_X_v = X_u , X_v
251+ self = self .to (device )
252+ last_X_u , last_X_v = X_u .to (device ), X_v .to (device )
245253 for _idx in range (self .layer_depth ):
246254 if _idx % 2 == 0 :
247255 self .train_one_layer (
248256 last_X_u ,
249257 last_X_v ,
250258 lambda x : g .v2u (x , aggr = "sum" ),
251259 self .layers [_idx ],
260+ self .decoders [_idx ],
252261 lr ,
253262 weight_decay ,
254263 max_epoch ,
255264 device ,
256265 )
257- last_X_u = g .v2u (self .layers [_idx ](last_X_v ), aggr = "sum" )
266+ with torch .no_grad ():
267+ self .decoders [_idx ].eval ()
268+ last_X_u = self .decoders [_idx ](torch .tanh (g .v2u (self .layers [_idx ](last_X_v ), aggr = "sum" )))
258269 else :
259270 self .train_one_layer (
260271 last_X_v ,
261272 last_X_u ,
262273 lambda x : g .u2v (x , aggr = "sum" ),
263274 self .layers [_idx ],
275+ self .decoders [_idx ],
264276 lr ,
265277 weight_decay ,
266278 max_epoch ,
267279 device ,
268280 )
269- last_X_v = g .u2v (self .layers [_idx ](last_X_u ), aggr = "sum" )
281+ with torch .no_grad ():
282+ self .decoders [_idx ].eval ()
283+ last_X_v = self .decoders [_idx ](torch .tanh (g .u2v (self .layers [_idx ](last_X_u ), aggr = "sum" )))
270284 return last_X_u
271285
272286
273287class Decoder (nn .Module ):
274288 def __init__ (self , in_channels : int , hid_channels : int , out_channels : int , drop_rate : float = 0.5 ):
275- super (Decoder , self ).__init__ ()
289+ super ().__init__ ()
276290 self .layers = nn .Sequential (
277291 nn .Linear (in_channels , hid_channels ),
278292 nn .ReLU (),
279- nn .Dropout (p = drop_rate , inplace = True ),
293+ nn .Dropout (p = drop_rate ),
280294 nn .Linear (hid_channels , out_channels ),
281295 nn .Tanh (),
282296 )
0 commit comments