|
1 | 1 | from .experiment import Experiment, label_map_ext
|
| 2 | +from sota_extractor2.models.structure.nbsvm import * |
| 3 | +from sklearn.metrics import confusion_matrix |
2 | 4 | from .nbsvm import preds_for_cell_content, preds_for_cell_content_max, preds_for_cell_content_multi
|
3 | 5 | import dataclasses
|
4 | 6 | from dataclasses import dataclass
|
@@ -134,3 +136,135 @@ def evaluate(self, model, train_df, valid_df, test_df):
|
134 | 136 | true_y_ext = tdf["cell_type"].apply(lambda x: label_map_ext.get(x, 0))
|
135 | 137 | self._set_results(prefix, preds, true_y, true_y_ext)
|
136 | 138 | self._preds.append(probs)
|
| 139 | + |
| 140 | + |
| 141 | +@dataclass |
| 142 | +class ULMFiTTableTypeExperiment(ULMFiTExperiment): |
| 143 | + sigmoid: bool = True |
| 144 | + distinguish_ablation: bool = True |
| 145 | + irrelevant_as_class: bool = False |
| 146 | + caption: bool = True |
| 147 | + first_row: bool = False |
| 148 | + first_column: bool = False |
| 149 | + referencing_sections: bool = False |
| 150 | + dedup_seqs: bool = False |
| 151 | + |
| 152 | + def _save_model(self, path): |
| 153 | + pass |
| 154 | + |
| 155 | + def _transform_df(self, df): |
| 156 | + df = df.copy(True) |
| 157 | + if self.sigmoid: |
| 158 | + if self.irrelevant_as_class: |
| 159 | + df["irrelevant"] = ~(df["sota"] | df["ablation"]) |
| 160 | + if not self.distinguish_ablation: |
| 161 | + df["sota"] = df["sota"] | df["ablation"] |
| 162 | + df = df.drop(columns=["ablation"]) |
| 163 | + else: |
| 164 | + if self.distinguish_ablation: |
| 165 | + df["class"] = 2 |
| 166 | + df.loc[df.ablation, "class"] = 1 |
| 167 | + df.loc[df.sota, "class"] = 0 |
| 168 | + else: |
| 169 | + df["class"] = 1 |
| 170 | + df.loc[df.sota, "class"] = 0 |
| 171 | + df.loc[df.ablation, "class"] = 0 |
| 172 | + |
| 173 | + df["label"] = 2 |
| 174 | + df.loc[df.ablation, "label"] = 1 |
| 175 | + df.loc[df.sota, "label"] = 0 |
| 176 | + drop_columns = [] |
| 177 | + if not self.caption: |
| 178 | + drop_columns.append("caption") |
| 179 | + if not self.first_column: |
| 180 | + drop_columns.append("col0") |
| 181 | + if not self.first_row: |
| 182 | + drop_columns.append("row0") |
| 183 | + if not self.referencing_sections: |
| 184 | + drop_columns.append("sections") |
| 185 | + df = df.drop(columns=drop_columns) |
| 186 | + return df |
| 187 | + |
| 188 | + def evaluate(self, model, train_df, valid_df, test_df): |
| 189 | + valid_probs = model.get_preds(ds_type=DatasetType.Valid, ordered=True)[0].cpu().numpy() |
| 190 | + test_probs = model.get_preds(ds_type=DatasetType.Test, ordered=True)[0].cpu().numpy() |
| 191 | + train_probs = model.get_preds(ds_type=DatasetType.Train, ordered=True)[0].cpu().numpy() |
| 192 | + self._preds = [] |
| 193 | + |
| 194 | + def multipreds2preds(preds, threshold=0.5): |
| 195 | + bs = preds.shape[0] |
| 196 | + return np.concatenate([probs, np.ones((bs, 1)) * threshold], axis=-1).argmax(-1) |
| 197 | + |
| 198 | + for prefix, tdf, probs in zip(["train", "valid", "test"], |
| 199 | + [train_df, valid_df, test_df], |
| 200 | + [train_probs, valid_probs, test_probs]): |
| 201 | + |
| 202 | + if self.sigmoid and not self.irrelevant_as_class: |
| 203 | + preds = multipreds2preds(probs) |
| 204 | + else: |
| 205 | + preds = np.argmax(probs, axis=1) |
| 206 | + if not self.distinguish_ablation: |
| 207 | + preds *= 2 |
| 208 | + |
| 209 | + true_y = tdf["label"] |
| 210 | + self._set_results(prefix, preds, true_y) |
| 211 | + self._preds.append(probs) |
| 212 | + |
| 213 | + def _set_results(self, prefix, preds, true_y, true_y_ext=None): |
| 214 | + def metrics(preds, true_y): |
| 215 | + y = true_y |
| 216 | + p = preds |
| 217 | + |
| 218 | + if self.distinguish_ablation: |
| 219 | + g = {0: 0, 1: 0, 2: 1}.get |
| 220 | + bin_y = np.array([g(x) for x in y]) |
| 221 | + bin_p = np.array([g(x) for x in p]) |
| 222 | + irr = 2 |
| 223 | + else: |
| 224 | + bin_y = y |
| 225 | + bin_p = p |
| 226 | + irr = 1 |
| 227 | + |
| 228 | + acc = (p == y).mean() |
| 229 | + tp = ((y != irr) & (p == y)).sum() |
| 230 | + fp = ((p != irr) & (p != y)).sum() |
| 231 | + fn = ((y != irr) & (p == irr)).sum() |
| 232 | + |
| 233 | + bin_acc = (bin_p == bin_y).mean() |
| 234 | + bin_tp = ((bin_y != 1) & (bin_p == bin_y)).sum() |
| 235 | + bin_fp = ((bin_p != 1) & (bin_p != bin_y)).sum() |
| 236 | + bin_fn = ((bin_y != 1) & (bin_p == 1)).sum() |
| 237 | + |
| 238 | + prec = tp / (fp + tp) |
| 239 | + reca = tp / (fn + tp) |
| 240 | + bin_prec = bin_tp / (bin_fp + bin_tp) |
| 241 | + bin_reca = bin_tp / (bin_fn + bin_tp) |
| 242 | + return { |
| 243 | + "precision": prec, |
| 244 | + "accuracy": acc, |
| 245 | + "recall": reca, |
| 246 | + "TP": tp, |
| 247 | + "FP": fp, |
| 248 | + "bin_precision": bin_prec, |
| 249 | + "bin_accuracy": bin_acc, |
| 250 | + "bin_recall": bin_reca, |
| 251 | + "bin_TP": bin_tp, |
| 252 | + "bin_FP": bin_fp, |
| 253 | + } |
| 254 | + |
| 255 | + m = metrics(preds, true_y) |
| 256 | + r = {} |
| 257 | + r[f"{prefix}_accuracy"] = m["accuracy"] |
| 258 | + r[f"{prefix}_precision"] = m["precision"] |
| 259 | + r[f"{prefix}_recall"] = m["recall"] |
| 260 | + r[f"{prefix}_bin_accuracy"] = m["bin_accuracy"] |
| 261 | + r[f"{prefix}_bin_precision"] = m["bin_precision"] |
| 262 | + r[f"{prefix}_bin_recall"] = m["bin_recall"] |
| 263 | + r[f"{prefix}_cm"] = confusion_matrix(true_y, preds).tolist() |
| 264 | + self.update_results(**r) |
| 265 | + |
| 266 | + def get_cm_labels(self, cm): |
| 267 | + if len(cm) == 3: |
| 268 | + return ["SOTA", "ABLATION", "IRRELEVANT"] |
| 269 | + else: |
| 270 | + return ["SOTA", "IRRELEVANT"] |
0 commit comments