Skip to content

Commit a470d5c

Browse files
darrylongqtuantruongtqtg
authored
Optimize LightGCN Model (#531)
* Generated model base from LightGCN * wip * wip example * add self-connection * refactor code * added sanity check * Changed train batch size in example to 1024 * Updated readme for example folder * Update Readme * update docs * Update block comment * WIP * Updated validation metric * Updated message handling * Added legacy lightgcn for comparison purposes * Changed to follow 'a_k = 1/(k+1)', k instead of i * Changed early stopping technique to follow NGCF * remove test_batchsize, early stop verbose to false * Changed parameters to align with paper and ngcf * refractor codes * update docstring * change param name to 'batch_size' * Fix paper reference --------- Co-authored-by: tqtg <tuantq.vnu@gmail.com> Co-authored-by: Quoc-Tuan Truong <tqtg@users.noreply.github.com>
1 parent c484988 commit a470d5c

File tree

3 files changed

+132
-148
lines changed

3 files changed

+132
-148
lines changed

cornac/models/lightgcn/lightgcn.py

Lines changed: 92 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,14 @@
11
import torch
22
import torch.nn as nn
3+
import torch.nn.functional as F
34
import dgl
45
import dgl.function as fn
56

67

8+
USER_KEY = "user"
9+
ITEM_KEY = "item"
10+
11+
712
def construct_graph(data_set):
813
"""
914
Generates graph given a cornac data set
@@ -14,89 +19,109 @@ def construct_graph(data_set):
1419
The data set as provided by cornac
1520
"""
1621
user_indices, item_indices, _ = data_set.uir_tuple
17-
user_nodes, item_nodes = (
18-
torch.from_numpy(user_indices),
19-
torch.from_numpy(
20-
item_indices + data_set.total_users
21-
), # increment item node idx by num users
22-
)
2322

24-
u = torch.cat([user_nodes, item_nodes], dim=0)
25-
v = torch.cat([item_nodes, user_nodes], dim=0)
23+
data_dict = {
24+
(USER_KEY, "user_item", ITEM_KEY): (user_indices, item_indices),
25+
(ITEM_KEY, "item_user", USER_KEY): (item_indices, user_indices),
26+
}
27+
num_dict = {USER_KEY: data_set.total_users, ITEM_KEY: data_set.total_items}
2628

27-
g = dgl.graph((u, v), num_nodes=(data_set.total_users + data_set.total_items))
28-
return g
29+
return dgl.heterograph(data_dict, num_nodes_dict=num_dict)
2930

3031

3132
class GCNLayer(nn.Module):
32-
def __init__(self):
33+
def __init__(self, norm_dict):
3334
super(GCNLayer, self).__init__()
3435

35-
def forward(self, graph, src_embedding, dst_embedding):
36-
with graph.local_scope():
37-
inner_product = torch.cat((src_embedding, dst_embedding), dim=0)
38-
39-
out_degs = graph.out_degrees().to(src_embedding.device).float().clamp(min=1)
40-
norm_out_degs = torch.pow(out_degs, -0.5).view(-1, 1) # D^-1/2
41-
42-
inner_product = inner_product * norm_out_degs
43-
44-
graph.ndata["h"] = inner_product
45-
graph.update_all(
46-
message_func=fn.copy_u("h", "m"), reduce_func=fn.sum("m", "h")
47-
)
48-
49-
res = graph.ndata["h"]
50-
51-
in_degs = graph.in_degrees().to(src_embedding.device).float().clamp(min=1)
52-
norm_in_degs = torch.pow(in_degs, -0.5).view(-1, 1) # D^-1/2
53-
54-
res = res * norm_in_degs
55-
return res
36+
# norm
37+
self.norm_dict = norm_dict
38+
39+
def forward(self, g, feat_dict):
40+
funcs = {} # message and reduce functions dict
41+
# for each type of edges, compute messages and reduce them all
42+
for srctype, etype, dsttype in g.canonical_etypes:
43+
src, dst = g.edges(etype=(srctype, etype, dsttype))
44+
norm = self.norm_dict[(srctype, etype, dsttype)]
45+
# TODO: CHECK HERE
46+
messages = norm * feat_dict[srctype][src] # compute messages
47+
g.edges[(srctype, etype, dsttype)].data[
48+
etype
49+
] = messages # store in edata
50+
funcs[(srctype, etype, dsttype)] = (
51+
fn.copy_e(etype, "m"),
52+
fn.sum("m", "h"),
53+
) # define message and reduce functions
54+
55+
g.multi_update_all(
56+
funcs, "sum"
57+
) # update all, reduce by first type-wisely then across different types
58+
feature_dict = {}
59+
for ntype in g.ntypes:
60+
h = F.normalize(g.nodes[ntype].data["h"], dim=1, p=2) # l2 normalize
61+
feature_dict[ntype] = h
62+
return feature_dict
5663

5764

5865
class Model(nn.Module):
59-
def __init__(self, user_size, item_size, hidden_size, num_layers=3, device=None):
66+
def __init__(self, g, in_size, num_layers, lambda_reg, device=None):
6067
super(Model, self).__init__()
61-
self.user_size = user_size
62-
self.item_size = item_size
63-
self.hidden_size = hidden_size
64-
self.embedding_weights = self._init_weights()
65-
self.layers = nn.ModuleList([GCNLayer() for _ in range(num_layers)])
68+
self.norm_dict = dict()
69+
self.lambda_reg = lambda_reg
6670
self.device = device
6771

68-
def forward(self, graph):
69-
user_embedding = self.embedding_weights["user_embedding"]
70-
item_embedding = self.embedding_weights["item_embedding"]
72+
for srctype, etype, dsttype in g.canonical_etypes:
73+
src, dst = g.edges(etype=(srctype, etype, dsttype))
74+
dst_degree = g.in_degrees(
75+
dst, etype=(srctype, etype, dsttype)
76+
).float() # obtain degrees
77+
src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float()
78+
norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1) # compute norm
79+
self.norm_dict[(srctype, etype, dsttype)] = norm
7180

