Skip to content

Commit b9cfc3d

Browse files
committed
add BGNN_Adv and BGNN_MLP
1 parent 8b115bd commit b9cfc3d

File tree

5 files changed

+84
-40
lines changed

5 files changed

+84
-40
lines changed

dhg/models/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .graphs import GCN, GraphSAGE, GAT, GIN, NGCF, LightGCN, BGNN_Adv
1+
from .graphs import GCN, GraphSAGE, GAT, GIN, NGCF, LightGCN, BGNN_Adv, BGNN_MLP
22
from .hypergraphs import HGNN, HGNNP, HNHN, HyperGCN, DHCF, UniGCN, UniGAT, UniSAGE, UniGIN
33

44
__all__ = [
@@ -8,6 +8,8 @@
88
"GIN",
99
"NGCF",
1010
"LightGCN",
11+
"BGNN_Adv",
12+
"BGNN_MLP",
1113
"HGNN",
1214
"HGNNP",
1315
"HNHN",

dhg/models/graphs/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
from .gin import GIN
55
from .ngcf import NGCF
66
from .lightgcn import LightGCN
7-
from .bgnn import BGNN_Adv
7+
from .bgnn import BGNN_Adv, BGNN_MLP

dhg/models/graphs/bgnn.py

Lines changed: 76 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -10,48 +10,89 @@ class BGNN_Adv(nn.Module):
1010
r"""The BGNN-Adv model proposed in `Cascade-BGNN: Toward Efficient Self-supervised Representation Learning on Large-scale Bipartite Graphs <https://arxiv.org/pdf/1906.11994.pdf>`_ paper (TNNLS 2020).
1111
1212
Args:
13-
``num_users`` (``int``): The Number of users.
14-
``num_items`` (``int``): The Number of items.
15-
``emb_dim`` (``int``): Embedding dimension.
16-
``num_layers`` (``int``): The Number of layers. Defaults to ``3``.
17-
``drop_rate`` (``float``): Dropout rate. Randomly dropout the connections in training stage with probability ``drop_rate``. Default: ``0.0``.
13+
``u_dim`` (``int``): The dimension of the vertex feature in set :math:`U`.
14+
``v_dim`` (``int``): The dimension of the vertex feature in set :math:`V`.
15+
``layer_depth`` (``int``): The depth of layers.
1816
"""
1917

20-
def __init__(
21-
self, num_users: int, num_items: int, emb_dim: int, num_layers: int = 3, drop_rate: float = 0.0
22-
) -> None:
18+
def __init__(self, u_dim: int, v_dim: int, layer_depth: int = 3,) -> None:
2319

2420
super().__init__()
25-
self.num_users, self.num_items = num_users, num_items
26-
self.num_layers = num_layers
27-
self.drop_rate = drop_rate
28-
self.u_embedding = nn.Embedding(num_users, emb_dim)
29-
self.i_embedding = nn.Embedding(num_items, emb_dim)
30-
self.reset_parameters()
31-
32-
def reset_parameters(self):
33-
r"""Initialize learnable parameters.
21+
self.layer_depth = layer_depth
22+
self.layers = nn.ModuleList()
23+
24+
for _idx in range(layer_depth):
25+
if _idx % 2 == 0:
26+
self.layers.append(nn.Linear(v_dim, u_dim))
27+
else:
28+
self.layers.append(nn.Linear(u_dim, v_dim))
29+
30+
def forward(self, X_u: torch.Tensor, X_v: torch.Tensor, g: BiGraph) -> Tuple[torch.Tensor, torch.Tensor]:
31+
r"""The forward function.
32+
33+
Args:
34+
``X_u`` (``torch.Tensor``): The feature matrix of vertices in set :math:`U`.
35+
``X_v`` (``torch.Tensor``): The feature matrix of vertices in set :math:`V`.
36+
``g`` (``BiGraph``): The bipartite graph.
3437
"""
35-
nn.init.normal_(self.u_embedding.weight, 0, 0.1)
36-
nn.init.normal_(self.i_embedding.weight, 0, 0.1)
38+
last_X_u, last_X_v = X_u, X_v
39+
for _idx in range(self.layer_depth):
40+
if _idx % 2 == 0:
41+
_tmp = self.layers[_idx](last_X_v)
42+
last_X_u = g.v2u(_tmp, aggr="sum")
43+
else:
44+
_tmp = self.layers[_idx](last_X_u)
45+
last_X_v = g.u2v(_tmp, aggr="sum")
46+
return last_X_u
47+
48+
def train_with_cascaded(self):
49+
pass
50+
51+
def train_with_end2end(self):
52+
pass
53+
54+
55+
class BGNN_MLP(nn.Module):
56+
r"""The BGNN-MLP model proposed in `Cascade-BGNN: Toward Efficient Self-supervised Representation Learning on Large-scale Bipartite Graphs <https://arxiv.org/pdf/1906.11994.pdf>`_ paper (TNNLS 2020).
3757
38-
def forward(self, ui_bigraph: BiGraph) -> Tuple[torch.Tensor, torch.Tensor]:
58+
Args:
59+
``u_dim`` (``int``): The dimension of the vertex feature in set :math:`U`.
60+
``v_dim`` (``int``): The dimension of the vertex feature in set :math:`V`.
61+
``layer_depth`` (``int``): The depth of layers.
62+
"""
63+
64+
def __init__(self, u_dim: int, v_dim: int, layer_depth: int = 3,) -> None:
65+
66+
super().__init__()
67+
self.layer_depth = layer_depth
68+
self.layers = nn.ModuleList()
69+
70+
for _idx in range(layer_depth):
71+
if _idx % 2 == 0:
72+
self.layers.append(nn.Linear(v_dim, u_dim))
73+
else:
74+
self.layers.append(nn.Linear(u_dim, v_dim))
75+
76+
def forward(self, X_u: torch.Tensor, X_v: torch.Tensor, g: BiGraph) -> Tuple[torch.Tensor, torch.Tensor]:
3977
r"""The forward function.
4078
4179
Args:
42-
``ui_bigraph`` (``dhg.BiGraph``): The user-item bipartite graph.
80+
``X_u`` (``torch.Tensor``): The feature matrix of vertices in set :math:`U`.
81+
``X_v`` (``torch.Tensor``): The feature matrix of vertices in set :math:`V`.
82+
``g`` (``BiGraph``): The bipartite graph.
4383
"""
44-
drop_rate = self.drop_rate if self.training else 0.0
45-
u_embs = self.u_embedding.weight
46-
i_embs = self.i_embedding.weight
47-
all_embs = torch.cat([u_embs, i_embs], dim=0)
48-
49-
embs_list = [all_embs]
50-
for _ in range(self.num_layers):
51-
all_embs = ui_bigraph.smoothing_with_GCN(all_embs, drop_rate=drop_rate)
52-
embs_list.append(all_embs)
53-
embs = torch.stack(embs_list, dim=1)
54-
embs = torch.mean(embs, dim=1)
55-
56-
u_embs, i_embs = torch.split(embs, [self.num_users, self.num_items], dim=0)
57-
return u_embs, i_embs
84+
last_X_u, last_X_v = X_u, X_v
85+
for _idx in range(self.layer_depth):
86+
if _idx % 2 == 0:
87+
_tmp = self.layers[_idx](last_X_v)
88+
last_X_u = g.v2u(_tmp, aggr="sum")
89+
else:
90+
_tmp = self.layers[_idx](last_X_u)
91+
last_X_v = g.u2v(_tmp, aggr="sum")
92+
return last_X_u
93+
94+
def train_with_cascaded(self):
95+
pass
96+
97+
def train_with_end2end(self):
98+
pass

dhg/structure/graphs/bipartite_graph.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -714,12 +714,12 @@ def v2u(
714714
P = self.B
715715
# message passing
716716
if aggr == "mean":
717-
X = torch.sparse.mm(self.B, X)
717+
X = torch.sparse.mm(P, X)
718718
X = torch.sparse.mm(self.D_u_neg_1, X)
719719
elif aggr == "sum":
720-
X = torch.sparse.mm(self.B, X)
720+
X = torch.sparse.mm(P, X)
721721
elif aggr == "softmax_then_sum":
722-
P = torch.sparse.softmax(self.B, dim=1)
722+
P = torch.sparse.softmax(P, dim=1)
723723
X = torch.sparse.mm(P, X)
724724
else:
725725
pass

docs/source/api/models.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ Models on Bipartite Graph
2525
dhg.models.NGCF
2626
dhg.models.LightGCN
2727
dhg.models.BGNN_Adv
28+
dhg.models.BGNN_MLP
2829

2930

3031
Models on Hypergraph

0 commit comments

Comments
 (0)