Skip to content

Commit 1fcc17e

Browse files
committed
add BGNN-Adv and BGNN-MLP methods on bipartite graph
1 parent b9cfc3d commit 1fcc17e

File tree

6 files changed

+234
-15
lines changed

6 files changed

+234
-15
lines changed

dhg/models/graphs/bgnn.py

Lines changed: 199 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
1-
from typing import Tuple
1+
from typing import Tuple, Callable
22

33
import torch
44
import torch.nn as nn
5+
import torch.optim as optim
6+
import torch.nn.functional as F
57

68
from dhg.structure.graphs import BiGraph
9+
from dhg.nn.convs.common import Discriminator
710

811

912
class BGNN_Adv(nn.Module):
@@ -15,7 +18,7 @@ class BGNN_Adv(nn.Module):
1518
``layer_depth`` (``int``): The depth of layers.
1619
"""
1720

18-
def __init__(self, u_dim: int, v_dim: int, layer_depth: int = 3,) -> None:
21+
def __init__(self, u_dim: int, v_dim: int, layer_depth: int = 3) -> None:
1922

2023
super().__init__()
2124
self.layer_depth = layer_depth
@@ -45,11 +48,103 @@ def forward(self, X_u: torch.Tensor, X_v: torch.Tensor, g: BiGraph) -> Tuple[tor
4548
last_X_v = g.u2v(_tmp, aggr="sum")
4649
return last_X_u
4750

48-
def train_with_cascaded(self):
49-
pass
51+
def train_one_layer(
52+
self,
53+
X_true: torch.Tensor,
54+
X_other: torch.Tensor,
55+
mp_func: Callable,
56+
layer: nn.Module,
57+
lr: float,
58+
weight_decay: float,
59+
max_epoch: int,
60+
drop_rate: float = 0.5,
61+
device: str = "cpu",
62+
):
63+
netG = layer.to(device)
64+
netD = Discriminator(X_true.shape[1], 16, 1, drop_rate=drop_rate).to(device)
5065

51-
def train_with_end2end(self):
52-
pass
66+
optimG = optim.Adam(netG.parameters(), lr=lr, weight_decay=weight_decay)
67+
optimD = optim.Adam(netD.parameters(), lr=lr, weight_decay=weight_decay)
68+
69+
X_true, X_other = X_true.to(device), X_other.to(device)
70+
lbl_real = torch.ones(X_true.shape[0]).to(device)
71+
lbl_fake = torch.zeros(X_true.shape[0]).to(device)
72+
73+
for _ in range(max_epoch):
74+
X_real = X_true
75+
X_fake = mp_func(netG(X_other))
76+
77+
# step 1: train Discriminator
78+
optimD.zero_grad()
79+
80+
pred_real = netD(X_real)
81+
pred_fake = netD(X_fake.detach())
82+
83+
lossD = F.binary_cross_entropy(pred_real, lbl_real) + F.binary_cross_entropy(pred_fake, lbl_fake)
84+
lossD.backward()
85+
optimD.step()
86+
87+
# step 2: train Generator
88+
optimG.zero_grad()
89+
90+
pred_fake = netD(X_fake)
91+
92+
lossG = F.binary_cross_entropy(pred_fake, lbl_real)
93+
lossG.backward()
94+
optimG.step()
95+
96+
def train_with_cascaded(
97+
self,
98+
X_u: torch.Tensor,
99+
X_v: torch.Tensor,
100+
g: BiGraph,
101+
lr: float,
102+
weight_decay: float,
103+
max_epoch: int,
104+
drop_rate: float = 0.5,
105+
device: str = "cpu",
106+
):
107+
r"""Train the model with cascaded strategy.
108+
109+
Args:
110+
``X_u`` (``torch.Tensor``): The feature matrix of vertices in set :math:`U`.
111+
``X_v`` (``torch.Tensor``): The feature matrix of vertices in set :math:`V`.
112+
``g`` (``BiGraph``): The bipartite graph.
113+
``lr`` (``float``): The learning rate.
114+
``weight_decay`` (``float``): The weight decay.
115+
``max_epoch`` (``int``): The maximum number of epochs.
116+
``drop_rate`` (``float``): The dropout rate. Default: ``0.5``.
117+
``device`` (``str``): The device to use. Default: ``"cpu"``.
118+
"""
119+
last_X_u, last_X_v = X_u, X_v
120+
for _idx in range(self.layer_depth):
121+
if _idx % 2 == 0:
122+
self.train_one_layer(
123+
last_X_u,
124+
last_X_v,
125+
lambda x: g.v2u(x, aggr="sum"),
126+
self.layers[_idx],
127+
lr,
128+
weight_decay,
129+
max_epoch,
130+
drop_rate,
131+
device,
132+
)
133+
last_X_u = g.v2u(self.layers[_idx](last_X_v), aggr="sum")
134+
else:
135+
self.train_one_layer(
136+
last_X_v,
137+
last_X_u,
138+
lambda x: g.u2v(x, aggr="sum"),
139+
self.layers[_idx],
140+
lr,
141+
weight_decay,
142+
max_epoch,
143+
drop_rate,
144+
device,
145+
)
146+
last_X_v = g.u2v(self.layers[_idx](last_X_u), aggr="sum")
147+
return last_X_u
53148

54149

55150
class BGNN_MLP(nn.Module):
@@ -58,20 +153,24 @@ class BGNN_MLP(nn.Module):
58153
Args:
59154
``u_dim`` (``int``): The dimension of the vertex feature in set :math:`U`.
60155
``v_dim`` (``int``): The dimension of the vertex feature in set :math:`V`.
156+
``hid_dim`` (``int``): The dimension of the hidden layer.
61157
``layer_depth`` (``int``): The depth of layers.
62158
"""
63159

