@@ -137,6 +137,7 @@ def train(
137137 )
138138 self .best_u = pd .DataFrame (best_u , index = drugs )
139139 self .best_v = pd .DataFrame (best_v , index = cell_lines )
140+ self .training_mean = np .nanmean (output ._response ) # Store training mean
140141
141142 def predict (
142143 self ,
@@ -154,12 +155,21 @@ def predict(
154155 :param drug_input: not needed for prediction in SRMF
155156 :returns: predicted response matrix
156157 """
157- best_u = self .best_u .loc [drug_ids ].values
158- best_v = self .best_v .loc [cell_line_ids ].values
158+ # Use training mean for missing drugs
159+ best_u = np .full ((len (drug_ids ), self .k ), self .training_mean )
160+ for idx , drug in enumerate (drug_ids ):
161+ if drug in self .best_u .index :
162+ best_u [idx , :] = self .best_u .loc [drug ].values
163+
164+ # Use training mean for missing cell lines
165+ best_v = np .full ((len (cell_line_ids ), self .k ), self .training_mean )
166+ for idx , cell in enumerate (cell_line_ids ):
167+ if cell in self .best_v .index :
168+ best_v [idx , :] = self .best_v .loc [cell ].values
169+
159170 # calculate the diagonal of the matrix product which is the prediction,
160171 # faster than np.dot(best_u, best_v.T).diagonal()
161172 diagonal_predictions = np .einsum ("ij,ji->i" , best_u , best_v .T )
162-
163173 return diagonal_predictions
164174
165175 def _cmf (self , w , int_mat , drug_mat , cell_mat ) -> tuple [np .ndarray , np .ndarray ]:
0 commit comments