Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,12 @@
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from sklearn.metrics import accuracy_score, classification_report, f1_score
from sklearn.metrics import (
accuracy_score,
classification_report,
f1_score,
roc_auc_score,
)

# %%
# Mantis
Expand All @@ -22,15 +27,22 @@
prediction_from_scratch = pd.read_csv(
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/bootstrap-labels/test/from-scratch-last-1126.csv"
)
prediction_from_scratch["pretraining"] = "ImageNet"
prediction_from_scratch["pretraining"] = "ImageNet encoder"

prediction_classical = pd.read_csv(
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/bootstrap-labels/test/fine-tune-classical-last-1126.csv"
)
prediction_classical["pretraining"] = "DynaCLR (classical sampling)"

prediction_finetuned = pd.read_csv(
"/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/bootstrap-labels/test/fine-tune-last-1126.csv"
)
pretrained_name = "DynaCLR"
pretrained_name = "DynaCLR (cell- and time-aware sampling)"
prediction_finetuned["pretraining"] = pretrained_name

prediction = pd.concat([prediction_from_scratch, prediction_finetuned], axis=0)
prediction = pd.concat(
[prediction_from_scratch, prediction_classical, prediction_finetuned], axis=0
)

# %%
prediction = prediction[prediction["fov_name"].isin(VAL_FOVS)]
Expand Down Expand Up @@ -82,14 +94,22 @@
lambda x: float(f1_score(x["label"], x["prediction_binary"]))
)

roc_auc_by_t = prediction.groupby(id_vars).apply(
lambda x: float(roc_auc_score(x["label"], x["prediction"]))
)

metrics_df = pd.DataFrame(
data={"accuracy": accuracy_by_t.values, "F1": f1_by_t.values},
data={
"accuracy": accuracy_by_t.values,
"F1": f1_by_t.values,
"ROC AUC": roc_auc_by_t.values,
},
index=f1_by_t.index,
).reset_index()

metrics_long = metrics_df.melt(
id_vars=id_vars,
value_vars=["accuracy"],
value_vars=["accuracy", "F1", "ROC AUC"],
var_name="metric",
value_name="score",
)
Expand All @@ -104,18 +124,19 @@
kind="point",
linewidth=1.5,
linestyles="--",
col="metric",
)
g.set_axis_labels("HPI", "accuracy")
sns.move_legend(g, "upper left", bbox_to_anchor=(0.35, 1.1))
g.figure.set_size_inches(3.5, 1.5)
g.set(xlim=(-1, 7), ylim=(0.6, 1.0))
g.set_axis_labels("HPI", "ROC AUC")
# sns.move_legend(g, "upper left", bbox_to_anchor=(0.35, 1.1))
# g.figure.set_size_inches(3.5, 1.5)
# g.set(xlim=(-1, 7), ylim=(0.6, 1.0))
plt.show()


# %%
g.figure.savefig(
Path.home()
/ "gdrive/publications/dynaCLR/2025_dynaCLR_paper/fig_manuscript_svg/figure_knowledge_distillation/figure_parts/accuracy_students.pdf",
/ "gdrive/publications/dynaCLR/2025_dynaCLR_paper/fig_manuscript_svg/figure_knowledge_distillation/figure_parts/roc_auc_students.pdf",
dpi=300,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,13 @@
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
from iohub.ngff import open_ome_zarr
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, classification_report, f1_score
from sklearn.metrics import (
accuracy_score,
classification_report,
f1_score,
roc_auc_score,
)

from viscy.representation.embedding_writer import read_embedding_dataset

Expand Down Expand Up @@ -69,12 +73,14 @@ def load_features_and_annotations(embedding_path, annotation_path, filter_fn):
model = model.fit(train_features, train_annotation)
train_prediction = model.predict(train_features)
val_prediction = model.predict(val_features)
val_prediction_score = model.predict_proba(val_features)

print("Training\n", classification_report(train_annotation, train_prediction))
print("Validation\n", classification_report(val_annotation, val_prediction))

val_selected["label"] = val_selected["infection_state"].cat.codes
val_selected["prediction_binary"] = val_prediction
val_selected["prediction"] = val_prediction_score[:, 1]

# %%
prediction = val_selected
Expand Down Expand Up @@ -107,17 +113,22 @@ def load_features_and_annotations(embedding_path, annotation_path, filter_fn):
lambda x: float(f1_score(x["label"], x["prediction_binary"]))
)

roc_auc_by_t = prediction.groupby(["stage"]).apply(
lambda x: float(roc_auc_score(x["label"], x["prediction"]))
)

metrics_df = pd.DataFrame(
data={
"accuracy": accuracy_by_t.values,
"F1": f1_by_t.values,
"ROC AUC": roc_auc_by_t.values,
},
index=f1_by_t.index,
).reset_index()

metrics_long = metrics_df.melt(
id_vars=["stage"],
value_vars=["accuracy"],
value_vars=["ROC AUC"],
var_name="metric",
value_name="score",
)
Expand All @@ -133,7 +144,7 @@ def load_features_and_annotations(embedding_path, annotation_path, filter_fn):
legend=False,
color="gray",
)
g.set_axis_labels("HPI", "accuracy")
g.set_axis_labels("HPI", "ROC AUC")
g.figure.set_size_inches(3.5, 0.75)
g.set(xlim=(-1, 7), ylim=(0.9, 1.0))
plt.show()
Expand Down
13 changes: 11 additions & 2 deletions viscy/data/cell_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from torch.utils.data import DataLoader, Dataset

from viscy.data.hcs import _read_norm_meta
from viscy.data.triplet import INDEX_COLUMNS


class ClassificationDataset(Dataset):
Expand Down Expand Up @@ -42,6 +41,16 @@ def __init__(
annotation["y"].between(*y_range, inclusive="neither")
& annotation["x"].between(*x_range, inclusive="neither")
]
self._index_columns = [
"fov_name",
"track_id",
"t",
"id",
"parent_track_id",
"parent_id",
"y",
"x",
]

def __len__(self):
return len(self.annotation)
Expand All @@ -68,7 +77,7 @@ def __getitem__(
img = self.transform(img)
label = torch.tensor(row["infection_state"]).float()[None]
if self.return_indices:
return img, label, row[INDEX_COLUMNS].to_dict()
return img, label, row[self._index_columns].to_dict()
else:
return img, label

Expand Down