@@ -119,7 +119,6 @@ def evaluate_model(
119119 data_list = data_list [: data_module .data_limit ]
120120 preds_list = []
121121 labels_list = []
122- weights_list = []
123122 if buffer_dir is not None :
124123 os .makedirs (buffer_dir , exist_ok = True )
125124 save_ind = 0
@@ -135,8 +134,6 @@ def evaluate_model(
135134 preds , labels = _run_batch (data_list [i : i + batch_size ], model , collate )
136135 preds_list .append (preds )
137136 labels_list .append (labels )
138- for j in range (i ,i + batch_size ):
139- weights_list .append (data_list [j ])
140137
141138 if buffer_dir is not None :
142139 if n_saved * batch_size >= save_batch_size :
@@ -174,69 +171,6 @@ def evaluate_model(
174171 os .path .join (buffer_dir , f"labels{ save_ind :03d} .pt" ),
175172 )
176173
177-
178- def evaluate_model_weights (
179- model : ChebaiBaseNet ,
180- data_module : XYBaseDataModule ,
181- filename : Optional [str ] = None ,
182- buffer_dir : Optional [str ] = None ,
183- batch_size : int = 32 ,
184- skip_existing_preds : bool = False ,
185- kind : str = "test" ,
186- ):
187- """
188- Runs the model on the test set of the data module or on the dataset found in the specified file.
189- If buffer_dir is set, results will be saved in buffer_dir.
190-
191- Note:
192- No need to provide "filename" parameter for Chebi dataset, "kind" parameter should be provided.
193-
194- Args:
195- model: The model to evaluate.
196- data_module: The data module containing the dataset.
197- filename: Optional file name for the dataset.
198- buffer_dir: Optional directory to save the results.
199- batch_size: The batch size for evaluation.
200- skip_existing_preds: Whether to skip evaluation if predictions already exist.
201- kind: Kind of split of the data to be used for testing the model. Default is `test`.
202-
203- Returns:
204- Tensors with predictions and labels.
205- """
206- model .eval ()
207- collate = data_module .reader .COLLATOR ()
208-
209- if isinstance (data_module , _ChEBIDataExtractor ):
210- # As the dynamic split change is implemented only for chebi-dataset as of now
211- data_df = data_module .dynamic_split_dfs [kind ]
212- data_list = data_df .to_dict (orient = "records" )
213- else :
214- data_list = data_module .load_processed_data ("test" , filename )
215- data_list = data_list [: data_module .data_limit ]
216- preds_list = []
217- labels_list = []
218- weights_list = []
219- if buffer_dir is not None :
220- os .makedirs (buffer_dir , exist_ok = True )
221- save_ind = 0
222- save_batch_size = 128
223- n_saved = 1
224-
225- print ("" )
226- for i in tqdm .tqdm (range (0 , len (data_list ), batch_size )):
227- if not (
228- skip_existing_preds
229- and os .path .isfile (os .path .join (buffer_dir , f"preds{ save_ind :03d} .pt" ))
230- ):
231- preds , labels = _run_batch (data_list [i : i + batch_size ], model , collate )
232- preds_list .append (preds )
233- labels_list .append (labels )
234-
235- result = create_weight_dict (preds_list ,labels_list ,data_list )
236- torch .save (result ,"./result.pt" )
237-
238-
239-
240174def load_results_from_buffer (
241175 buffer_dir : str , device : torch .device
242176) -> Tuple [Optional [torch .Tensor ], Optional [torch .Tensor ]]:
0 commit comments