64-
def __init__(self, u_dim: int, v_dim: int, layer_depth: int = 3,) -> None:
160+
def __init__(self, u_dim: int, v_dim: int, hid_dim: int, layer_depth: int = 3,) -> None:
65161

66162
super().__init__()
67163
self.layer_depth = layer_depth
68164
self.layers = nn.ModuleList()
165+
self.decoders = nn.ModuleList()
69166

70167
for _idx in range(layer_depth):
71168
if _idx % 2 == 0:
72-
self.layers.append(nn.Linear(v_dim, u_dim))
169+
self.layers.append(nn.Linear(v_dim, hid_dim))
170+
self.decoders.append(nn.Linear(hid_dim, u_dim))
73171
else:
74-
self.layers.append(nn.Linear(u_dim, v_dim))
172+
self.layers.append(nn.Linear(u_dim, hid_dim))
173+
self.decoders.append(nn.Linear(hid_dim, v_dim))
75174

76175
def forward(self, X_u: torch.Tensor, X_v: torch.Tensor, g: BiGraph) -> Tuple[torch.Tensor, torch.Tensor]:
77176
r"""The forward function.
@@ -91,8 +190,95 @@ def forward(self, X_u: torch.Tensor, X_v: torch.Tensor, g: BiGraph) -> Tuple[tor
91190
last_X_v = g.u2v(_tmp, aggr="sum")
92191
return last_X_u
93192

94-
def train_with_cascaded(self):
95-
pass
193+
def train_one_layer(
194+
self,
195+
X_true: torch.Tensor,
196+
X_other: torch.Tensor,
197+
mp_func: Callable,
198+
layer: nn.Module,
199+
decoder: nn.Module,
200+
lr: float,
201+
weight_decay: float,
202+
max_epoch: int,
203+
device: str = "cpu",
204+
):
205+
netG = layer.to(device)
206+
netD = decoder.to(device)
207+
208+
optimizer = optim.Adam([*netG.parameters(), *netD.parameters()], lr=lr, weight_decay=weight_decay)
209+
210+
X_true, X_other = X_true.to(device), X_other.to(device)
211+
212+
for _ in range(max_epoch):
213+
X_real = X_true
214+
X_fake = netD(mp_func(netG(X_other)))
215+
216+
optimizer.zero_grad()
217+
loss = F.mse_loss(X_fake, X_real)
218+
loss.backward()
219+
optimizer.step()
220+
221+
def train_with_cascaded(
222+
self,
223+
X_u: torch.Tensor,
224+
X_v: torch.Tensor,
225+
g: BiGraph,
226+
lr: float,
227+
weight_decay: float,
228+
max_epoch: int,
229+
device: str = "cpu",
230+
):
231+
r"""Train the model with cascaded strategy.
232+
233+
Args:
234+
``X_u`` (``torch.Tensor``): The feature matrix of vertices in set :math:`U`.
235+
``X_v`` (``torch.Tensor``): The feature matrix of vertices in set :math:`V`.
236+
``g`` (``BiGraph``): The bipartite graph.
237+
``lr`` (``float``): The learning rate.
238+
``weight_decay`` (``float``): The weight decay.
239+
``max_epoch`` (``int``): The maximum number of epochs.
240+
``device`` (``str``): The device to use. Default: ``"cpu"``.
241+
"""
242+
last_X_u, last_X_v = X_u, X_v
243+
for _idx in range(self.layer_depth):
244+
if _idx % 2 == 0:
245+
self.train_one_layer(
246+
last_X_u,
247+
last_X_v,
248+
lambda x: g.v2u(x, aggr="sum"),
249+
self.layers[_idx],
250+
lr,
251+
weight_decay,
252+
max_epoch,
253+
device,
254+
)
255+
last_X_u = g.v2u(self.layers[_idx](last_X_v), aggr="sum")
256+
else:
257+
self.train_one_layer(
258+
last_X_v,
259+
last_X_u,
260+
lambda x: g.u2v(x, aggr="sum"),
261+
self.layers[_idx],
262+
lr,
263+
weight_decay,
264+
max_epoch,
265+
device,
266+
)
267+
last_X_v = g.u2v(self.layers[_idx](last_X_u), aggr="sum")
268+
return last_X_u
269+
270+
271+
class Decoder(nn.Module):
272+
def __init__(self, in_channels: int, hid_channels: int, out_channels: int, drop_rate: float = 0.5):
273+
super(Decoder, self).__init__()
274+
self.layers = nn.Sequential(
275+
nn.Linear(in_channels, hid_channels),
276+
nn.ReLU(),
277+
nn.Dropout(p=drop_rate, inplace=True),
278+
nn.Linear(hid_channels, out_channels),
279+
nn.Tanh(),
280+
)
96281

97-
def train_with_end2end(self):
98-
pass
282+
def forward(self, X):
283+
X = self.layers(X)
284+
return X

dhg/nn/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .convs.common import MLP, MultiHeadWrapper
1+
from .convs.common import MLP, MultiHeadWrapper, Discriminator
22
from .convs.graphs import GCNConv, GATConv, GraphSAGEConv, GINConv
33
from .convs.hypergraphs import (
44
HGNNConv,
@@ -17,6 +17,7 @@
1717
__all__ = [
1818
"MLP",
1919
"MultiHeadWrapper",
20+
"Discriminator",
2021
"GCNConv",
2122
"GATConv",
2223
"GraphSAGEConv",

dhg/nn/convs/common.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,3 +106,33 @@ def forward(self, **kwargs) -> torch.Tensor:
106106
else:
107107
raise ValueError("Unknown readout type")
108108

109+
110+
class Discriminator(nn.Module):
111+
r"""The Discriminator for Generative Adversarial Networks (GANs).
112+
113+
Args:
114+
``in_channels`` (``int``): The number of input channels.
115+
``hid_channels`` (``int``): The number of hidden channels.
116+
``out_channels`` (``int``): The number of output channels.
117+
``drop_rate`` (``float``): Dropout ratio. Defaults to ``0.5``.
118+
"""
119+
120+
def __init__(self, in_channels: int, hid_channels: int, out_channels: int, drop_rate: float = 0.5):
121+
122+
super(Discriminator, self).__init__()
123+
self.layers = nn.Sequential(
124+
nn.Linear(in_channels, hid_channels),
125+
nn.LeakyReLU(),
126+
nn.Dropout(p=drop_rate),
127+
nn.Linear(hid_channels, out_channels),
128+
nn.Sigmoid(),
129+
)
130+
131+
def forward(self, X):
132+
"""The forward function.
133+
134+
Args:
135+
``X`` (``torch.Tensor``): The input tensor.
136+
"""
137+
X = self.layers(X)
138+
return X

docs/source/_templates/model_template.rst

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@
77

88
.. autoclass:: {{ name }}
99
:show-inheritance:
10-
:members: forward
10+
:members: forward, train_with_cascaded

docs/source/api/models.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ Models on Graph
1717

1818
Models on Bipartite Graph
1919
-----------------------------
20+
2021
.. autosummary::
2122
:toctree: ../generated/
2223
:nosignatures:

docs/source/api/nn.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ Common Layers
1313

1414
dhg.nn.MLP
1515
dhg.nn.MultiHeadWrapper
16+
dhg.nn.Discriminator
1617

1718

1819
Layers on Graph

0 commit comments

Comments
 (0)