Skip to content

Commit d83ab94

Browse files
committed
update bgnn-adv and bgnn-mlp
1 parent 1f3e267 commit d83ab94

File tree

9 files changed

+190
-40
lines changed

9 files changed

+190
-40
lines changed

dhg/data/citeseer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class CiteseerBiGraph(BaseData):
7777
- ``num_v_vertices``: The number of vertices in set :math:`V` : :math:`742`.
7878
- ``num_edges``: The number of edges: :math:`1,665`.
7979
- ``dim_u_features``: The dimension of features in set :math:`U` : :math:`3,703`.
80-
- ``dim_v_features``: The dimension of features: :math:`3,703`.
80+
- ``dim_v_features``: The dimension of features in set :math:`V` : :math:`3,703`.
8181
- ``u_features``: The vertex feature matrix in set :math:`U`. ``torch.Tensor`` with size :math:`(1,237 \times 3,703)`.
8282
- ``v_features``: The vertex feature matrix in set :math:`V` . ``torch.Tensor`` with size :math:`(742 \times 3,703)`.
8383
- ``edge_list``: The edge list. ``List`` with length :math:`(1,665 \times 2)`.

dhg/data/cora.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ class CoraBiGraph(BaseData):
7979
- ``num_v_vertices``: The number of vertices in set :math:`V` : :math:`789`.
8080
- ``num_edges``: The number of edges: :math:`2,314`.
8181
- ``dim_u_features``: The dimension of features in set :math:`U` : :math:`1,433`.
82-
- ``dim_v_features``: The dimension of features: :math:`1,433`.
82+
- ``dim_v_features``: The dimension of features in set :math:`V` : :math:`1,433`.
8383
- ``u_features``: The vertex feature matrix in set :math:`U`. ``torch.Tensor`` with size :math:`(1,312 \times 1,433)`.
8484
- ``v_features``: The vertex feature matrix in set :math:`V` . ``torch.Tensor`` with size :math:`(789 \times 1,433)`.
8585
- ``edge_list``: The edge list. ``List`` with length :math:`(2,314 \times 2)`.

dhg/data/pubmed.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class PubmedBiGraph(BaseData):
7777
- ``num_v_vertices``: The number of vertices in set :math:`V` : :math:`3,435`.
7878
- ``num_edges``: The number of edges: :math:`18,782`.
7979
- ``dim_u_features``: The dimension of features in set :math:`U` : :math:`400`.
80-
- ``dim_v_features``: The dimension of features: :math:`500`.
80+
- ``dim_v_features``: The dimension of features in set :math:`V` : :math:`500`.
8181
- ``u_features``: The vertex feature matrix in set :math:`U`. ``torch.Tensor`` with size :math:`(13,424 \times 400)`.
8282
- ``v_features``: The vertex feature matrix in set :math:`V` . ``torch.Tensor`` with size :math:`(3,435 \times 500)`.
8383
- ``edge_list``: The edge list. ``List`` with length :math:`(2,314 \times 2)`.

dhg/data/tencent.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class TencentBiGraph(BaseData):
2424
- ``num_v_vertices``: The number of vertices in set :math:`V` : :math:`90,044`.
2525
- ``num_edges``: The number of edges: :math:`144,501`.
2626
- ``dim_u_features``: The dimension of features in set :math:`U` : :math:`8`.
27-
- ``dim_v_features``: The dimension of features: :math:`16`.
27+
- ``dim_v_features``: The dimension of features in set :math:`V` : :math:`16`.
2828
- ``u_features``: The vertex feature matrix in set :math:`U`. ``torch.Tensor`` with size :math:`(619,030 \times 8)`.
2929
- ``v_features``: The vertex feature matrix in set :math:`V` . ``torch.Tensor`` with size :math:`(90,044 \times 16)`.
3030
- ``edge_list``: The edge list. ``List`` with length :math:`(991,713 \times 2)`.

dhg/datapipe/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
to_long_tensor,
66
)
77
from .loader import load_from_pickle, load_from_json, load_from_txt
8-
from .normalize import norm_ft
8+
from .normalize import norm_ft, min_max_scaler
99

