Skip to content

Commit 5661b64

Browse files
authored
Merge pull request #78 from ChEB-AI/feature/weighted-bce-tokenized
base BCE weighted on tokenized data
2 parents ff03d6f + 0c2b99d commit 5661b64

File tree

3 files changed

+19
-21
lines changed

3 files changed

+19
-21
lines changed

chebai/loss/bce_weighted.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -50,32 +50,29 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
5050
and self.data_extractor is not None
5151
and all(
5252
os.path.exists(
53-
os.path.join(self.data_extractor.processed_dir_main, file_name)
53+
os.path.join(self.data_extractor.processed_dir, file_name)
5454
)
55-
for file_name in self.data_extractor.processed_main_file_names
55+
for file_name in self.data_extractor.processed_file_names
5656
)
5757
and self.pos_weight is None
5858
):
5959
print(
6060
f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})"
6161
)
62-
complete_data = pd.concat(
62+
complete_labels = torch.concat(
6363
[
64-
pd.read_pickle(
65-
open(
66-
os.path.join(
67-
self.data_extractor.processed_dir_main,
68-
file_name,
69-
),
70-
"rb",
71-
)
64+
torch.stack(
65+
[
66+
torch.Tensor(row["labels"])
67+
for row in self.data_extractor.load_processed_data(
68+
filename=file_name
69+
)
70+
]
7271
)
73-
for file_name in self.data_extractor.processed_main_file_names
72+
for file_name in self.data_extractor.processed_file_names
7473
]
7574
)
76-
value_counts = []
77-
for c in complete_data.columns[3:]:
78-
value_counts.append(len([v for v in complete_data[c] if v]))
75+
value_counts = complete_labels.sum(dim=0)
7976
weights = [
8077
(1 - self.beta) / (1 - pow(self.beta, value)) for value in value_counts
8178
]

chebai/result/classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def print_metrics(
7878
print(f"Micro-Recall: {recall_micro(preds, labels):3f}")
7979
if markdown_output:
8080
print(
81-
f"| Model | Macro-F1 | Micro-F1 | Macro-Precision | Micro-Precision | Macro-Recall | Micro-Recall | Balanced Accuracy"
81+
f"| Model | Macro-F1 | Micro-F1 | Macro-Precision | Micro-Precision | Macro-Recall | Micro-Recall | Balanced Accuracy |"
8282
)
8383
print(f"| --- | --- | --- | --- | --- | --- | --- | --- |")
8484
print(

chebai/result/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,12 @@ def evaluate_model(
156156
return test_preds, test_labels
157157
return test_preds, None
158158
elif len(preds_list) < 0:
159-
torch.save(
160-
_concat_tuple(preds_list),
161-
os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"),
162-
)
163-
if labels_list[0] is not None:
159+
if len(preds_list) > 0 and preds_list[0] is not None:
160+
torch.save(
161+
_concat_tuple(preds_list),
162+
os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"),
163+
)
164+
if len(labels_list) > 0 and labels_list[0] is not None:
164165
torch.save(
165166
_concat_tuple(labels_list),
166167
os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"),

0 commit comments

Comments
 (0)