72-
for i, layer in enumerate(self.layers, start=1):
73-
if i == 1:
74-
embeddings = layer(graph, user_embedding, item_embedding)
75-
else:
76-
embeddings = layer(
77-
graph, embeddings[: self.user_size], embeddings[self.user_size:]
78-
)
81+
self.layers = nn.ModuleList([GCNLayer(self.norm_dict) for _ in range(num_layers)])
7982

80-
user_embedding = user_embedding + embeddings[: self.user_size] * (
81-
1 / (i + 1)
82-
)
83-
item_embedding = item_embedding + embeddings[self.user_size:] * (
84-
1 / (i + 1)
85-
)
86-
87-
return user_embedding, item_embedding
83+
self.initializer = nn.init.xavier_uniform_
8884

89-
def _init_weights(self):
90-
initializer = nn.init.xavier_uniform_
91-
92-
weights_dict = nn.ParameterDict(
85+
# embeddings for different types of nodes
86+
self.feature_dict = nn.ParameterDict(
9387
{
94-
"user_embedding": nn.Parameter(
95-
initializer(torch.empty(self.user_size, self.hidden_size))
96-
),
97-
"item_embedding": nn.Parameter(
98-
initializer(torch.empty(self.item_size, self.hidden_size))
99-
),
88+
ntype: nn.Parameter(
89+
self.initializer(torch.empty(g.num_nodes(ntype), in_size))
90+
)
91+
for ntype in g.ntypes
10092
}
10193
)
102-
return weights_dict
94+
95+
def forward(self, g, users=None, pos_items=None, neg_items=None):
96+
h_dict = {ntype: self.feature_dict[ntype] for ntype in g.ntypes}
97+
# obtain features of each layer and concatenate them all
98+
user_embeds = h_dict[USER_KEY]
99+
item_embeds = h_dict[ITEM_KEY]
100+
101+
for k, layer in enumerate(self.layers):
102+
h_dict = layer(g, h_dict)
103+
user_embeds = user_embeds + (h_dict[USER_KEY] * 1 / (k + 1))
104+
item_embeds = item_embeds + (h_dict[ITEM_KEY] * 1 / (k + 1))
105+
106+
u_g_embeddings = user_embeds if users is None else user_embeds[users, :]
107+
pos_i_g_embeddings = item_embeds if pos_items is None else item_embeds[pos_items, :]
108+
neg_i_g_embeddings = item_embeds if neg_items is None else item_embeds[neg_items, :]
109+
110+
return u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings
111+
112+
def loss_fn(self, users, pos_items, neg_items):
113+
pos_scores = (users * pos_items).sum(1)
114+
neg_scores = (users * neg_items).sum(1)
115+
116+
bpr_loss = F.softplus(neg_scores - pos_scores).mean()
117+
reg_loss = (
118+
(1 / 2)
119+
* (
120+
torch.norm(users) ** 2
121+
+ torch.norm(pos_items) ** 2
122+
+ torch.norm(neg_items) ** 2
123+
)
124+
/ len(users)
125+
)
126+
127+
return bpr_loss + self.lambda_reg * reg_loss, bpr_loss, reg_loss

cornac/models/lightgcn/recom_lightgcn.py

Lines changed: 37 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -28,21 +28,18 @@ class LightGCN(Recommender):
2828
name: string, default: 'LightGCN'
2929
The name of the recommender model.
3030
31+
emb_size: int, default: 64
32+
Size of the node embeddings.
33+
3134
num_epochs: int, default: 1000
32-
Maximum number of iterations or the number of epochs
35+
Maximum number of iterations or the number of epochs.
3336
3437
learning_rate: float, default: 0.001
3538
The learning rate that determines the step size at each iteration
3639
37-
train_batch_size: int, default: 1024
40+
batch_size: int, default: 1024
3841
Mini-batch size used for train set
3942
40-
test_batch_size: int, default: 100
41-
Mini-batch size used for test set
42-
43-
hidden_dim: int, default: 64
44-
The embedding size of the model
45-
4643
num_layers: int, default: 3
4744
Number of LightGCN Layers
4845
@@ -80,11 +77,10 @@ class LightGCN(Recommender):
8077
def __init__(
8178
self,
8279
name="LightGCN",
80+
emb_size=64,
8381
num_epochs=1000,
8482
learning_rate=0.001,
85-
train_batch_size=1024,
86-
test_batch_size=100,
87-
hidden_dim=64,
83+
batch_size=1024,
8884
num_layers=3,
8985
early_stopping=None,
9086
lambda_reg=1e-4,
@@ -93,13 +89,11 @@ def __init__(
9389
seed=2020,
9490
):
9591
super().__init__(name=name, trainable=trainable, verbose=verbose)
96-
92+
self.emb_size = emb_size
9793
self.num_epochs = num_epochs
9894
self.learning_rate = learning_rate
99-
self.hidden_dim = hidden_dim
95+
self.batch_size = batch_size
10096
self.num_layers = num_layers
101-
self.train_batch_size = train_batch_size
102-
self.test_batch_size = test_batch_size
10397
self.early_stopping = early_stopping
10498
self.lambda_reg = lambda_reg
10599
self.seed = seed
@@ -135,19 +129,15 @@ def fit(self, train_set, val_set=None):
135129
if torch.cuda.is_available():
136130
torch.cuda.manual_seed_all(self.seed)
137131

132+
graph = construct_graph(train_set).to(self.device)
138133
model = Model(
139-
train_set.total_users,
140-
train_set.total_items,
141-
self.hidden_dim,
134+
graph,
135+
self.emb_size,
142136
self.num_layers,
137+
self.lambda_reg,
143138
).to(self.device)
144139

145-
graph = construct_graph(train_set).to(self.device)
146-
147-
optimizer = torch.optim.Adam(
148-
model.parameters(), lr=self.learning_rate, weight_decay=self.lambda_reg
149-
)
150-
loss_fn = torch.nn.BCELoss(reduction="sum")
140+
optimizer = torch.optim.Adam(model.parameters(), lr=self.learning_rate)
151141

152142
# model training
153143
pbar = trange(
@@ -163,53 +153,43 @@ def fit(self, train_set, val_set=None):
163153
accum_loss = 0.0
164154
for batch_u, batch_i, batch_j in tqdm(
165155
train_set.uij_iter(
166-
batch_size=self.train_batch_size,
156+
batch_size=self.batch_size,
167157
shuffle=True,
168158
),
169159
desc="Epoch",
170-
total=train_set.num_batches(self.train_batch_size),
160+
total=train_set.num_batches(self.batch_size),
171161
leave=False,
172162
position=1,
173163
disable=not self.verbose,
174164
):
175-
user_embeddings, item_embeddings = model(graph)
176-
177-
batch_u = torch.from_numpy(batch_u).long().to(self.device)
178-
batch_i = torch.from_numpy(batch_i).long().to(self.device)
179-
batch_j = torch.from_numpy(batch_j).long().to(self.device)
180-
181-
user_embed = user_embeddings[batch_u]
182-
positive_item_embed = item_embeddings[batch_i]
183-
negative_item_embed = item_embeddings[batch_j]
184-
185-
ui_scores = (user_embed * positive_item_embed).sum(dim=1)
186-
uj_scores = (user_embed * negative_item_embed).sum(dim=1)
165+
u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings = model(
166+
graph, batch_u, batch_i, batch_j
167+
)
187168

188-
loss = loss_fn(
189-
torch.sigmoid(ui_scores - uj_scores), torch.ones_like(ui_scores)
169+
batch_loss, batch_bpr_loss, batch_reg_loss = model.loss_fn(
170+
u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings
190171
)
191-
accum_loss += loss.cpu().item()
172+
accum_loss += batch_loss.cpu().item() * len(batch_u)
192173

193174
optimizer.zero_grad()
194-
loss.backward()
175+
batch_loss.backward()
195176
optimizer.step()
196177

197178
accum_loss /= len(train_set.uir_tuple[0]) # normalize over all observations
198179
pbar.set_postfix(loss=accum_loss)
199180

200181
# store user and item embedding matrices for prediction
201182
model.eval()
202-
self.U, self.V = model(graph)
183+
u_embs, i_embs, _ = model(graph)
184+
# we will use numpy for faster prediction in the score function, no need torch
185+
self.U = u_embs.cpu().detach().numpy()
186+
self.V = i_embs.cpu().detach().numpy()
203187

204188
if self.early_stopping is not None and self.early_stop(
205189
**self.early_stopping
206190
):
207191
break
208192

209-
# we will use numpy for faster prediction in the score function, no need torch
210-
self.U = self.U.cpu().detach().numpy()
211-
self.V = self.V.cpu().detach().numpy()
212-
213193
def monitor_value(self):
214194
"""Calculating monitored value used for early stopping on validation set (`val_set`).
215195
This function will be called by `early_stop()` function.
@@ -223,38 +203,17 @@ def monitor_value(self):
223203
if self.val_set is None:
224204
return None
225205

226-
import torch
206+
from ...metrics import Recall
207+
from ...eval_methods import ranking_eval
227208

228-
loss_fn = torch.nn.BCELoss(reduction="sum")
229-
accum_loss = 0.0
230-
pbar = tqdm(
231-
self.val_set.uij_iter(batch_size=self.test_batch_size),
232-
desc="Validation",
233-
total=self.val_set.num_batches(self.test_batch_size),
234-
leave=False,
235-
position=1,
236-
disable=not self.verbose,
237-
)
238-
for batch_u, batch_i, batch_j in pbar:
239-
batch_u = torch.from_numpy(batch_u).long().to(self.device)
240-
batch_i = torch.from_numpy(batch_i).long().to(self.device)
241-
batch_j = torch.from_numpy(batch_j).long().to(self.device)
242-
243-
user_embed = self.U[batch_u]
244-
positive_item_embed = self.V[batch_i]
245-
negative_item_embed = self.V[batch_j]
246-
247-
ui_scores = (user_embed * positive_item_embed).sum(dim=1)
248-
uj_scores = (user_embed * negative_item_embed).sum(dim=1)
249-
250-
loss = loss_fn(
251-
torch.sigmoid(ui_scores - uj_scores), torch.ones_like(ui_scores)
252-
)
253-
accum_loss += loss.cpu().item()
254-
pbar.set_postfix(val_loss=accum_loss)
255-
256-
accum_loss /= len(self.val_set.uir_tuple[0])
257-
return -accum_loss # higher is better -> smaller loss is better
209+
recall_20 = ranking_eval(
210+
model=self,
211+
metrics=[Recall(k=20)],
212+
train_set=self.train_set,
213+
test_set=self.val_set
214+
)[0][0]
215+
216+
return recall_20 # Section 4.1.2 in the paper, same strategy as NGCF.
258217

259218
def score(self, user_idx, item_idx=None):
260219
"""Predict the scores/ratings of a user for an item.

0 commit comments

Comments
 (0)