File tree Expand file tree Collapse file tree 1 file changed +9
-16
lines changed
Expand file tree Collapse file tree 1 file changed +9
-16
lines changed Original file line number Diff line number Diff line change @@ -50,32 +50,25 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
5050 and self .data_extractor is not None
5151 and all (
5252 os .path .exists (
53- os .path .join (self .data_extractor .processed_dir_main , file_name )
53+ os .path .join (self .data_extractor .processed_dir , file_name )
5454 )
55- for file_name in self .data_extractor .processed_main_file_names
55+ for file_name in self .data_extractor .processed_file_names
5656 )
5757 and self .pos_weight is None
5858 ):
5959 print (
6060 f"Computing loss-weights based on v{ self .data_extractor .chebi_version } dataset (beta={ self .beta } )"
6161 )
62- complete_data = pd .concat (
62+ complete_labels = torch .concat (
6363 [
64- pd .read_pickle (
65- open (
66- os .path .join (
67- self .data_extractor .processed_dir_main ,
68- file_name ,
69- ),
70- "rb" ,
71- )
72- )
73- for file_name in self .data_extractor .processed_main_file_names
64+ torch .stack ([
65+ torch .Tensor (row ["labels" ]) for row in
66+ self .data_extractor .load_processed_data (filename = file_name )
67+ ])
68+ for file_name in self .data_extractor .processed_file_names
7469 ]
7570 )
76- value_counts = []
77- for c in complete_data .columns [3 :]:
78- value_counts .append (len ([v for v in complete_data [c ] if v ]))
71+ value_counts = complete_labels .sum (dim = 0 )
7972 weights = [
8073 (1 - self .beta ) / (1 - pow (self .beta , value )) for value in value_counts
8174 ]
You can’t perform that action at this time.
0 commit comments