File tree Expand file tree Collapse file tree 3 files changed +19
-21
lines changed
Expand file tree Collapse file tree 3 files changed +19
-21
lines changed Original file line number Diff line number Diff line change @@ -50,32 +50,29 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
5050 and self .data_extractor is not None
5151 and all (
5252 os .path .exists (
53- os .path .join (self .data_extractor .processed_dir_main , file_name )
53+ os .path .join (self .data_extractor .processed_dir , file_name )
5454 )
55- for file_name in self .data_extractor .processed_main_file_names
55+ for file_name in self .data_extractor .processed_file_names
5656 )
5757 and self .pos_weight is None
5858 ):
5959 print (
6060 f"Computing loss-weights based on v{ self .data_extractor .chebi_version } dataset (beta={ self .beta } )"
6161 )
62- complete_data = pd .concat (
62+ complete_labels = torch .concat (
6363 [
64- pd .read_pickle (
65- open (
66- os .path .join (
67- self .data_extractor .processed_dir_main ,
68- file_name ,
69- ),
70- "rb" ,
71- )
64+ torch .stack (
65+ [
66+ torch .Tensor (row ["labels" ])
67+ for row in self .data_extractor .load_processed_data (
68+ filename = file_name
69+ )
70+ ]
7271 )
73- for file_name in self .data_extractor .processed_main_file_names
72+ for file_name in self .data_extractor .processed_file_names
7473 ]
7574 )
76- value_counts = []
77- for c in complete_data .columns [3 :]:
78- value_counts .append (len ([v for v in complete_data [c ] if v ]))
75+ value_counts = complete_labels .sum (dim = 0 )
7976 weights = [
8077 (1 - self .beta ) / (1 - pow (self .beta , value )) for value in value_counts
8178 ]
Original file line number Diff line number Diff line change @@ -78,7 +78,7 @@ def print_metrics(
7878 print (f"Micro-Recall: { recall_micro (preds , labels ):3f} " )
7979 if markdown_output :
8080 print (
81- f"| Model | Macro-F1 | Micro-F1 | Macro-Precision | Micro-Precision | Macro-Recall | Micro-Recall | Balanced Accuracy"
81+ f"| Model | Macro-F1 | Micro-F1 | Macro-Precision | Micro-Precision | Macro-Recall | Micro-Recall | Balanced Accuracy | "
8282 )
8383 print (f"| --- | --- | --- | --- | --- | --- | --- | --- |" )
8484 print (
Original file line number Diff line number Diff line change @@ -156,11 +156,12 @@ def evaluate_model(
156156 return test_preds , test_labels
157157 return test_preds , None
158158 elif len (preds_list ) < 0 :
159- torch .save (
160- _concat_tuple (preds_list ),
161- os .path .join (buffer_dir , f"preds{ save_ind :03d} .pt" ),
162- )
163- if labels_list [0 ] is not None :
159+ if len (preds_list ) > 0 and preds_list [0 ] is not None :
160+ torch .save (
161+ _concat_tuple (preds_list ),
162+ os .path .join (buffer_dir , f"preds{ save_ind :03d} .pt" ),
163+ )
164+ if len (labels_list ) > 0 and labels_list [0 ] is not None :
164165 torch .save (
165166 _concat_tuple (labels_list ),
166167 os .path .join (buffer_dir , f"labels{ save_ind :03d} .pt" ),
You can’t perform that action at this time.
0 commit comments