Skip to content

Commit 9dfdf69

Browse files
committed
use processed data instead of processed-main for BCE
1 parent ff03d6f commit 9dfdf69

File tree

1 file changed

+9
-16
lines changed

1 file changed

+9
-16
lines changed

chebai/loss/bce_weighted.py

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,32 +50,25 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
5050
and self.data_extractor is not None
5151
and all(
5252
os.path.exists(
53-
os.path.join(self.data_extractor.processed_dir_main, file_name)
53+
os.path.join(self.data_extractor.processed_dir, file_name)
5454
)
55-
for file_name in self.data_extractor.processed_main_file_names
55+
for file_name in self.data_extractor.processed_file_names
5656
)
5757
and self.pos_weight is None
5858
):
5959
print(
6060
f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})"
6161
)
62-
complete_data = pd.concat(
62+
complete_labels = torch.concat(
6363
[
64-
pd.read_pickle(
65-
open(
66-
os.path.join(
67-
self.data_extractor.processed_dir_main,
68-
file_name,
69-
),
70-
"rb",
71-
)
72-
)
73-
for file_name in self.data_extractor.processed_main_file_names
64+
torch.stack([
65+
torch.Tensor(row["labels"]) for row in
66+
self.data_extractor.load_processed_data(filename=file_name)
67+
])
68+
for file_name in self.data_extractor.processed_file_names
7469
]
7570
)
76-
value_counts = []
77-
for c in complete_data.columns[3:]:
78-
value_counts.append(len([v for v in complete_data[c] if v]))
71+
value_counts = complete_labels.sum(dim=0)
7972
weights = [
8073
(1 - self.beta) / (1 - pow(self.beta, value)) for value in value_counts
8174
]

0 commit comments

Comments
 (0)