1- from typing import Tuple
1+ from typing import Tuple , Callable
22
33import torch
44import torch .nn as nn
5+ import torch .optim as optim
6+ import torch .nn .functional as F
57
68from dhg .structure .graphs import BiGraph
9+ from dhg .nn .convs .common import Discriminator
710
811
912class BGNN_Adv (nn .Module ):
@@ -15,7 +18,7 @@ class BGNN_Adv(nn.Module):
1518 ``layer_depth`` (``int``): The depth of layers.
1619 """
1720
18- def __init__ (self , u_dim : int , v_dim : int , layer_depth : int = 3 , ) -> None :
21+ def __init__ (self , u_dim : int , v_dim : int , layer_depth : int = 3 ) -> None :
1922
2023 super ().__init__ ()
2124 self .layer_depth = layer_depth
@@ -45,11 +48,103 @@ def forward(self, X_u: torch.Tensor, X_v: torch.Tensor, g: BiGraph) -> Tuple[tor
4548 last_X_v = g .u2v (_tmp , aggr = "sum" )
4649 return last_X_u
4750
48- def train_with_cascaded (self ):
49- pass
51+ def train_one_layer (
52+ self ,
53+ X_true : torch .Tensor ,
54+ X_other : torch .Tensor ,
55+ mp_func : Callable ,
56+ layer : nn .Module ,
57+ lr : float ,
58+ weight_decay : float ,
59+ max_epoch : int ,
60+ drop_rate : float = 0.5 ,
61+ device : str = "cpu" ,
62+ ):
63+ netG = layer .to (device )
64+ netD = Discriminator (X_true .shape [1 ], 16 , 1 , drop_rate = drop_rate ).to (device )
5065
51- def train_with_end2end (self ):
52- pass
66+ optimG = optim .Adam (netG .parameters (), lr = lr , weight_decay = weight_decay )
67+ optimD = optim .Adam (netD .parameters (), lr = lr , weight_decay = weight_decay )
68+
69+ X_true , X_other = X_true .to (device ), X_other .to (device )
70+ lbl_real = torch .ones (X_true .shape [0 ]).to (device )
71+ lbl_fake = torch .zeros (X_true .shape [0 ]).to (device )
72+
73+ for _ in range (max_epoch ):
74+ X_real = X_true
75+ X_fake = mp_func (netG (X_other ))
76+
77+ # step 1: train Discriminator
78+ optimD .zero_grad ()
79+
80+ pred_real = netD (X_real )
81+ pred_fake = netD (X_fake .detach ())
82+
83+ lossD = F .binary_cross_entropy (pred_real , lbl_real ) + F .binary_cross_entropy (pred_fake , lbl_fake )
84+ lossD .backward ()
85+ optimD .step ()
86+
87+ # step 2: train Generator
88+ optimG .zero_grad ()
89+
90+ pred_fake = netD (X_fake )
91+
92+ lossG = F .binary_cross_entropy (pred_fake , lbl_real )
93+ lossG .backward ()
94+ optimG .step ()
95+
96+ def train_with_cascaded (
97+ self ,
98+ X_u : torch .Tensor ,
99+ X_v : torch .Tensor ,
100+ g : BiGraph ,
101+ lr : float ,
102+ weight_decay : float ,
103+ max_epoch : int ,
104+ drop_rate : float = 0.5 ,
105+ device : str = "cpu" ,
106+ ):
107+ r"""Train the model with cascaded strategy.
108+
109+ Args:
110+ ``X_u`` (``torch.Tensor``): The feature matrix of vertices in set :math:`U`.
111+ ``X_v`` (``torch.Tensor``): The feature matrix of vertices in set :math:`V`.
112+ ``g`` (``BiGraph``): The bipartite graph.
113+ ``lr`` (``float``): The learning rate.
114+ ``weight_decay`` (``float``): The weight decay.
115+ ``max_epoch`` (``int``): The maximum number of epochs.
116+ ``drop_rate`` (``float``): The dropout rate. Default: ``0.5``.
117+ ``device`` (``str``): The device to use. Default: ``"cpu"``.
118+ """
119+ last_X_u , last_X_v = X_u , X_v
120+ for _idx in range (self .layer_depth ):
121+ if _idx % 2 == 0 :
122+ self .train_one_layer (
123+ last_X_u ,
124+ last_X_v ,
125+ lambda x : g .v2u (x , aggr = "sum" ),
126+ self .layers [_idx ],
127+ lr ,
128+ weight_decay ,
129+ max_epoch ,
130+ drop_rate ,
131+ device ,
132+ )
133+ last_X_u = g .v2u (self .layers [_idx ](last_X_v ), aggr = "sum" )
134+ else :
135+ self .train_one_layer (
136+ last_X_v ,
137+ last_X_u ,
138+ lambda x : g .u2v (x , aggr = "sum" ),
139+ self .layers [_idx ],
140+ lr ,
141+ weight_decay ,
142+ max_epoch ,
143+ drop_rate ,
144+ device ,
145+ )
146+ last_X_v = g .u2v (self .layers [_idx ](last_X_u ), aggr = "sum" )
147+ return last_X_u
53148
54149
55150class BGNN_MLP (nn .Module ):
@@ -58,20 +153,24 @@ class BGNN_MLP(nn.Module):
58153 Args:
59154 ``u_dim`` (``int``): The dimension of the vertex feature in set :math:`U`.
60155 ``v_dim`` (``int``): The dimension of the vertex feature in set :math:`V`.
156+ ``hid_dim`` (``int``): The dimension of the hidden layer.
61157 ``layer_depth`` (``int``): The depth of layers.
62158 """
63159
64- def __init__ (self , u_dim : int , v_dim : int , layer_depth : int = 3 ,) -> None :
160+ def __init__ (self , u_dim : int , v_dim : int , hid_dim : int , layer_depth : int = 3 ,) -> None :
65161
66162 super ().__init__ ()
67163 self .layer_depth = layer_depth
68164 self .layers = nn .ModuleList ()
165+ self .decoders = nn .ModuleList ()
69166
70167 for _idx in range (layer_depth ):
71168 if _idx % 2 == 0 :
72- self .layers .append (nn .Linear (v_dim , u_dim ))
169+ self .layers .append (nn .Linear (v_dim , hid_dim ))
170+ self .decoders .append (nn .Linear (hid_dim , u_dim ))
73171 else :
74- self .layers .append (nn .Linear (u_dim , v_dim ))
172+ self .layers .append (nn .Linear (u_dim , hid_dim ))
173+ self .decoders .append (nn .Linear (hid_dim , v_dim ))
75174
76175 def forward (self , X_u : torch .Tensor , X_v : torch .Tensor , g : BiGraph ) -> Tuple [torch .Tensor , torch .Tensor ]:
77176 r"""The forward function.
@@ -91,8 +190,95 @@ def forward(self, X_u: torch.Tensor, X_v: torch.Tensor, g: BiGraph) -> Tuple[tor
91190 last_X_v = g .u2v (_tmp , aggr = "sum" )
92191 return last_X_u
93192
94- def train_with_cascaded (self ):
95- pass
193+ def train_one_layer (
194+ self ,
195+ X_true : torch .Tensor ,
196+ X_other : torch .Tensor ,
197+ mp_func : Callable ,
198+ layer : nn .Module ,
199+ decoder : nn .Module ,
200+ lr : float ,
201+ weight_decay : float ,
202+ max_epoch : int ,
203+ device : str = "cpu" ,
204+ ):
205+ netG = layer .to (device )
206+ netD = decoder .to (device )
207+
208+ optimizer = optim .Adam ([* netG .parameters (), * netD .parameters ()], lr = lr , weight_decay = weight_decay )
209+
210+ X_true , X_other = X_true .to (device ), X_other .to (device )
211+
212+ for _ in range (max_epoch ):
213+ X_real = X_true
214+ X_fake = netD (mp_func (netG (X_other )))
215+
216+ optimizer .zero_grad ()
217+ loss = F .mse_loss (X_fake , X_real )
218+ loss .backward ()
219+ optimizer .step ()
220+
221+ def train_with_cascaded (
222+ self ,
223+ X_u : torch .Tensor ,
224+ X_v : torch .Tensor ,
225+ g : BiGraph ,
226+ lr : float ,
227+ weight_decay : float ,
228+ max_epoch : int ,
229+ device : str = "cpu" ,
230+ ):
231+ r"""Train the model with cascaded strategy.
232+
233+ Args:
234+ ``X_u`` (``torch.Tensor``): The feature matrix of vertices in set :math:`U`.
235+ ``X_v`` (``torch.Tensor``): The feature matrix of vertices in set :math:`V`.
236+ ``g`` (``BiGraph``): The bipartite graph.
237+ ``lr`` (``float``): The learning rate.
238+ ``weight_decay`` (``float``): The weight decay.
239+ ``max_epoch`` (``int``): The maximum number of epochs.
240+ ``device`` (``str``): The device to use. Default: ``"cpu"``.
241+ """
242+ last_X_u , last_X_v = X_u , X_v
243+ for _idx in range (self .layer_depth ):
244+ if _idx % 2 == 0 :
245+ self .train_one_layer (
246+ last_X_u ,
247+ last_X_v ,
248+ lambda x : g .v2u (x , aggr = "sum" ),
249+ self .layers [_idx ],
250+ lr ,
251+ weight_decay ,
252+ max_epoch ,
253+ device ,
254+ )
255+ last_X_u = g .v2u (self .layers [_idx ](last_X_v ), aggr = "sum" )
256+ else :
257+ self .train_one_layer (
258+ last_X_v ,
259+ last_X_u ,
260+ lambda x : g .u2v (x , aggr = "sum" ),
261+ self .layers [_idx ],
262+ lr ,
263+ weight_decay ,
264+ max_epoch ,
265+ device ,
266+ )
267+ last_X_v = g .u2v (self .layers [_idx ](last_X_u ), aggr = "sum" )
268+ return last_X_u
269+
270+
271+ class Decoder (nn .Module ):
272+ def __init__ (self , in_channels : int , hid_channels : int , out_channels : int , drop_rate : float = 0.5 ):
273+ super (Decoder , self ).__init__ ()
274+ self .layers = nn .Sequential (
275+ nn .Linear (in_channels , hid_channels ),
276+ nn .ReLU (),
277+ nn .Dropout (p = drop_rate , inplace = True ),
278+ nn .Linear (hid_channels , out_channels ),
279+ nn .Tanh (),
280+ )
96281
97- def train_with_end2end (self ):
98- pass
282+ def forward (self , X ):
283+ X = self .layers (X )
284+ return X
0 commit comments