From 5167ddc2c8d6f3ebb1fbe5680b4d87bc59022568 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 9 Jul 2025 10:18:57 -0700 Subject: [PATCH 1/4] use roc auc instead of accuracy --- .../evaluation/knowledge_distillation.py | 19 ++++++++++++++----- .../knowledge_distillation_teacher.py | 19 +++++++++++++++---- 2 files changed, 29 insertions(+), 9 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py b/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py index b3a1a75b8..c74433c12 100644 --- a/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py +++ b/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py @@ -6,7 +6,12 @@ import matplotlib.pyplot as plt import pandas as pd import seaborn as sns -from sklearn.metrics import accuracy_score, classification_report, f1_score +from sklearn.metrics import ( + accuracy_score, + classification_report, + f1_score, + roc_auc_score, +) # %% # Mantis @@ -82,14 +87,18 @@ lambda x: float(f1_score(x["label"], x["prediction_binary"])) ) +roc_auc_by_t = prediction.groupby(id_vars).apply( + lambda x: float(roc_auc_score(x["label"], x["prediction"])) +) + metrics_df = pd.DataFrame( - data={"accuracy": accuracy_by_t.values, "F1": f1_by_t.values}, + data={"accuracy": accuracy_by_t.values, "F1": f1_by_t.values, "ROC AUC": roc_auc_by_t.values}, index=f1_by_t.index, ).reset_index() metrics_long = metrics_df.melt( id_vars=id_vars, - value_vars=["accuracy"], + value_vars=["ROC AUC"], var_name="metric", value_name="score", ) @@ -105,7 +114,7 @@ linewidth=1.5, linestyles="--", ) - g.set_axis_labels("HPI", "accuracy") + g.set_axis_labels("HPI", "ROC AUC") sns.move_legend(g, "upper left", bbox_to_anchor=(0.35, 1.1)) g.figure.set_size_inches(3.5, 1.5) g.set(xlim=(-1, 7), ylim=(0.6, 1.0)) @@ -115,7 +124,7 @@ # %% g.figure.savefig( Path.home() - / "gdrive/publications/dynaCLR/2025_dynaCLR_paper/fig_manuscript_svg/figure_knowledge_distillation/figure_parts/accuracy_students.pdf", + / "gdrive/publications/dynaCLR/2025_dynaCLR_paper/fig_manuscript_svg/figure_knowledge_distillation/figure_parts/roc_auc_students.pdf", dpi=300, ) diff --git a/applications/contrastive_phenotyping/evaluation/knowledge_distillation_teacher.py b/applications/contrastive_phenotyping/evaluation/knowledge_distillation_teacher.py index 6afe391a7..01c9701ff 100644 --- a/applications/contrastive_phenotyping/evaluation/knowledge_distillation_teacher.py +++ b/applications/contrastive_phenotyping/evaluation/knowledge_distillation_teacher.py @@ -4,9 +4,13 @@ import matplotlib.pyplot as plt import pandas as pd import seaborn as sns -from iohub.ngff import open_ome_zarr from sklearn.linear_model import LogisticRegression -from sklearn.metrics import accuracy_score, classification_report, f1_score +from sklearn.metrics import ( + accuracy_score, + classification_report, + f1_score, + roc_auc_score, +) from viscy.representation.embedding_writer import read_embedding_dataset @@ -69,12 +73,14 @@ def load_features_and_annotations(embedding_path, annotation_path, filter_fn): model = model.fit(train_features, train_annotation) train_prediction = model.predict(train_features) val_prediction = model.predict(val_features) +val_prediction_score = model.predict_proba(val_features) print("Training\n", classification_report(train_annotation, train_prediction)) print("Validation\n", classification_report(val_annotation, val_prediction)) val_selected["label"] = val_selected["infection_state"].cat.codes val_selected["prediction_binary"] = val_prediction +val_selected["prediction"] = val_prediction_score[:, 1] # %% prediction = val_selected @@ -107,17 +113,22 @@ def load_features_and_annotations(embedding_path, annotation_path, filter_fn): lambda x: float(f1_score(x["label"], x["prediction_binary"])) ) +roc_auc_by_t = prediction.groupby(["stage"]).apply( + lambda x: float(roc_auc_score(x["label"], x["prediction"])) +) + metrics_df = pd.DataFrame( data={ "accuracy": accuracy_by_t.values, "F1": f1_by_t.values, + "ROC AUC": roc_auc_by_t.values, }, index=f1_by_t.index, ).reset_index() metrics_long = metrics_df.melt( id_vars=["stage"], - value_vars=["accuracy"], + value_vars=["ROC AUC"], var_name="metric", value_name="score", ) @@ -133,7 +144,7 @@ def load_features_and_annotations(embedding_path, annotation_path, filter_fn): legend=False, color="gray", ) - g.set_axis_labels("HPI", "accuracy") + g.set_axis_labels("HPI", "ROC AUC") g.figure.set_size_inches(3.5, 0.75) g.set(xlim=(-1, 7), ylim=(0.9, 1.0)) plt.show() From 039b05f742ccadd74a9313f6fe9cec00a10c0143 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 16 Jul 2025 14:12:51 -0700 Subject: [PATCH 2/4] do not require z in annotation --- viscy/data/cell_classification.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/viscy/data/cell_classification.py b/viscy/data/cell_classification.py index ac72a6601..757a4921c 100644 --- a/viscy/data/cell_classification.py +++ b/viscy/data/cell_classification.py @@ -10,7 +10,6 @@ from torch.utils.data import DataLoader, Dataset from viscy.data.hcs import _read_norm_meta -from viscy.data.triplet import INDEX_COLUMNS class ClassificationDataset(Dataset): @@ -42,6 +41,16 @@ def __init__( annotation["y"].between(*y_range, inclusive="neither") & annotation["x"].between(*x_range, inclusive="neither") ] + self._index_columns = [ + "fov_name", + "track_id", + "t", + "id", + "parent_track_id", + "parent_id", + "y", + "x", + ] def __len__(self): return len(self.annotation) @@ -68,7 +77,7 @@ def __getitem__( img = self.transform(img) label = torch.tensor(row["infection_state"]).float()[None] if self.return_indices: - return img, label, row[INDEX_COLUMNS].to_dict() + return img, label, row[self._index_columns].to_dict() else: return img, label From b276c7ed284c5bda9feb6ab027f4d2bc51bdbf1f Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 16 Jul 2025 16:13:50 -0700 Subject: [PATCH 3/4] add classical sampling results --- .../evaluation/knowledge_distillation.py | 28 +++++++++++++------ 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py b/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py index c74433c12..b8a35568e 100644 --- a/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py +++ b/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py @@ -27,15 +27,22 @@ prediction_from_scratch = pd.read_csv( "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/bootstrap-labels/test/from-scratch-last-1126.csv" ) -prediction_from_scratch["pretraining"] = "ImageNet" +prediction_from_scratch["pretraining"] = "without DynaCLR" + +prediction_classical = pd.read_csv( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/bootstrap-labels/test/fine-tune-classical-last-1126.csv" +) +prediction_classical["pretraining"] = "classical sampling" prediction_finetuned = pd.read_csv( "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/bootstrap-labels/test/fine-tune-last-1126.csv" ) -pretrained_name = "DynaCLR" +pretrained_name = "cell and time aware sampling" prediction_finetuned["pretraining"] = pretrained_name -prediction = pd.concat([prediction_from_scratch, prediction_finetuned], axis=0) +prediction = pd.concat( + [prediction_from_scratch, prediction_classical, prediction_finetuned], axis=0 +) # %% prediction = prediction[prediction["fov_name"].isin(VAL_FOVS)] @@ -92,13 +99,17 @@ ) metrics_df = pd.DataFrame( - data={"accuracy": accuracy_by_t.values, "F1": f1_by_t.values, "ROC AUC": roc_auc_by_t.values}, + data={ + "accuracy": accuracy_by_t.values, + "F1": f1_by_t.values, + "ROC AUC": roc_auc_by_t.values, + }, index=f1_by_t.index, ).reset_index() metrics_long = metrics_df.melt( id_vars=id_vars, - value_vars=["ROC AUC"], + value_vars=["accuracy", "F1", "ROC AUC"], var_name="metric", value_name="score", ) @@ -113,11 +124,12 @@ kind="point", linewidth=1.5, linestyles="--", + col="metric", ) g.set_axis_labels("HPI", "ROC AUC") - sns.move_legend(g, "upper left", bbox_to_anchor=(0.35, 1.1)) - g.figure.set_size_inches(3.5, 1.5) - g.set(xlim=(-1, 7), ylim=(0.6, 1.0)) + # sns.move_legend(g, "upper left", bbox_to_anchor=(0.35, 1.1)) + # g.figure.set_size_inches(3.5, 1.5) + # g.set(xlim=(-1, 7), ylim=(0.6, 1.0)) plt.show() From f282132d7d5972cdf2e7f9e262342a9ba9399921 Mon Sep 17 00:00:00 2001 From: Ziwen Liu Date: Wed, 16 Jul 2025 16:36:31 -0700 Subject: [PATCH 4/4] update legend --- .../evaluation/knowledge_distillation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py b/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py index b8a35568e..b0208b769 100644 --- a/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py +++ b/applications/contrastive_phenotyping/evaluation/knowledge_distillation.py @@ -27,17 +27,17 @@ prediction_from_scratch = pd.read_csv( "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/bootstrap-labels/test/from-scratch-last-1126.csv" ) -prediction_from_scratch["pretraining"] = "without DynaCLR" +prediction_from_scratch["pretraining"] = "ImageNet encoder" prediction_classical = pd.read_csv( "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/bootstrap-labels/test/fine-tune-classical-last-1126.csv" ) -prediction_classical["pretraining"] = "classical sampling" +prediction_classical["pretraining"] = "DynaCLR (classical sampling)" prediction_finetuned = pd.read_csv( "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/bootstrap-labels/test/fine-tune-last-1126.csv" ) -pretrained_name = "cell and time aware sampling" +pretrained_name = "DynaCLR (cell- and time-aware sampling)" prediction_finetuned["pretraining"] = pretrained_name prediction = pd.concat(