|
2 | 2 | import os |
3 | 3 | import traceback |
4 | 4 | from datetime import datetime |
| 5 | +from pathlib import Path |
5 | 6 | from typing import List, LiteralString, Optional, Tuple |
6 | 7 |
|
7 | 8 | import pandas as pd |
@@ -155,9 +156,11 @@ def get_disjoint_groups(disjoint_files): |
155 | 156 | disjoint_files = os.path.join("data", "chebi-disjoints.owl") |
156 | 157 | disjoint_pairs, disjoint_groups = [], [] |
157 | 158 | for file in disjoint_files: |
158 | | - if file.split(".")[-1] == "csv": |
| 159 | + if isinstance(file, Path): |
| 160 | + file = str(file) |
| 161 | + if file.endswith(".csv"): |
159 | 162 | disjoint_pairs += pd.read_csv(file, header=None).values.tolist() |
160 | | - elif file.split(".")[-1] == "owl": |
| 163 | + elif file.endswith(".owl"): |
161 | 164 | with open(file, "r") as f: |
162 | 165 | plaintext = f.read() |
163 | 166 | segments = plaintext.split("<") |
@@ -217,10 +220,16 @@ def set_label_names(self, label_names): |
217 | 220 | self.label_successors = self.label_successors.unsqueeze(0) |
218 | 221 |
|
219 | 222 | def __call__(self, preds): |
| 223 | + if preds.shape[1] == 0: |
| 224 | + # no labels predicted |
| 225 | + return preds |
| 226 | + # preds shape: (n_samples, n_labels) |
220 | 227 | preds_sum_orig = torch.sum(preds) |
221 | 228 | # step 1: apply implications: for each class, set prediction to max of itself and all successors |
222 | 229 | preds = preds.unsqueeze(1) |
223 | 230 | preds_masked_succ = torch.where(self.label_successors, preds, 0) |
| 231 | + # preds_masked_succ shape: (n_samples, n_labels, n_labels) |
| 232 | + |
224 | 233 | preds = preds_masked_succ.max(dim=2).values |
225 | 234 | if torch.sum(preds) != preds_sum_orig: |
226 | 235 | print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}") |
|
0 commit comments