1010
__all__ = [
1111
"compose_pipes",
1212
"norm_ft",
13+
"min_max_scaler",
1314
"to_tensor",
1415
"to_bool_tensor",
1516
"to_long_tensor",

dhg/datapipe/normalize.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@
22
import torch
33

44

5-
def norm_ft(
6-
X: torch.Tensor, ord: Optional[Union[int, float]] = None
7-
) -> torch.Tensor:
5+
def norm_ft(X: torch.Tensor, ord: Optional[Union[int, float]] = None) -> torch.Tensor:
86
r"""Normalize the input feature matrix with specified ``ord`` refer to pytorch's `torch.linalg.norm <https://pytorch.org/docs/stable/generated/torch.linalg.norm.html#torch.linalg.norm>`_ function.
97
108
.. note::
@@ -40,3 +38,33 @@ def norm_ft(
4038
"The input feature matrix is expected to be a 1D verter or a 2D tensor with shape (num_samples, num_features)."
4139
)
4240

41+
42+
def min_max_scaler(X: torch.Tensor, ft_min: float, ft_max: float) -> torch.Tensor:
43+
r"""Normalize the input feature matrix with min-max scaling.
44+
45+
Args:
46+
``X`` (``torch.Tensor``): The input feature.
47+
``ft_min`` (``float``): The minimum value of the output feature.
48+
``ft_max`` (``float``): The maximum value of the output feature.
49+
50+
Examples:
51+
>>> import dhg.datapipe as dd
52+
>>> import torch
53+
>>> X = torch.tensor([
54+
[0.1, 0.2, 0.5],
55+
[0.5, 0.2, 0.3],
56+
[0.3, 0.2, 0.0]
57+
])
58+
>>> dd.min_max_scaler(X, -1, 1)
59+
tensor([[-0.6000, -0.2000, 1.0000],
60+
[ 1.0000, -0.2000, 0.2000],
61+
[ 0.2000, -0.2000, -1.0000]])
62+
"""
63+
assert ft_min < ft_max, "The minimum value of the feature should be less than the maximum value."
64+
X_min, X_max = X.min().item(), X.max().item()
65+
X_range = X_max - X_min
66+
scale_ = (ft_max - ft_min) / X_range
67+
min_ = ft_min - X_min * scale_
68+
X = X * scale_ + min_
69+
return X
70+

dhg/models/graphs/bgnn.py

Lines changed: 46 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,10 @@ def forward(self, X_u: torch.Tensor, X_v: torch.Tensor, g: BiGraph) -> Tuple[tor
4242
for _idx in range(self.layer_depth):
4343
if _idx % 2 == 0:
4444
_tmp = self.layers[_idx](last_X_v)
45-
last_X_u = g.v2u(_tmp, aggr="sum")
45+
last_X_u = torch.tanh(g.v2u(_tmp, aggr="sum"))
4646
else:
4747
_tmp = self.layers[_idx](last_X_u)
48-
last_X_v = g.u2v(_tmp, aggr="sum")
48+
last_X_v = torch.tanh(g.u2v(_tmp, aggr="sum"))
4949
return last_X_u
5050

5151
def train_one_layer(
@@ -63,36 +63,36 @@ def train_one_layer(
6363
netG = layer.to(device)
6464
netD = Discriminator(X_true.shape[1], 16, 1, drop_rate=drop_rate).to(device)
6565

66-
optimG = optim.Adam(netG.parameters(), lr=lr, weight_decay=weight_decay)
67-
optimD = optim.Adam(netD.parameters(), lr=lr, weight_decay=weight_decay)
66+
optimizer_G = optim.Adam(netG.parameters(), lr=lr, weight_decay=weight_decay)
67+
optimizer_D = optim.Adam(netD.parameters(), lr=lr, weight_decay=weight_decay)
6868

69-
X_true, X_other = X_true.to(device), X_other.to(device)
70-
lbl_real = torch.ones(X_true.shape[0]).to(device)
71-
lbl_fake = torch.zeros(X_true.shape[0]).to(device)
69+
X_true, X_other = X_true.detach().to(device), X_other.detach().to(device)
70+
lbl_real = torch.ones(X_true.shape[0], 1, requires_grad=False).to(device)
71+
lbl_fake = torch.zeros(X_true.shape[0], 1, requires_grad=False).to(device)
7272

7373
netG.train(), netD.train()
7474
for _ in range(max_epoch):
7575
X_real = X_true
76-
X_fake = mp_func(netG(X_other))
76+
X_fake = torch.tanh(mp_func(netG(X_other)))
7777

7878
# step 1: train Discriminator
79-
optimD.zero_grad()
79+
optimizer_D.zero_grad()
8080

8181
pred_real = netD(X_real)
8282
pred_fake = netD(X_fake.detach())
8383

84-
lossD = F.binary_cross_entropy(pred_real, lbl_real) + F.binary_cross_entropy(pred_fake, lbl_fake)
85-
lossD.backward()
86-
optimD.step()
84+
loss_D = F.binary_cross_entropy(pred_real, lbl_real) + F.binary_cross_entropy(pred_fake, lbl_fake)
85+
loss_D.backward()
86+
optimizer_D.step()
8787

8888
# step 2: train Generator
89-
optimG.zero_grad()
89+
optimizer_G.zero_grad()
9090

9191
pred_fake = netD(X_fake)
9292

93-
lossG = F.binary_cross_entropy(pred_fake, lbl_real)
94-
lossG.backward()
95-
optimG.step()
93+
loss_G = F.binary_cross_entropy(pred_fake, lbl_real)
94+
loss_G.backward()
95+
optimizer_G.step()
9696

9797
def train_with_cascaded(
9898
self,
@@ -117,7 +117,8 @@ def train_with_cascaded(
117117
``drop_rate`` (``float``): The dropout rate. Default: ``0.5``.
118118
``device`` (``str``): The device to use. Default: ``"cpu"``.
119119
"""
120-
last_X_u, last_X_v = X_u, X_v
120+
self = self.to(device)
121+
last_X_u, last_X_v = X_u.to(device), X_v.to(device)
121122
for _idx in range(self.layer_depth):
122123
if _idx % 2 == 0:
123124
self.train_one_layer(
@@ -131,7 +132,8 @@ def train_with_cascaded(
131132
drop_rate,
132133
device,
133134
)
134-
last_X_u = g.v2u(self.layers[_idx](last_X_v), aggr="sum")
135+
with torch.no_grad():
136+
last_X_u = torch.tanh(g.v2u(self.layers[_idx](last_X_v), aggr="sum"))
135137
else:
136138
self.train_one_layer(
137139
last_X_v,
@@ -144,7 +146,8 @@ def train_with_cascaded(
144146
drop_rate,
145147
device,
146148
)
147-
last_X_v = g.u2v(self.layers[_idx](last_X_u), aggr="sum")
149+
with torch.no_grad():
150+
last_X_v = torch.tanh(g.u2v(self.layers[_idx](last_X_u), aggr="sum"))
148151
return last_X_u
149152

150153

@@ -155,10 +158,14 @@ class BGNN_MLP(nn.Module):
155158
``u_dim`` (``int``): The dimension of the vertex feature in set :math:`U`.
156159
``v_dim`` (``int``): The dimension of the vertex feature in set :math:`V`.
157160
``hid_dim`` (``int``): The dimension of the hidden layer.
158-
``layer_depth`` (``int``): The depth of layers.
161+
``decoder_hid_dim`` (``int``): The dimension of the hidden layer in the decoder.
162+
``drop_rate`` (``float``): The dropout rate. Default: ``0.5``.
163+
``layer_depth`` (``int``): The depth of layers. Default: ``3``.
159164
"""
160165

161-
def __init__(self, u_dim: int, v_dim: int, hid_dim: int, layer_depth: int = 3,) -> None:
166+
def __init__(
167+
self, u_dim: int, v_dim: int, hid_dim: int, decoder_hid_dim: int, drop_rate: float = 0.5, layer_depth: int = 3,
168+
) -> None:
162169

163170
super().__init__()
164171
self.layer_depth = layer_depth
@@ -168,10 +175,10 @@ def __init__(self, u_dim: int, v_dim: int, hid_dim: int, layer_depth: int = 3,)
168175
for _idx in range(layer_depth):
169176
if _idx % 2 == 0:
170177
self.layers.append(nn.Linear(v_dim, hid_dim))
171-
self.decoders.append(nn.Linear(hid_dim, u_dim))
178+
self.decoders.append(Decoder(hid_dim, decoder_hid_dim, u_dim, drop_rate=drop_rate))
172179
else:
173180
self.layers.append(nn.Linear(u_dim, hid_dim))
174-
self.decoders.append(nn.Linear(hid_dim, v_dim))
181+
self.decoders.append(Decoder(hid_dim, decoder_hid_dim, v_dim, drop_rate=drop_rate))
175182

176183
def forward(self, X_u: torch.Tensor, X_v: torch.Tensor, g: BiGraph) -> Tuple[torch.Tensor, torch.Tensor]:
177184
r"""The forward function.
@@ -185,10 +192,10 @@ def forward(self, X_u: torch.Tensor, X_v: torch.Tensor, g: BiGraph) -> Tuple[tor
185192
for _idx in range(self.layer_depth):
186193
if _idx % 2 == 0:
187194
_tmp = self.layers[_idx](last_X_v)
188-
last_X_u = g.v2u(_tmp, aggr="sum")
195+
last_X_u = self.decoders[_idx](torch.tanh(g.v2u(_tmp, aggr="sum")))
189196
else:
190197
_tmp = self.layers[_idx](last_X_u)
191-
last_X_v = g.u2v(_tmp, aggr="sum")
198+
last_X_v = self.decoders[_idx](torch.tanh(g.u2v(_tmp, aggr="sum")))
192199
return last_X_u
193200

194201
def train_one_layer(
@@ -208,12 +215,12 @@ def train_one_layer(
208215

209216
optimizer = optim.Adam([*netG.parameters(), *netD.parameters()], lr=lr, weight_decay=weight_decay)
210217

211-
X_true, X_other = X_true.to(device), X_other.to(device)
218+
X_true, X_other = X_true.detach().to(device), X_other.detach().to(device)
212219

213220
netG.train(), netD.train()
214221
for _ in range(max_epoch):
215222
X_real = X_true
216-
X_fake = netD(mp_func(netG(X_other)))
223+
X_fake = netD(torch.tanh(mp_func(netG(X_other))))
217224

218225
optimizer.zero_grad()
219226
loss = F.mse_loss(X_fake, X_real)
@@ -241,42 +248,49 @@ def train_with_cascaded(
241248
``max_epoch`` (``int``): The maximum number of epochs.
242249
``device`` (``str``): The device to use. Default: ``"cpu"``.
243250
"""
244-
last_X_u, last_X_v = X_u, X_v
251+
self = self.to(device)
252+
last_X_u, last_X_v = X_u.to(device), X_v.to(device)
245253
for _idx in range(self.layer_depth):
246254
if _idx % 2 == 0:
247255
self.train_one_layer(
248256
last_X_u,
249257
last_X_v,
250258
lambda x: g.v2u(x, aggr="sum"),
251259
self.layers[_idx],
260+
self.decoders[_idx],
252261
lr,
253262
weight_decay,
254263
max_epoch,
255264
device,
256265
)
257-
last_X_u = g.v2u(self.layers[_idx](last_X_v), aggr="sum")
266+
with torch.no_grad():
267+
self.decoders[_idx].eval()
268+
last_X_u = self.decoders[_idx](torch.tanh(g.v2u(self.layers[_idx](last_X_v), aggr="sum")))
258269
else:
259270
self.train_one_layer(
260271
last_X_v,
261272
last_X_u,
262273
lambda x: g.u2v(x, aggr="sum"),
263274
self.layers[_idx],
275+
self.decoders[_idx],
264276
lr,
265277
weight_decay,
266278
max_epoch,
267279
device,
268280
)
269-
last_X_v = g.u2v(self.layers[_idx](last_X_u), aggr="sum")
281+
with torch.no_grad():
282+
self.decoders[_idx].eval()
283+
last_X_v = self.decoders[_idx](torch.tanh(g.u2v(self.layers[_idx](last_X_u), aggr="sum")))
270284
return last_X_u
271285

272286

273287
class Decoder(nn.Module):
274288
def __init__(self, in_channels: int, hid_channels: int, out_channels: int, drop_rate: float = 0.5):
275-
super(Decoder, self).__init__()
289+
super().__init__()
276290
self.layers = nn.Sequential(
277291
nn.Linear(in_channels, hid_channels),
278292
nn.ReLU(),
279-
nn.Dropout(p=drop_rate, inplace=True),
293+
nn.Dropout(p=drop_rate),
280294
nn.Linear(hid_channels, out_channels),
281295
nn.Tanh(),
282296
)

docs/source/api/datapipe.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@ Transforms
1515

1616
.. autofunction:: dhg.datapipe.norm_ft
1717

18+
.. autofunction:: dhg.datapipe.min_max_scaler
19+
1820
.. autofunction:: dhg.datapipe.to_tensor
1921

2022
.. autofunction:: dhg.datapipe.to_bool_tensor

0 commit comments

Comments
 (0)