1414from chebai .preprocessing .datasets .base import XYBaseDataModule
1515from chebai .preprocessing .datasets .chebi import _ChEBIDataExtractor
1616
17+ from extras .ev_model import create_weight_dict
18+
1719
1820def get_checkpoint_from_wandb (
1921 epoch : int ,
@@ -117,6 +119,7 @@ def evaluate_model(
117119 data_list = data_list [: data_module .data_limit ]
118120 preds_list = []
119121 labels_list = []
122+ weights_list = []
120123 if buffer_dir is not None :
121124 os .makedirs (buffer_dir , exist_ok = True )
122125 save_ind = 0
@@ -132,6 +135,8 @@ def evaluate_model(
132135 preds , labels = _run_batch (data_list [i : i + batch_size ], model , collate )
133136 preds_list .append (preds )
134137 labels_list .append (labels )
138+ for j in range (i ,i + batch_size ):
139+ weights_list .append (data_list [j ])
135140
136141 if buffer_dir is not None :
137142 if n_saved * batch_size >= save_batch_size :
@@ -170,6 +175,68 @@ def evaluate_model(
170175 )
171176
172177
178+ def evaluate_model_weights (
179+ model : ChebaiBaseNet ,
180+ data_module : XYBaseDataModule ,
181+ filename : Optional [str ] = None ,
182+ buffer_dir : Optional [str ] = None ,
183+ batch_size : int = 32 ,
184+ skip_existing_preds : bool = False ,
185+ kind : str = "test" ,
186+ ) -> Tuple [torch .Tensor , Optional [torch .Tensor ]]:
187+ """
188+ Runs the model on the test set of the data module or on the dataset found in the specified file.
189+ If buffer_dir is set, results will be saved in buffer_dir.
190+
191+ Note:
192+ No need to provide "filename" parameter for Chebi dataset, "kind" parameter should be provided.
193+
194+ Args:
195+ model: The model to evaluate.
196+ data_module: The data module containing the dataset.
197+ filename: Optional file name for the dataset.
198+ buffer_dir: Optional directory to save the results.
199+ batch_size: The batch size for evaluation.
200+ skip_existing_preds: Whether to skip evaluation if predictions already exist.
201+ kind: Kind of split of the data to be used for testing the model. Default is `test`.
202+
203+ Returns:
204+ Tensors with predictions and labels.
205+ """
206+ model .eval ()
207+ collate = data_module .reader .COLLATOR ()
208+
209+ if isinstance (data_module , _ChEBIDataExtractor ):
210+ # As the dynamic split change is implemented only for chebi-dataset as of now
211+ data_df = data_module .dynamic_split_dfs [kind ]
212+ data_list = data_df .to_dict (orient = "records" )
213+ else :
214+ data_list = data_module .load_processed_data ("test" , filename )
215+ data_list = data_list [: data_module .data_limit ]
216+ preds_list = []
217+ labels_list = []
218+ weights_list = []
219+ if buffer_dir is not None :
220+ os .makedirs (buffer_dir , exist_ok = True )
221+ save_ind = 0
222+ save_batch_size = 128
223+ n_saved = 1
224+
225+ print ("" )
226+ for i in tqdm .tqdm (range (0 , len (data_list ), batch_size )):
227+ if not (
228+ skip_existing_preds
229+ and os .path .isfile (os .path .join (buffer_dir , f"preds{ save_ind :03d} .pt" ))
230+ ):
231+ preds , labels = _run_batch (data_list [i : i + batch_size ], model , collate )
232+ preds_list .append (preds )
233+ labels_list .append (labels )
234+
235+ result = create_weight_dict (preds_list ,labels_list ,data_list )
236+ torch .save (result ,"./result.pt" )
237+
238+
239+
173240def load_results_from_buffer (
174241 buffer_dir : str , device : torch .device
175242) -> Tuple [Optional [torch .Tensor ], Optional [torch .Tensor ]]:
0 commit comments