Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 12 additions & 15 deletions chebai/loss/bce_weighted.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,32 +50,29 @@ def set_pos_weight(self, input: torch.Tensor) -> None:
and self.data_extractor is not None
and all(
os.path.exists(
os.path.join(self.data_extractor.processed_dir_main, file_name)
os.path.join(self.data_extractor.processed_dir, file_name)
)
for file_name in self.data_extractor.processed_main_file_names
for file_name in self.data_extractor.processed_file_names
)
and self.pos_weight is None
):
print(
f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})"
)
complete_data = pd.concat(
complete_labels = torch.concat(
[
pd.read_pickle(
open(
os.path.join(
self.data_extractor.processed_dir_main,
file_name,
),
"rb",
)
torch.stack(
[
torch.Tensor(row["labels"])
for row in self.data_extractor.load_processed_data(
filename=file_name
)
]
)
for file_name in self.data_extractor.processed_main_file_names
for file_name in self.data_extractor.processed_file_names
]
)
value_counts = []
for c in complete_data.columns[3:]:
value_counts.append(len([v for v in complete_data[c] if v]))
value_counts = complete_labels.sum(dim=0)
weights = [
(1 - self.beta) / (1 - pow(self.beta, value)) for value in value_counts
]
Expand Down
2 changes: 1 addition & 1 deletion chebai/result/classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def print_metrics(
print(f"Micro-Recall: {recall_micro(preds, labels):3f}")
if markdown_output:
print(
f"| Model | Macro-F1 | Micro-F1 | Macro-Precision | Micro-Precision | Macro-Recall | Micro-Recall | Balanced Accuracy"
f"| Model | Macro-F1 | Micro-F1 | Macro-Precision | Micro-Precision | Macro-Recall | Micro-Recall | Balanced Accuracy |"
)
print(f"| --- | --- | --- | --- | --- | --- | --- | --- |")
print(
Expand Down
11 changes: 6 additions & 5 deletions chebai/result/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,11 +156,12 @@ def evaluate_model(
return test_preds, test_labels
return test_preds, None
elif len(preds_list) < 0:
torch.save(
_concat_tuple(preds_list),
os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"),
)
if labels_list[0] is not None:
if len(preds_list) > 0 and preds_list[0] is not None:
torch.save(
_concat_tuple(preds_list),
os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"),
)
if len(labels_list) > 0 and labels_list[0] is not None:
torch.save(
_concat_tuple(labels_list),
os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"),
Expand Down