Skip to content

Commit a25054d

Browse files
authored
Fix GCMC Invalid Key Error (#533)
* Add missing default scores when item_idx is none * Added all items into test_dec_graph for all users * changed to create dec_graph for each user when scoring
1 parent 0955645 commit a25054d

File tree

2 files changed

+77
-7
lines changed

2 files changed

+77
-7
lines changed

cornac/models/gcmc/gcmc.py

Lines changed: 76 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,41 @@ def _generate_dec_graph(data_set):
133133
)
134134

135135

136+
def _generate_test_user_graph(user_idx, total_users, total_items):
137+
"""
138+
Generates decoding graph given a cornac data set
139+
140+
Parameters
141+
----------
142+
data_set : cornac.data.dataset.Dataset
143+
The data set as provided by cornac
144+
145+
Returns
146+
-------
147+
graph : dgl.heterograph
148+
Heterograph containing user-item edges and nodes
149+
"""
150+
u_list = np.array([user_idx for _ in range(total_items)])
151+
i_list = np.array([item_idx for item_idx in range(total_items)])
152+
153+
rating_pairs = (u_list, i_list)
154+
ones = np.ones_like(rating_pairs[0])
155+
user_item_ratings_coo = sp.coo_matrix(
156+
(ones, rating_pairs),
157+
shape=(total_users, total_items),
158+
dtype=np.float32,
159+
)
160+
161+
graph = dgl.bipartite_from_scipy(
162+
user_item_ratings_coo, utype="_U", etype="_E", vtype="_V"
163+
)
164+
165+
return dgl.heterograph(
166+
{("user", "rate", "item"): graph.edges()},
167+
num_nodes_dict={"user": total_users, "item": total_items},
168+
)
169+
170+
136171
class Model:
137172
def __init__(
138173
self,
@@ -479,7 +514,11 @@ def predict(self, test_set):
479514

480515
test_pred_ratings = test_pred_ratings.cpu().numpy()
481516

482-
(u_list, i_list, _) = test_set.uir_tuple
517+
uid_list = test_set.uir_tuple[0]
518+
uid_list = np.unique(uid_list)
519+
520+
u_list = np.array([user_idx for _ in range(test_set.total_items) for user_idx in uid_list])
521+
i_list = np.array([item_idx for item_idx in range(test_set.total_items) for _ in uid_list])
483522

484523
u_list = u_list.tolist()
485524
i_list = i_list.tolist()
@@ -489,3 +528,39 @@ def predict(self, test_set):
489528
for idx, rating in enumerate(test_pred_ratings)
490529
}
491530
return u_i_rating_dict
531+
532+
def predict_one(self, train_set, user_idx):
533+
"""
534+
Processes single user_idx from test set and returns numpy list of scores
535+
for all items.
536+
537+
Parameters
538+
----------
539+
train_set : cornac.data.dataset.Dataset
540+
The train set as provided by cornac
541+
542+
Returns
543+
-------
544+
test_pred_ratings : numpy.array
545+
Numpy array containing all ratings for the given user_idx.
546+
"""
547+
test_dec_graph = _generate_test_user_graph(user_idx, train_set.total_users, train_set.total_items)
548+
test_dec_graph = test_dec_graph.int().to(self.device)
549+
550+
self.net.eval()
551+
552+
with torch.no_grad():
553+
pred_ratings = self.net(self.train_enc_graph, test_dec_graph)
554+
555+
test_rating_values = train_set.uir_tuple[2]
556+
test_rating_values = np.unique(test_rating_values)
557+
558+
nd_positive_rating_values = torch.FloatTensor(test_rating_values).to(
559+
self.device
560+
)
561+
562+
test_pred_ratings = (
563+
torch.softmax(pred_ratings, dim=1) * nd_positive_rating_values.view(1, -1)
564+
).sum(dim=1)
565+
566+
return test_pred_ratings.cpu().numpy()

cornac/models/gcmc/recom_gcmc.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -213,11 +213,6 @@ def score(self, user_idx, item_idx=None):
213213
"""
214214
if item_idx is None:
215215
# Return scores of all items for a given user
216-
# - If item does not exist in test_set, we provide a default score
217-
# (as set in default_dict initialisation)
218-
return [
219-
self.u_i_rating_dict[f"{user_idx}-{idx}"]
220-
for idx in range(self.train_set.total_items)
221-
]
216+
return self.model.predict_one(self.train_set, user_idx)
222217
# Return score of known user/item
223218
return self.u_i_rating_dict.get(f"{user_idx}-{item_idx}", self.default_score())

0 commit comments

Comments
 (0)