Skip to content

Commit 22a7b79

Browse files
committed
add safeguards to evaluate_model, fix typo
1 parent 9dfdf69 commit 22a7b79

File tree

2 files changed

+7
-6
lines changed

2 files changed

+7
-6
lines changed

chebai/result/classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def print_metrics(
7878
print(f"Micro-Recall: {recall_micro(preds, labels):3f}")
7979
if markdown_output:
8080
print(
81-
f"| Model | Macro-F1 | Micro-F1 | Macro-Precision | Micro-Precision | Macro-Recall | Micro-Recall | Balanced Accuracy"
81+
f"| Model | Macro-F1 | Micro-F1 | Macro-Precision | Micro-Precision | Macro-Recall | Micro-Recall | Balanced Accuracy |"
8282
)
8383
print(f"| --- | --- | --- | --- | --- | --- | --- | --- |")
8484
print(

chebai/result/utils.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,12 @@ def evaluate_model(
156156
return test_preds, test_labels
157157
return test_preds, None
158158
elif len(preds_list) < 0:
159-
torch.save(
160-
_concat_tuple(preds_list),
161-
os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"),
162-
)
163-
if labels_list[0] is not None:
159+
if len(preds_list) > 0 and preds_list[0] is not None:
160+
torch.save(
161+
_concat_tuple(preds_list),
162+
os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"),
163+
)
164+
if len(labels_list) > 0 and labels_list[0] is not None:
164165
torch.save(
165166
_concat_tuple(labels_list),
166167
os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"),

0 commit comments

Comments
 (0)