-
I am currently learning about graph neural networks and followed this book by Maxime Labonne on using LightGCN for book recommendation. The example is to use LightGCN for recommendation, where we have user and item interactions
# Build the adjacency matrix based on user ratings
user_ids = torch.LongTensor([user_mapping[i] for i in df['User-ID']])
item_ids = torch.LongTensor([item_mapping[i] for i in df['ISBN']])
edge_index = torch.stack((user_ids, item_ids))
# use LGConv from pytorch geometric
from torch_geometric.nn import LGConv
class LightGCN(nn.Module):
def __init__(self, num_users, num_items, num_layers=4, dim_h=64):
super().__init__()
self.num_users = num_users
self.num_items = num_items
self.num_layers = num_layers
self.emb_users = nn.Embedding(num_embeddings=self.num_users, embedding_dim=dim_h)
self.emb_items = nn.Embedding(num_embeddings=self.num_items, embedding_dim=dim_h)
self.convs = nn.ModuleList(LGConv() for _ in range(num_layers))
nn.init.normal_(self.emb_users.weight, std=0.01)
nn.init.normal_(self.emb_items.weight, std=0.01)
def forward(self, edge_index):
emb = torch.cat([self.emb_users.weight, self.emb_items.weight])
embs = [emb]
for conv in self.convs:
###############################################################
######### I am stuck at understanding the part here!!! ##########
emb = conv(x=emb, edge_index=edge_index)
###############################################################
embs.append(emb)
emb_final = 1/(self.num_layers+1) * torch.mean(torch.stack(embs, dim=1), dim=1)
emb_users_final, emb_items_final = torch.split(emb_final, [self.num_users, self.num_items])
return emb_users_final, self.emb_users.weight, emb_items_final, self.emb_items.weight My question:
To elaborate:
edit: |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
|
Beta Was this translation helpful? Give feedback.
LightGCN
implementation (and the one you posted) requires as well that all node IDs are unique (i.e., user IDs should range from (0, num_users) and item IDs should range from (num_users, num_users + num_items).