@@ -133,6 +133,41 @@ def _generate_dec_graph(data_set):
133133 )
134134
135135
136+ def _generate_test_user_graph (user_idx , total_users , total_items ):
137+ """
138+ Generates decoding graph given a cornac data set
139+
140+ Parameters
141+ ----------
142+ data_set : cornac.data.dataset.Dataset
143+ The data set as provided by cornac
144+
145+ Returns
146+ -------
147+ graph : dgl.heterograph
148+ Heterograph containing user-item edges and nodes
149+ """
150+ u_list = np .array ([user_idx for _ in range (total_items )])
151+ i_list = np .array ([item_idx for item_idx in range (total_items )])
152+
153+ rating_pairs = (u_list , i_list )
154+ ones = np .ones_like (rating_pairs [0 ])
155+ user_item_ratings_coo = sp .coo_matrix (
156+ (ones , rating_pairs ),
157+ shape = (total_users , total_items ),
158+ dtype = np .float32 ,
159+ )
160+
161+ graph = dgl .bipartite_from_scipy (
162+ user_item_ratings_coo , utype = "_U" , etype = "_E" , vtype = "_V"
163+ )
164+
165+ return dgl .heterograph (
166+ {("user" , "rate" , "item" ): graph .edges ()},
167+ num_nodes_dict = {"user" : total_users , "item" : total_items },
168+ )
169+
170+
136171class Model :
137172 def __init__ (
138173 self ,
@@ -479,7 +514,11 @@ def predict(self, test_set):
479514
480515 test_pred_ratings = test_pred_ratings .cpu ().numpy ()
481516
482- (u_list , i_list , _ ) = test_set .uir_tuple
517+ uid_list = test_set .uir_tuple [0 ]
518+ uid_list = np .unique (uid_list )
519+
520+ u_list = np .array ([user_idx for _ in range (test_set .total_items ) for user_idx in uid_list ])
521+ i_list = np .array ([item_idx for item_idx in range (test_set .total_items ) for _ in uid_list ])
483522
484523 u_list = u_list .tolist ()
485524 i_list = i_list .tolist ()
@@ -489,3 +528,39 @@ def predict(self, test_set):
489528 for idx , rating in enumerate (test_pred_ratings )
490529 }
491530 return u_i_rating_dict
531+
532+ def predict_one (self , train_set , user_idx ):
533+ """
534+ Processes single user_idx from test set and returns numpy list of scores
535+ for all items.
536+
537+ Parameters
538+ ----------
539+ train_set : cornac.data.dataset.Dataset
540+ The train set as provided by cornac
541+
542+ Returns
543+ -------
544+ test_pred_ratings : numpy.array
545+ Numpy array containing all ratings for the given user_idx.
546+ """
547+ test_dec_graph = _generate_test_user_graph (user_idx , train_set .total_users , train_set .total_items )
548+ test_dec_graph = test_dec_graph .int ().to (self .device )
549+
550+ self .net .eval ()
551+
552+ with torch .no_grad ():
553+ pred_ratings = self .net (self .train_enc_graph , test_dec_graph )
554+
555+ test_rating_values = train_set .uir_tuple [2 ]
556+ test_rating_values = np .unique (test_rating_values )
557+
558+ nd_positive_rating_values = torch .FloatTensor (test_rating_values ).to (
559+ self .device
560+ )
561+
562+ test_pred_ratings = (
563+ torch .softmax (pred_ratings , dim = 1 ) * nd_positive_rating_values .view (1 , - 1 )
564+ ).sum (dim = 1 )
565+
566+ return test_pred_ratings .cpu ().numpy ()
0 commit comments