diff --git a/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py b/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py index fb8ce2f84..fb204078f 100644 --- a/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py +++ b/applications/contrastive_phenotyping/evaluation/cosine_dissimilarity_dataset.py @@ -1,270 +1,62 @@ # %% from pathlib import Path -from typing import Optional import matplotlib.pyplot as plt -import seaborn as sns -from sklearn.preprocessing import StandardScaler -from numpy.typing import NDArray - -from viscy.representation.embedding_writer import read_embedding_dataset -from viscy.representation.evaluation.clustering import ( - compare_time_offset, - pairwise_distance_matrix, - rank_nearest_neighbors, - select_block, -) -import numpy as np from tqdm import tqdm -import pandas as pd - -from scipy.stats import gaussian_kde -from scipy.optimize import minimize_scalar +from viscy.representation.evaluation.distance import ( + compute_embedding_distances, + analyze_and_plot_distances, +) plt.style.use("../evaluation/figure.mplstyle") - -def compute_piece_wise_dissimilarity( - features_df: pd.DataFrame, cross_dist: NDArray, rank_fractions: NDArray -): - """ - Computing the smoothness and dynamic range - - Get the off diagonal per block and compute the mode - - The blocks are not square, so we need to get the off diagonal elements - - Get the 1 and 99 percentile of the off diagonal per block - """ - piece_wise_dissimilarity_per_track = [] - piece_wise_rank_difference_per_track = [] - for name, subdata in features_df.groupby(["fov_name", "track_id"]): - if len(subdata) > 1: - indices = subdata.index.values - single_track_dissimilarity = select_block(cross_dist, indices) - single_track_rank_fraction = select_block(rank_fractions, indices) - piece_wise_dissimilarity = compare_time_offset( - single_track_dissimilarity, time_offset=1 - ) - piece_wise_rank_difference = compare_time_offset( - single_track_rank_fraction, time_offset=1 - ) - piece_wise_dissimilarity_per_track.append(piece_wise_dissimilarity) - piece_wise_rank_difference_per_track.append(piece_wise_rank_difference) - return piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track - - -def plot_histogram( - data, title, xlabel, ylabel, color="blue", alpha=0.5, stat="frequency" -): - plt.figure() - plt.title(title) - sns.histplot(data, bins=30, kde=True, color=color, alpha=alpha, stat=stat) - plt.xlabel(xlabel) - plt.ylabel(ylabel) - plt.tight_layout() - plt.show() - - -def find_distribution_peak(data: np.ndarray) -> float: - """ - Find the peak (mode) of a distribution using kernel density estimation. - - Args: - data: Array of values to find the peak for - - Returns: - float: The x-value where the peak occurs - """ - kde = gaussian_kde(data) - # Find the peak (maximum) of the KDE - result = minimize_scalar( - lambda x: -kde(x), bounds=(np.min(data), np.max(data)), method="bounded" - ) - return result.x - - -def analyze_embedding_smoothness( - prediction_path: Path, - verbose: bool = False, - output_path: Optional[str] = None, - loss_name: Optional[str] = None, - overwrite: bool = False, -) -> dict: - """ - Analyze the smoothness and dynamic range of embeddings. - - Args: - prediction_path: Path to the embedding dataset - verbose: If True, generates additional plots - output_path: Path to save the final plot (optional) - loss_name: Name of the loss function used (optional) - overwrite: If True, overwrites existing files. If False, raises error if file exists (default: False) - - Returns: - dict: Dictionary containing metrics including: - - dissimilarity_mean: Mean of adjacent frame dissimilarity - - dissimilarity_std: Standard deviation of adjacent frame dissimilarity - - dissimilarity_median: Median of adjacent frame dissimilarity - - dissimilarity_peak: Peak of adjacent frame distribution - - dissimilarity_p99: 99th percentile of adjacent frame dissimilarity - - dissimilarity_p1: 1st percentile of adjacent frame dissimilarity - - dissimilarity_distribution: Full distribution of adjacent frame dissimilarities - - random_mean: Mean of random sampling dissimilarity - - random_std: Standard deviation of random sampling dissimilarity - - random_median: Median of random sampling dissimilarity - - random_peak: Peak of random sampling distribution - - random_distribution: Full distribution of random sampling dissimilarities - - dynamic_range: Difference between random and adjacent peaks - """ - # Read the dataset - embeddings = read_embedding_dataset(prediction_path) - features = embeddings["features"] - - scaled_features = StandardScaler().fit_transform(features.values) - # Compute the cosine dissimilarity - cross_dist = pairwise_distance_matrix(scaled_features, metric="cosine") - rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) - - # Compute piece-wise dissimilarity and rank difference - features_df = features["sample"].to_dataframe().reset_index(drop=True) - piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track = ( - compute_piece_wise_dissimilarity(features_df, cross_dist, rank_fractions) - ) - - all_dissimilarity = np.concatenate(piece_wise_dissimilarity_per_track) - - p99_piece_wise_dissimilarity = np.array( - [np.percentile(track, 99) for track in piece_wise_dissimilarity_per_track] - ) - p1_percentile_piece_wise_dissimilarity = np.array( - [np.percentile(track, 1) for track in piece_wise_dissimilarity_per_track] - ) - - # Random sampling values in the dissimilarity matrix with same size as adjacent frame measurements - n_samples = len(all_dissimilarity) - random_indices = np.random.randint(0, len(cross_dist), size=(n_samples, 2)) - sampled_values = cross_dist[random_indices[:, 0], random_indices[:, 1]] - - # Compute the peaks of both distributions using KDE - adjacent_peak = float(find_distribution_peak(all_dissimilarity)) - random_peak = float(find_distribution_peak(sampled_values)) - dynamic_range = float(random_peak - adjacent_peak) - - metrics = { - "dissimilarity_mean": float(np.mean(all_dissimilarity)), - "dissimilarity_std": float(np.std(all_dissimilarity)), - "dissimilarity_median": float(np.median(all_dissimilarity)), - "dissimilarity_peak": adjacent_peak, - "dissimilarity_p99": p99_piece_wise_dissimilarity, - "dissimilarity_p1": p1_percentile_piece_wise_dissimilarity, - "dissimilarity_distribution": all_dissimilarity, - "random_mean": float(np.mean(sampled_values)), - "random_std": float(np.std(sampled_values)), - "random_median": float(np.median(sampled_values)), - "random_peak": random_peak, - "random_distribution": sampled_values, - "dynamic_range": dynamic_range, - } - - if verbose: - # Plot cross distance matrix - plt.figure() - plt.imshow(cross_dist) - plt.show() - - # Plot histograms - plot_histogram( - piece_wise_dissimilarity_per_track, - "Adjacent Frame Dissimilarity per Track", - "Cosine Dissimilarity", - "Frequency", - ) - - # Plot the comparison histogram and save if output_path is provided - fig = plt.figure() - sns.histplot( - metrics["dissimilarity_distribution"], - bins=30, - kde=True, - color="cyan", - alpha=0.5, - stat="density", - ) - sns.histplot( - metrics["random_distribution"], - bins=30, - kde=True, - color="red", - alpha=0.5, - stat="density", - ) - plt.xlabel("Cosine Dissimilarity") - plt.ylabel("Density") - # Add vertical lines for the peaks - plt.axvline( - x=metrics["dissimilarity_peak"], color="cyan", linestyle="--", alpha=0.8 - ) - plt.axvline(x=metrics["random_peak"], color="red", linestyle="--", alpha=0.8) - plt.tight_layout() - plt.legend(["Adjacent Frame", "Random Sample", "Adjacent Peak", "Random Peak"]) - - if output_path and loss_name: - output_file = Path( - f"{output_path}/cosine_dissimilarity_smoothness_{prediction_path.stem}_{loss_name}.pdf" - ) - if output_file.exists() and not overwrite: - raise FileExistsError( - f"File {output_file} already exists and overwrite=False" - ) - fig.savefig( - output_file, - dpi=600, - ) - plt.show() - - return metrics - - -# Example usage: if __name__ == "__main__": - # plotting - VERBOSE = True - - PATH_TO_GDRIVE_FIGUE = "./" + # Define models as a dictionary with meaningful keys + prediction_paths = { + "ntxent_sensor_phase": Path( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_98ckpt_rev6_2.zarr" + ), + "triplet_sensor_phase": Path( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev5_sensorPhase_infection/2chan_160patch_97ckpt_rev5_2.zarr" + ), + } - prediction_path_1 = Path( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev6_NTXent_sensorPhase_infection/2chan_160patch_98ckpt_rev6_2.zarr" - ) - prediction_path_2 = Path( - "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/rev5_sensorPhase_infection/2chan_160patch_97ckpt_rev5_2.zarr" + # output_folder to save the distributions as .csv + output_folder = Path( + "/hpc/projects/comp.micro/infected_cell_imaging/Single_cell_phenotyping/ContrastiveLearning/trainng_logs/SEC61/cosine_dissimilarity_distributions" ) - - # Create a list of models to evaluate - models = [ - (prediction_path_1, "ntxent"), - (prediction_path_2, "triplet"), - ] + output_folder.mkdir(parents=True, exist_ok=True) # Evaluate each model - for prediction_path, loss_name in tqdm(models, desc="Evaluating models"): - print(f"\nAnalyzing model: {prediction_path.stem} (Loss: {loss_name})") - print("-" * 80) + for model_name, prediction_path in tqdm( + prediction_paths.items(), desc="Evaluating models" + ): + print(f"\nAnalyzing model: {prediction_path.stem} (Loss: {model_name})") + + # Compute and save distributions + distributions_df = compute_embedding_distances( + prediction_path=prediction_path, + output_folder=output_folder, + distance_metric="cosine", + verbose=True, + ) - metrics = analyze_embedding_smoothness( - prediction_path, - verbose=VERBOSE, - output_path=PATH_TO_GDRIVE_FIGUE, - loss_name=loss_name, + # Analyze distributions and create plots + metrics = analyze_and_plot_distances( + distributions_df, + output_file_path=output_folder / f"{model_name}_distance_plot.pdf", overwrite=True, ) - # Print adjacent frame dissimilarity statistics - print("\nAdjacent Frame Dissimilarity Statistics:") + # Print statistics + print("\nAdjacent Frame Distance Statistics:") print(f"{'Mean:':<15} {metrics['dissimilarity_mean']:.3f}") print(f"{'Std:':<15} {metrics['dissimilarity_std']:.3f}") print(f"{'Median:':<15} {metrics['dissimilarity_median']:.3f}") print(f"{'Peak:':<15} {metrics['dissimilarity_peak']:.3f}") - print(f"{'P1:':<15} {np.mean(metrics['dissimilarity_p1']):.3f}") - print(f"{'P99:':<15} {np.mean(metrics['dissimilarity_p99']):.3f}") + print(f"{'P1:':<15} {metrics['dissimilarity_p1']:.3f}") + print(f"{'P99:':<15} {metrics['dissimilarity_p99']:.3f}") # Print random sampling statistics print("\nRandom Sampling Statistics:") @@ -280,8 +72,8 @@ def analyze_embedding_smoothness( # Print distribution sizes print("\nDistribution Sizes:") print( - f"{'Adjacent Frame:':<15} {len(metrics['dissimilarity_distribution']):,d} samples" + f"{'Adjacent Frame:':<15} {len(distributions_df['adjacent_frame']):,d} samples" ) - print(f"{'Random:':<15} {len(metrics['random_distribution']):,d} samples") + print(f"{'Random:':<15} {len(distributions_df['random_sampling']):,d} samples") # %% diff --git a/applications/contrastive_phenotyping/evaluation/cosine_similarity.py b/applications/contrastive_phenotyping/evaluation/cosine_similarity_demo.py similarity index 100% rename from applications/contrastive_phenotyping/evaluation/cosine_similarity.py rename to applications/contrastive_phenotyping/evaluation/cosine_similarity_demo.py diff --git a/applications/contrastive_phenotyping/evaluation/euclidean_distance_dataset.py b/applications/contrastive_phenotyping/evaluation/euclidean_distance_dataset.py new file mode 100644 index 000000000..3ed67dcc2 --- /dev/null +++ b/applications/contrastive_phenotyping/evaluation/euclidean_distance_dataset.py @@ -0,0 +1,80 @@ +# %% +from pathlib import Path +import sys +sys.path.append("/hpc/mydata/soorya.pradeep/scratch/viscy_infection_phenotyping/VisCy") +import matplotlib.pyplot as plt +from tqdm import tqdm + +from viscy.representation.evaluation.distance import ( + compute_embedding_distances, + analyze_and_plot_distances, +) + +# plt.style.use("../evaluation/figure.mplstyle") + +if __name__ == "__main__": + # Define models as a dictionary with meaningful keys + prediction_paths = { + "ntxent_classical": Path( + "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_classical.zarr" + ), + "triplet_classical": Path( + "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/log_alfi_triplet_time_intervals/prediction/ALFI_classical.zarr" + ), + } + + # output_folder to save the distributions as .csv + output_folder = Path( + "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/metrics" + ) + output_folder.mkdir(parents=True, exist_ok=True) + + # Evaluate each model + for model_name, prediction_path in tqdm( + prediction_paths.items(), desc="Evaluating models" + ): + print(f"\nAnalyzing model: {prediction_path.stem} (Loss: {model_name})") + + # Compute and save distributions + distributions_df = compute_embedding_distances( + prediction_path=prediction_path, + output_path=output_folder / f"{model_name}_distance_.csv", + distance_metric="euclidean", + verbose=True, + ) + + # Analyze distributions and create plots + metrics = analyze_and_plot_distances( + distributions_df, + output_file_path=output_folder / f"{model_name}_distance_plot.pdf", + overwrite=True, + ) + + # Print statistics + print("\nAdjacent Frame Distance Statistics:") + print(f"{'Mean:':<15} {metrics['dissimilarity_mean']:.3f}") + print(f"{'Std:':<15} {metrics['dissimilarity_std']:.3f}") + print(f"{'Median:':<15} {metrics['dissimilarity_median']:.3f}") + print(f"{'Peak:':<15} {metrics['dissimilarity_peak']:.3f}") + print(f"{'P1:':<15} {metrics['dissimilarity_p1']:.3f}") + print(f"{'P99:':<15} {metrics['dissimilarity_p99']:.3f}") + + # Print random sampling statistics + print("\nRandom Sampling Statistics:") + print(f"{'Mean:':<15} {metrics['random_mean']:.3f}") + print(f"{'Std:':<15} {metrics['random_std']:.3f}") + print(f"{'Median:':<15} {metrics['random_median']:.3f}") + print(f"{'Peak:':<15} {metrics['random_peak']:.3f}") + + # Print dynamic range + print("\nComparison Metrics:") + print(f"{'Dynamic Range:':<15} {metrics['dynamic_range']:.3f}") + + # Print distribution sizes + print("\nDistribution Sizes:") + print( + f"{'Adjacent Frame:':<15} {len(distributions_df['adjacent_frame']):,d} samples" + ) + print(f"{'Random:':<15} {len(distributions_df['random_sampling']):,d} samples") + +# %% diff --git a/applications/contrastive_phenotyping/figures/ALFI_accuracy_metrics.py b/applications/contrastive_phenotyping/figures/ALFI_accuracy_metrics.py new file mode 100644 index 000000000..030673a64 --- /dev/null +++ b/applications/contrastive_phenotyping/figures/ALFI_accuracy_metrics.py @@ -0,0 +1,138 @@ + +# %% compute accuracy of model from ALFI data using cell division state classification + +from pathlib import Path +import matplotlib.pyplot as plt +import pandas as pd +import numpy as np +from sklearn.linear_model import LogisticRegression + +from viscy.representation.embedding_writer import read_embedding_dataset + +# %% +accuracies = [] + +features_paths = { + '7 min interval': '/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_7mins.zarr', + '14 min interval': '/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_14mins.zarr', + '28 min interval': '/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_28mins.zarr', + '56 min interval': '/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_56mins.zarr', + '91 min interval': '/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_91mins.zarr', + 'classical': '/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_classical.zarr', +} + +for interval_name, path in features_paths.items(): + features_path = Path(path) + embedding_dataset = read_embedding_dataset(features_path) + embedding_dataset + features = embedding_dataset["features"] + + # load the cell cycle state annotation + + def load_annotation(da, path, name, categories: dict | None = None): + annotation = pd.read_csv(path) + # annotation_columns = annotation.columns.tolist() + # print(annotation_columns) + annotation["fov_name"] = "/" + annotation["fov ID"] + annotation = annotation.set_index(["fov_name", "id"]) + mi = pd.MultiIndex.from_arrays( + [da["fov_name"].values, da["id"].values], names=["fov_name", "id"] + ) + selected = annotation.reindex(mi)[name] + if categories: + selected = selected.astype("category").cat.rename_categories(categories) + return selected + + ann_root = Path( + "/hpc/projects/organelle_phenotyping/ALFI_models_data/datasets/zarr_datasets" + ) + + division = load_annotation( + embedding_dataset, + ann_root / "test_annotations.csv", + "division", + {0: "interphase", 1: "mitosis"}, + ) + + # train a linear classifier on half the data + + division_npy = division.cat.codes.values + division_npy_filtered = division_npy[division_npy != -1] + + feature_npy = features.values + feature_npy_filtered = feature_npy[division_npy != -1] + + # add time and well info into dataframe + time_npy = features["t"].values + time_npy_filtered = time_npy[division_npy != -1] + + + fov_name_list = features["fov_name"].values + fov_name_list_filtered = fov_name_list[division_npy != -1] + + data = pd.DataFrame( + { + "division": division_npy_filtered, + "time": time_npy_filtered, + "fov_name": fov_name_list_filtered, + } + ) + # Add all 768 features to the dataframe + feature_columns = pd.DataFrame(feature_npy_filtered, columns=[f"feature_{i+1}" for i in range(768)]) + data = pd.concat([data, feature_columns], axis=1) + + # dataframe for training set, fov names starts with "/B/4/6" or "/B/4/7" or "/A/3/" + data_train_val = data[ + data["fov_name"].str.contains("/0/0/0") + | data["fov_name"].str.contains("/0/1/0") + | data["fov_name"].str.contains("/0/2/0") + ] + + data_test = data[ + data["fov_name"].str.contains("/0/3/0") + | data["fov_name"].str.contains("/0/4/0") + ] + + x_train = data_train_val.drop( + columns=[ + "division", + "fov_name", + "time", + ] + ) + y_train = data_train_val["division"] + + # train a logistic regression model + clf = LogisticRegression(random_state=0).fit(x_train, y_train) + + # test the trained classifer on the other half of the data + + x_test = data_test.drop( + columns=[ + "division", + "fov_name", + "time", + ] + ) + y_test = data_test["division"] + + # predict the infection state for the testing set + y_pred = clf.predict(x_test) + + # compute the accuracy of the classifier + + accuracy = np.mean(y_pred == y_test) + # save the accuracy for final ploting + print(f"Accuracy of model trained on {interval_name} data: {accuracy}") + accuracies.append(accuracy) + +# %% plot the accuracy of the model trained on different time intervals + +plt.figure(figsize=(8, 6)) +plt.bar(features_paths.keys(), accuracies) +plt.xticks(rotation=45, ha='right', fontsize=12) +plt.ylabel("Accuracy", fontsize=14) +plt.xlabel("Time interval", fontsize=14) +plt.ylim(0.9, 1) +plt.show() +# %% diff --git a/applications/contrastive_phenotyping/figures/ALFI_cell_division.py b/applications/contrastive_phenotyping/figures/ALFI_cell_division.py new file mode 100644 index 000000000..eaf941d93 --- /dev/null +++ b/applications/contrastive_phenotyping/figures/ALFI_cell_division.py @@ -0,0 +1,252 @@ + +# %% Figure on ALFI cell division model showing +# (a) Euclidean distance over a cell division event and +# (b) difference between trajectory of cell in time-aware and classical method over division event + +from pathlib import Path +from collections import defaultdict +import seaborn as sns +import matplotlib.pyplot as plt +from matplotlib.patches import FancyArrowPatch +import numpy as np +from sklearn.metrics.pairwise import cosine_similarity +import pandas as pd + +from viscy.representation.embedding_writer import read_embedding_dataset + +# %% Task A: plot the Eucledian distance for a dividing cell + +# Paths to datasets +feature_paths = { + "7 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_7mins.zarr", + "14 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_14mins.zarr", + "28 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_28mins.zarr", + "56 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_56mins.zarr", + "91 min interval": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_91mins.zarr", + "Classical": "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_classical.zarr", +} + +track_well = '/0/2/0' +parent_id = 3 # 11 +daughter1_track = 4 # 12 +daughter2_track = 5 # 13 + +# %% plot the eucledian distance over time lag for a parent cell for different time intervals + +def compute_displacement_track(fov_name, track_id, current_time, distance_metric="euclidean_squared", max_delta_t=10): + + fov_names = embedding_dataset["fov_name"].values + track_ids = embedding_dataset["track_id"].values + timepoints = embedding_dataset["t"].values + embeddings = embedding_dataset["features"].values + + # find index where fov_name, track_id and current_time match + i = np.where( + (fov_names == fov_name) + & (track_ids == track_id) + & (timepoints == current_time) + )[0][0] + current_embedding = embeddings[i].reshape(1, -1) + + # Check if max_delta_t is provided, otherwise use the maximum timepoint + if max_delta_t is None: + max_delta_t = timepoints.max() + + displacement_per_delta_t = defaultdict(list) + + # Compute displacements for each delta t + for delta_t in range(1, max_delta_t + 1): + future_time = current_time + delta_t + matching_indices = np.where( + (fov_names == fov_name) + & (track_ids == track_id) + & (timepoints == future_time) + )[0] + + if len(matching_indices) == 1: + if distance_metric == "euclidean_squared": + future_embedding = embeddings[matching_indices[0]].reshape(1, -1) + displacement = np.sum((current_embedding - future_embedding) ** 2) + elif distance_metric == "cosine": + future_embedding = embeddings[matching_indices[0]].reshape(1, -1) + displacement = cosine_similarity( + current_embedding, future_embedding + ) + displacement_per_delta_t[delta_t].append(displacement) + + return displacement_per_delta_t + +# %% plot the eucledian distance for a parent cell + +plt.figure(figsize=(10, 6)) +for label, path in feature_paths.items(): + embedding_dataset = read_embedding_dataset(path) + displacement_per_delta_t = compute_displacement_track(track_well, parent_id, 1) + delta_ts = sorted(displacement_per_delta_t.keys()) + displacements = [np.mean(displacement_per_delta_t[delta_t]) for delta_t in delta_ts] + plt.plot(delta_ts, displacements, label=label) + +plt.xlabel("Time Interval (delta t)") +plt.ylabel("Displacement (Euclidean Distance)") +plt.title("Displacement vs Time Interval for Parent Cell") +plt.legend() +plt.show() + +# %% Task B: plot the phate map and overlay the dividing cell trajectory + +# for time-aware model uncomment the next three lines +# embedding_dataset = read_embedding_dataset( +# "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_28mins.zarr" +# ) + +# for classical model uncomment the next three line +embedding_dataset = read_embedding_dataset( + "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/logs_alfi_ntxent_time_intervals/predictions/ALFI_classical.zarr" +) + +PHATE1 = embedding_dataset["PHATE1"].values +PHATE2 = embedding_dataset["PHATE2"].values + +# %% plot PHATE map based on the embedding dataset time points + +sns.scatterplot( + x=embedding_dataset["PHATE1"], y=embedding_dataset["PHATE2"], hue=embedding_dataset["t"], s=7, alpha=0.8 +) + +# %% color using human annotation for cell cycle state + +def load_annotation(da, path, name, categories: dict | None = None): + annotation = pd.read_csv(path) + # annotation_columns = annotation.columns.tolist() + # print(annotation_columns) + annotation["fov_name"] = "/" + annotation["fov ID"] + annotation = annotation.set_index(["fov_name", "id"]) + mi = pd.MultiIndex.from_arrays( + [da["fov_name"].values, da["id"].values], names=["fov_name", "id"] + ) + selected = annotation.reindex(mi)[name] + if categories: + selected = selected.astype("category").cat.rename_categories(categories) + return selected + + +# %% load the cell cycle state annotation + +ann_root = Path( + "/hpc/projects/organelle_phenotyping/ALFI_models_data/datasets/zarr_datasets" +) + +division = load_annotation( + embedding_dataset, + ann_root / "test_annotations.csv", + "division", + {0: "interphase", 1: "mitosis"}, +) + +# %% find a parent that divides to two daughter cells for ploting trajectory + +cell_parent = embedding_dataset.where(embedding_dataset["fov_name"] == track_well, drop=True).where( + embedding_dataset["track_id"] == parent_id, drop=True +) +cell_parent = cell_parent["PHATE1"].values, cell_parent["PHATE2"].values +cell_parent = pd.DataFrame(np.column_stack(cell_parent), columns=["PHATE1", "PHATE2"]) + +cell_daughter1 = embedding_dataset.where(embedding_dataset["fov_name"] == track_well, drop=True).where( + embedding_dataset["track_id"] == daughter1_track, drop=True +) +cell_daughter1 = cell_daughter1["PHATE1"].values, cell_daughter1["PHATE2"].values +cell_daughter1 = pd.DataFrame(np.column_stack(cell_daughter1), columns=["PHATE1", "PHATE2"]) + +cell_daughter2 = embedding_dataset.where(embedding_dataset["fov_name"] == track_well, drop=True).where( + embedding_dataset["track_id"] == daughter2_track, drop=True +) +cell_daughter2 = cell_daughter2["PHATE1"].values, cell_daughter2["PHATE2"].values +cell_daughter2 = pd.DataFrame(np.column_stack(cell_daughter2), columns=["PHATE1", "PHATE2"]) + +# %% Plot: display one arrow at end of trajectory of cell overlayed on PHATE + +sns.scatterplot( + x=embedding_dataset["PHATE1"], + y=embedding_dataset["PHATE2"], + hue=division, + palette={"interphase": "steelblue", "mitosis": "orangered", -1: "green"}, + s=7, + alpha=0.5, +) + +# sns.lineplot(x=cell_parent["PHATE1"], y=cell_parent["PHATE2"], color="black", linewidth=2) +# sns.lineplot( +# x=cell_daughter1["PHATE1"], y=cell_daughter1["PHATE2"], color="blue", linewidth=2 +# ) +# sns.lineplot( +# x=cell_daughter2["PHATE1"], y=cell_daughter2["PHATE2"], color="red", linewidth=2 +# ) + +parent_arrow = FancyArrowPatch( + (cell_parent["PHATE1"].values[28], cell_parent["PHATE2"].values[28]), + (cell_parent["PHATE1"].values[35], cell_parent["PHATE2"].values[35]), + color="black", + arrowstyle="->", + mutation_scale=20, # reduce the size of arrowhead by half + lw=2, + shrinkA=0, + shrinkB=0, +) +plt.gca().add_patch(parent_arrow) +parent_arrow = FancyArrowPatch( + (cell_parent["PHATE1"].values[35], cell_parent["PHATE2"].values[35]), + (cell_parent["PHATE1"].values[38], cell_parent["PHATE2"].values[38]), + color="black", + arrowstyle="->", + mutation_scale=20, # reduce the size of arrowhead by half + lw=2, + shrinkA=0, + shrinkB=0, +) +plt.gca().add_patch(parent_arrow) +daughter1_arrow = FancyArrowPatch( + (cell_daughter1["PHATE1"].values[0], cell_daughter1["PHATE2"].values[0]), + (cell_daughter1["PHATE1"].values[1], cell_daughter1["PHATE2"].values[1]), + color="blue", + arrowstyle="->", + mutation_scale=20, # reduce the size of arrowhead by half + lw=2, + shrinkA=0, + shrinkB=0, +) +plt.gca().add_patch(daughter1_arrow) +daughter1_arrow = FancyArrowPatch( + (cell_daughter1["PHATE1"].values[1], cell_daughter1["PHATE2"].values[1]), + (cell_daughter1["PHATE1"].values[10], cell_daughter1["PHATE2"].values[10]), + color="blue", + arrowstyle="->", + mutation_scale=20, # reduce the size of arrowhead by half + lw=2, + shrinkA=0, + shrinkB=0, +) +plt.gca().add_patch(daughter1_arrow) +daughter2_arrow = FancyArrowPatch( + (cell_daughter2["PHATE1"].values[0], cell_daughter2["PHATE2"].values[0]), + (cell_daughter2["PHATE1"].values[1], cell_daughter2["PHATE2"].values[1]), + color="red", + arrowstyle="->", + mutation_scale=20, # reduce the size of arrowhead by half + lw=2, + shrinkA=0, + shrinkB=0, +) +plt.gca().add_patch(daughter2_arrow) +daughter2_arrow = FancyArrowPatch( + (cell_daughter2["PHATE1"].values[1], cell_daughter2["PHATE2"].values[1]), + (cell_daughter2["PHATE1"].values[10], cell_daughter2["PHATE2"].values[10]), + color="red", + arrowstyle="->", + mutation_scale=20, # reduce the size of arrowhead by half + lw=2, + shrinkA=0, + shrinkB=0, +) +plt.gca().add_patch(daughter2_arrow) + +# %% diff --git a/applications/contrastive_phenotyping/figures/interactive_plot_wDisplay.py b/applications/contrastive_phenotyping/figures/interactive_plot_wDisplay.py new file mode 100644 index 000000000..43441918c --- /dev/null +++ b/applications/contrastive_phenotyping/figures/interactive_plot_wDisplay.py @@ -0,0 +1,138 @@ + +# This is a simple example of an interactive plot using Dash. +from pathlib import Path +import dash +from dash import dcc, html +import plotly.express as px +import pandas as pd +import numpy as np +import base64 +from io import BytesIO +from PIL import Image +from sklearn.decomposition import PCA +from sklearn.preprocessing import StandardScaler +import dash.dependencies as dd + +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation import dataset_of_tracks + +# Initialize Dash app +app = dash.Dash(__name__) + +# Sample DataFrame for demonstration +features_path = Path( + "/hpc/projects/intracellular_dashboard/viral-sensor/infection_classification/models/time_sampling_strategies/time_interval/predict/jun_time_interval_1_epoch_178.zarr" +) +embedding_dataset = read_embedding_dataset(features_path) +features = embedding_dataset["features"] +scaled_features = StandardScaler().fit_transform(features.values) +pca = PCA(n_components=3) +embedding = pca.fit_transform(scaled_features) +features = ( + features.assign_coords(PCA1=("sample", embedding[:, 0])) + .assign_coords(PCA2=("sample", embedding[:, 1])) + .assign_coords(PCA3=("sample", embedding[:, 2])) + .set_index(sample=["PCA1", "PCA2", "PCA3"], append=True) +) + +df = pd.DataFrame({k: v for k, v in features.coords.items() if k != "features"}) + +# Image paths for each track and time + +data_path = Path( + "/hpc/projects/organelle_phenotyping/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/registered_chunked.zarr" +) +tracks_path = Path( + "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_06_13_SEC61_TOMM20_ZIKV_DENGUE_1/4.2-tracking/track.zarr" +) + +# Create scatter plot with hover data (track_id, t, fov_name) +fig = px.scatter( + df, + x="PCA1", + y="PCA2", + color="PCA1", + hover_name="fov_name", + hover_data=["id", "t", "track_id"], # Include track_id and t for image lookup +) + +# Layout of the app +app.layout = html.Div([ + dcc.Graph( + id="scatter-plot", + figure=fig, + ), + html.Div([ + html.Img(id="hover-image", src="", style={"width": "150px", "height": "150px"}) + ]) +]) + +# Helper function to convert numpy array to base64 image +def numpy_to_base64(img_array): + # Clip, normalize, and scale to the range [0, 255] + img_array = np.clip(img_array, img_array.min(), img_array.max()) # Clip values to the expected range + img_array = (img_array - img_array.min()) / (img_array.max() - img_array.min()) # Normalize to [0, 1] + img_array = (img_array * 255).astype(np.uint8) # Scale to [0, 255] and convert to uint8 + + img = Image.fromarray(img_array) + buffered = BytesIO() + img.save(buffered, format="PNG") + return "data:image/png;base64," + base64.b64encode(buffered.getvalue()).decode("utf-8") + + +# Callback to update the image when a point is hovered over +@app.callback( + dd.Output("hover-image", "src"), + [dd.Input("scatter-plot", "hoverData")] +) +def update_image(hoverData): + if hoverData is None: + return "" # Return empty if no hover + + # Extract the necessary information from hoverData + fov_name = hoverData['points'][0]['hovertext'] # fov_name is in hovertext + track_id = hoverData['points'][0]['customdata'][2] # track_id from hover_data + t = hoverData['points'][0]['customdata'][1] # t from hover_data + + print(f"Hovering over: fov_name={fov_name}, track_id={track_id}, t={t}") + + # Lookup the image path based on fov_name, track_id, and t + # image_key = (fov_name, track_id, t) + + # Get the image URL if it exists + # return image_paths.get(image_key, "") # Return empty string if no match + source_channel = ["Phase3D"] + z_range = (33,34) + predict_dataset = dataset_of_tracks( + data_path, + tracks_path, + [fov_name], + [track_id], + z_range=z_range, + source_channel=source_channel, + ) + # image_patch = np.stack([p["anchor"][0, 7].numpy() for p in predict_dataset]) + + # Check if the dataset was retrieved successfully + if not predict_dataset: + print(f"No dataset found for fov_name={fov_name}, track_id={track_id}, t={t}") + return "" # Return empty if no dataset is found + + # Extract the image patch (assuming it's a numpy array) + try: + image_patch = np.stack([p["anchor"][0].numpy() for p in predict_dataset]) + image_patch = image_patch[0,0] + print(f"Image patch shape: {image_patch.shape}") + except Exception as e: + print(f"Error extracting image patch: {e}") + return "" + + # Check if the image is valid (this step is just a safety check) + if image_patch.ndim != 2: + print(f"Invalid image data: image_patch is not 2D.") + return "" + + return numpy_to_base64(image_patch) + +if __name__ == '__main__': + app.run_server(debug=True) \ No newline at end of file diff --git a/applications/contrastive_phenotyping/figures/plot_phatemap_ALFI.py b/applications/contrastive_phenotyping/figures/plot_phatemap_ALFI.py new file mode 100644 index 000000000..9d2d78f58 --- /dev/null +++ b/applications/contrastive_phenotyping/figures/plot_phatemap_ALFI.py @@ -0,0 +1,93 @@ + +# %% + +from pathlib import Path +import matplotlib.pyplot as plt +import seaborn as sns +import pandas as pd + +from viscy.representation.embedding_writer import read_embedding_dataset + +# %% + +features_path = Path( + "/hpc/projects/organelle_phenotyping/ALFI_ntxent_loss/log_alfi_triplet_time_intervals/prediction/ALFI_91mins.zarr" +) +# data_path = Path( +# "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/2-assemble/2024_11_07_A549_SEC61_ZIKV_DENV.zarr" +# ) +# tracks_path = Path( +# "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/1-preprocess/label-free/3-track/2024_11_07_A549_SEC61_ZIKV_DENV_2_cropped.zarr" +# ) + +# %% + +embedding_dataset = read_embedding_dataset(features_path) +embedding_dataset + +PHATE1 = embedding_dataset["PHATE1"].values +PHATE2 = embedding_dataset["PHATE2"].values + +# %% plot PHATE map based on the embedding dataset time points + +sns.scatterplot( + x=embedding_dataset["PHATE1"], y=embedding_dataset["PHATE2"], hue=embedding_dataset["t"], s=7, alpha=0.8 +) + +# %% color using human annotation for cell cycle state + +def load_annotation(da, path, name, categories: dict | None = None): + annotation = pd.read_csv(path) + # annotation_columns = annotation.columns.tolist() + # print(annotation_columns) + annotation["fov_name"] = "/" + annotation["fov ID"] + annotation = annotation.set_index(["fov_name", "id"]) + mi = pd.MultiIndex.from_arrays( + [da["fov_name"].values, da["id"].values], names=["fov_name", "id"] + ) + selected = annotation.reindex(mi)[name] + if categories: + selected = selected.astype("category").cat.rename_categories(categories) + return selected + + +# %% load the cell cycle state annotation + +ann_root = Path( + "/hpc/projects/organelle_phenotyping/ALFI_models_data/datasets/zarr_datasets" +) + +division = load_annotation( + embedding_dataset, + ann_root / "test_annotations.csv", + "division", + {0: "interphase", 1: "mitosis"}, +) + +# %% plot PHATE map based on the cell cycle annotation + +sns.scatterplot( + x=embedding_dataset["PHATE1"], y=embedding_dataset["PHATE2"], hue=division, s=7, alpha=0.8 +) + +# %% plot intercative plot to hover over the points on scatter plot and see the fov_name and track_id + +import plotly.express as px + +fig = px.scatter( + embedding_dataset.to_dataframe(), + x="PHATE1", + y="PHATE2", + color=division, + hover_data=["fov_name", "id"], +) + +# %% +# find row index in 'division' where the value is -1 +division[division == -1].index +# find the track_id and 't' value of cell instance where 'fov_name' is '/0/0/0' and 'id' is 1000941 +embedding_dataset.where(embedding_dataset["fov_name"] == "/0/0/0", drop=True).where( + embedding_dataset["id"] == 1000942, drop=True +) + +# %% diff --git a/applications/pseudotime_analysis/pca_analysis.py b/applications/pseudotime_analysis/pca_analysis.py new file mode 100644 index 000000000..3aed085ce --- /dev/null +++ b/applications/pseudotime_analysis/pca_analysis.py @@ -0,0 +1,564 @@ +# %% +import numpy as np +import pandas as pd +from sklearn.preprocessing import StandardScaler +from sklearn.decomposition import PCA +import matplotlib.pyplot as plt +import seaborn as sns +from viscy.representation.embedding_writer import read_embedding_dataset +from scipy.spatial.distance import pdist, squareform + +# Set global random seed for reproducibility +RANDOM_SEED = 42 +np.random.seed(RANDOM_SEED) + + +def analyze_pc_loadings(pca, feature_names=None, top_n=5): + """Analyze which features contribute most to each PC.""" + if feature_names is None: + feature_names = [f"Feature_{i}" for i in range(pca.components_[0].shape[0])] + + pc_loadings = [] + for i, pc in enumerate(pca.components_): + # Get the absolute loadings + abs_loadings = np.abs(pc) + # Get indices of top contributing features + top_indices = np.argsort(abs_loadings)[-top_n:][::-1] + + # Store the results + pc_dict = { + "PC": i + 1, + "Variance_Explained": pca.explained_variance_ratio_[i], + "Top_Features": [feature_names[idx] for idx in top_indices], + "Top_Loadings": [pc[idx] for idx in top_indices], + } + pc_loadings.append(pc_dict) + + return pd.DataFrame(pc_loadings) + + +def analyze_track_clustering( + pca_result, + track_ids, + time_points, + labels, + phenotype_of_interest, + seed_timepoint, + time_window, +): + """Analyze how tracks cluster in PC space within the time window.""" + # Get points within time window + time_mask = (time_points >= seed_timepoint - time_window) & ( + time_points <= seed_timepoint + time_window + ) + window_points = pca_result[time_mask] + window_tracks = track_ids[time_mask] + window_labels = labels[time_mask] + + # Calculate mean position for each track + track_means = {} + phenotype_tracks = [] + + for track_id in np.unique(window_tracks): + track_mask = (window_tracks == track_id) & ( + window_labels == phenotype_of_interest + ) + if np.any(track_mask): + track_means[track_id] = np.mean(window_points[track_mask], axis=0) + phenotype_tracks.append(track_id) + + if len(phenotype_tracks) < 2: + return None + + # Calculate pairwise distances between track means + track_positions = np.array([track_means[tid] for tid in phenotype_tracks]) + distances = pdist(track_positions) + mean_distance = np.mean(distances) + std_distance = np.std(distances) + + # Calculate spread within each track + track_spreads = {} + for track_id in phenotype_tracks: + track_mask = (window_tracks == track_id) & ( + window_labels == phenotype_of_interest + ) + if np.sum(track_mask) > 1: + track_points = window_points[track_mask] + spread = np.mean(pdist(track_points)) + track_spreads[track_id] = spread + + mean_spread = np.mean(list(track_spreads.values())) if track_spreads else 0 + + return { + "n_tracks": len(phenotype_tracks), + "mean_inter_track_distance": mean_distance, + "std_inter_track_distance": std_distance, + "mean_intra_track_spread": mean_spread, + "clustering_ratio": mean_distance / mean_spread if mean_spread > 0 else np.inf, + } + + +def analyze_pc_distributions( + pca_result, + labels, + phenotype_of_interest, + time_points=None, + seed_timepoint=None, + time_window=None, +): + """Analyze the distributions of each PC for phenotype vs background.""" + n_components = pca_result.shape[1] + results = [] + + for i in range(n_components): + # Get phenotype and background points + if ( + time_points is not None + and seed_timepoint is not None + and time_window is not None + ): + time_mask = (time_points >= seed_timepoint - time_window) & ( + time_points <= seed_timepoint + time_window + ) + pc_values_phenotype = pca_result[ + time_mask & (labels == phenotype_of_interest), i + ] + pc_values_background = pca_result[ + time_mask & (labels != phenotype_of_interest), i + ] + else: + pc_values_phenotype = pca_result[labels == phenotype_of_interest, i] + pc_values_background = pca_result[labels != phenotype_of_interest, i] + + # Calculate basic statistics + stats = { + "PC": i + 1, + "phenotype_mean": np.mean(pc_values_phenotype), + "background_mean": np.mean(pc_values_background), + "phenotype_std": np.std(pc_values_phenotype), + "background_std": np.std(pc_values_background), + "separation": abs( + np.mean(pc_values_phenotype) - np.mean(pc_values_background) + ) + / (np.std(pc_values_phenotype) + np.std(pc_values_background)), + } + + # Check for multimodality using a simple peak detection + hist, bins = np.histogram(pc_values_phenotype, bins="auto") + peaks = len( + [ + i + for i in range(1, len(hist) - 1) + if hist[i] > hist[i - 1] and hist[i] > hist[i + 1] + ] + ) + stats["n_peaks"] = peaks + + results.append(stats) + + return pd.DataFrame(results) + + +def analyze_embeddings_with_pca( + embedding_path, + annotation_path=None, + phenotype_of_interest=None, + n_random_tracks=10, + n_components=8, + seed_timepoint=None, + time_window=10, + fov_patterns=None, +): + """Analyze embeddings using PCA, either for specific phenotypes or random tracks. + + Args: + embedding_path: Path to embedding zarr file + annotation_path: Optional path to annotation CSV file. If None, uses random tracks + phenotype_of_interest: Which phenotype to analyze (only used if annotation_path is provided) + n_random_tracks: Number of random tracks to select (only used if annotation_path is None) + n_components: Number of PCA components + seed_timepoint: Center of time window. If None, uses all timepoints + time_window: Size of time window (+/-). Only used if seed_timepoint is not None + fov_patterns: List of patterns to filter FOVs (e.g. ['/C/2/*', '/B/3/*']). + Optional even when using annotation_path - can be used to restrict + analysis to specific FOVs while still using phenotype information. + """ + if annotation_path is None: + print(f"\nUsing random tracks (global seed: {RANDOM_SEED})") + + if seed_timepoint is None: + print("\nUsing all timepoints") + else: + print(f"\nUsing time window: {seed_timepoint}±{time_window}") + + # Load embeddings + embedding_dataset = read_embedding_dataset(embedding_path) + features = embedding_dataset["features"] + track_ids = embedding_dataset["track_id"].values + fovs = embedding_dataset["fov_name"].values + time_points = embedding_dataset["t"].values + + # Filter FOVs if patterns are provided + if fov_patterns is not None: + print(f"\nFiltering FOVs with patterns: {fov_patterns}") + fov_mask = np.zeros_like(fovs, dtype=bool) + for pattern in fov_patterns: + fov_mask |= np.char.find(fovs.astype(str), pattern) >= 0 + + # Update all arrays with the FOV mask + features = features[fov_mask] + track_ids = track_ids[fov_mask] + fovs = fovs[fov_mask] + time_points = time_points[fov_mask] + + print(f"Found {len(np.unique(fovs))} FOVs matching patterns") + + # Get tracks of interest + if annotation_path is not None: + # Load annotations and get phenotype tracks + annotations_df = pd.read_csv(annotation_path) + annotation_map = { + (str(row["FOV"]), int(row["Track_id"])): row["Observed phenotype"] + for _, row in annotations_df.iterrows() + } + labels = np.array( + [ + annotation_map.get((str(fov), int(track_id)), -1) + for fov, track_id in zip(fovs, track_ids) + ] + ) + selection_mask = labels == phenotype_of_interest + tracks_of_interest = np.unique(track_ids[selection_mask]) + other_mask = ~selection_mask + mode = f"phenotype {phenotype_of_interest}" + else: + # Select random tracks from different FOVs when possible + # Create a mapping of FOV to tracks + fov_track_map = {} + for fov, track_id in zip(fovs, track_ids): + if fov not in fov_track_map: + fov_track_map[fov] = [] + if track_id not in fov_track_map[fov]: # Avoid duplicates + fov_track_map[fov].append(track_id) + + # Get list of all FOVs + available_fovs = list(fov_track_map.keys()) + tracks_of_interest = [] + + # First, try to get one track from each FOV + np.random.shuffle(available_fovs) # Randomize FOV order + for fov in available_fovs: + if len(tracks_of_interest) < n_random_tracks: + # Randomly select a track from this FOV + track = np.random.choice(fov_track_map[fov]) + tracks_of_interest.append(track) + else: + break + + # If we still need more tracks, randomly select from remaining tracks + if len(tracks_of_interest) < n_random_tracks: + # Get all remaining tracks that aren't already selected + remaining_tracks = [ + track + for track in np.unique(track_ids) + if track not in tracks_of_interest + ] + # Select additional tracks + additional_tracks = np.random.choice( + remaining_tracks, + size=min( + n_random_tracks - len(tracks_of_interest), len(remaining_tracks) + ), + replace=False, + ) + tracks_of_interest.extend(additional_tracks) + + tracks_of_interest = np.array(tracks_of_interest) + selection_mask = np.isin(track_ids, tracks_of_interest) + other_mask = ~selection_mask + labels = np.where(selection_mask, 1, 0) + mode = "random tracks" + + # Print selected tracks with their FOVs + print("\nSelected tracks:") + for track in tracks_of_interest: + track_fovs = np.unique(fovs[track_ids == track]) + print(f"Track {track}: FOV {track_fovs[0]}") + + # Scale the features + scaler = StandardScaler() + scaled_features = scaler.fit_transform(features.values) + + # Perform PCA + pca = PCA(n_components=n_components) + pca_result = pca.fit_transform(scaled_features) + + # Calculate explained variance + explained_variance_ratio = pca.explained_variance_ratio_ + cumulative_variance_ratio = np.cumsum(explained_variance_ratio) + + # Create track-specific colors + track_colors = plt.cm.tab10(np.linspace(0, 1, len(tracks_of_interest))) + track_color_map = dict(zip(tracks_of_interest, track_colors)) + + # Create plots + fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5)) + + # Scree plot + ax1.plot(range(1, n_components + 1), explained_variance_ratio, "bo-") + ax1.plot(range(1, n_components + 1), cumulative_variance_ratio, "ro-") + ax1.set_xlabel("Principal Component") + ax1.set_ylabel("Explained Variance Ratio") + ax1.set_title("Scree Plot") + ax1.legend(["Individual", "Cumulative"]) + + # First two components plot + # Plot other tracks/cells in gray + ax2.scatter( + pca_result[other_mask, 0], + pca_result[other_mask, 1], + alpha=0.1, + color="gray", + label="Other cells", + s=10, + ) + + # Plot tracks of interest with decreasing opacity + for track_id in tracks_of_interest: + track_mask = track_ids == track_id + track_points = pca_result[track_mask] + track_times = time_points[track_mask] + + # Sort points by time + sort_idx = np.argsort(track_times) + track_points = track_points[sort_idx] + track_times = track_times[sort_idx] + + # Apply time window if specified + if seed_timepoint is not None: + time_mask = (track_times >= seed_timepoint - time_window) & ( + track_times <= seed_timepoint + time_window + ) + else: + time_mask = np.ones_like(track_times, dtype=bool) # Use all points + + if np.any(time_mask): # Only plot if there are points in the window + window_points = track_points[time_mask] + window_times = track_times[time_mask] + + # Normalize times within window for opacity + norm_times = (window_times - window_times.min()) / ( + window_times.max() - window_times.min() + 1e-10 + ) + alphas = 0.2 + 0.8 * norm_times # Scale to [0.2, 1.0] + + # Plot points with opacity based on normalized time + for idx in range(len(window_points)): + ax2.scatter( + window_points[idx, 0], + window_points[idx, 1], + color=track_color_map[track_id], + alpha=alphas[idx], + s=50, + label=( + f"Track {track_id}" if idx == len(window_points) - 1 else None + ), + ) + + ax2.set_xlabel("First Principal Component") + ax2.set_ylabel("Second Principal Component") + title = f"First Two Principal Components - {mode}" + if seed_timepoint is not None: + title += f"\nTime window: {seed_timepoint}±{time_window}" + ax2.set_title(title) + ax2.legend(bbox_to_anchor=(1.05, 1), loc="upper left") + + plt.tight_layout() + plt.show() + + # Pairwise component plots + fig, axes = plt.subplots(n_components, n_components, figsize=(20, 20)) + + for i in range(n_components): + for j in range(n_components): + if i != j: + # Plot other points first + axes[i, j].scatter( + pca_result[other_mask, j], + pca_result[other_mask, i], + alpha=0.1, + color="gray", + s=5, + ) + + # Plot each track with decreasing opacity + for track_id in tracks_of_interest: + track_mask = track_ids == track_id + track_points_j = pca_result[track_mask, j] + track_points_i = pca_result[track_mask, i] + track_times = time_points[track_mask] + + # Sort points by time + sort_idx = np.argsort(track_times) + track_points_j = track_points_j[sort_idx] + track_points_i = track_points_i[sort_idx] + track_times = track_times[sort_idx] + + # Select points within the time window + time_mask = (track_times >= seed_timepoint - time_window) & ( + track_times <= seed_timepoint + time_window + ) + if np.any(time_mask): # Only plot if there are points in the window + window_points_j = track_points_j[time_mask] + window_points_i = track_points_i[time_mask] + window_times = track_times[time_mask] + + # Normalize times within window for opacity + norm_times = (window_times - window_times.min()) / ( + window_times.max() - window_times.min() + 1e-10 + ) + alphas = 0.2 + 0.8 * norm_times # Scale to [0.2, 1.0] + + # Plot points with opacity based on normalized time + for idx in range(len(window_points_j)): + axes[i, j].scatter( + window_points_j[idx], + window_points_i[idx], + color=track_color_map[track_id], + alpha=alphas[idx], + s=30, + ) + + axes[i, j].set_xlabel(f"PC{j+1}") + axes[i, j].set_ylabel(f"PC{i+1}") + else: + # On diagonal, show distribution + sns.histplot( + pca_result[other_mask, i], ax=axes[i, i], color="gray", alpha=0.3 + ) + for track_id in tracks_of_interest: + track_mask = track_ids == track_id + # For histograms, use all points in the time window + time_mask = ( + time_points[track_mask] >= seed_timepoint - time_window + ) & (time_points[track_mask] <= seed_timepoint + time_window) + if np.any(time_mask): + sns.histplot( + pca_result[track_mask][time_mask, i], + ax=axes[i, i], + color=track_color_map[track_id], + alpha=0.5, + ) + axes[i, i].set_xlabel(f"PC{i+1}") + + plt.tight_layout() + plt.show() + + # Print variance explained + print("\nExplained variance ratio by component:") + for i, var in enumerate(explained_variance_ratio): + print(f"PC{i+1}: {var:.3f} ({cumulative_variance_ratio[i]:.3f} cumulative)") + + # Add analysis of PC loadings + pc_analysis = analyze_pc_loadings(pca) + print("\nPC Loading Analysis:") + print(pc_analysis.to_string(index=False)) + + # Add analysis of track clustering + cluster_analysis = analyze_track_clustering( + pca_result, + track_ids, + time_points, + labels, + 1 if annotation_path is None else phenotype_of_interest, + seed_timepoint, + time_window, + ) + + if cluster_analysis: + print("\nTrack Clustering Analysis:") + print(f"Number of tracks in window: {cluster_analysis['n_tracks']}") + print( + f"Mean distance between tracks: {cluster_analysis['mean_inter_track_distance']:.3f}" + ) + print( + f"Mean spread within tracks: {cluster_analysis['mean_intra_track_spread']:.3f}" + ) + print( + f"Clustering ratio (inter/intra): {cluster_analysis['clustering_ratio']:.3f}" + ) + print("(Lower clustering ratio suggests tighter clustering)") + + # Add distribution analysis + dist_analysis = analyze_pc_distributions( + pca_result, + labels, + 1 if annotation_path is None else phenotype_of_interest, + time_points if seed_timepoint is not None else None, + seed_timepoint, + time_window, + ) + print("\nPC Distribution Analysis:") + print( + "(Separation score > 1 suggests good separation between selected tracks and background)" + ) + print(dist_analysis.to_string(index=False)) + + return ( + pca, + pca_result, + explained_variance_ratio, + labels, + tracks_of_interest, + pc_analysis, + cluster_analysis, + dist_analysis, + ) + + +# %% +if __name__ == "__main__": + embedding_path = "/hpc/projects/intracellular_dashboard/organelle_dynamics/2024_11_07_A549_SEC61_ZIKV_DENV/3-phenotyping/predictions/timeAware_2chan__ntxent_192patch_70ckpt_rev7_GT.zarr" + annotation_path = "/home/eduardo.hirata/repos/viscy/applications/pseudotime_analysis/phenotype_observations.csv" + # %% + # Using phenotype annotations with specific FOVs + print("\nAnalyzing phenotype 1 in specific FOVs:") + ( + pca, + pca_result, + variance_ratio, + labels, + tracks, + pc_analysis, + cluster_analysis, + dist_analysis, + ) = analyze_embeddings_with_pca( + embedding_path, + annotation_path=annotation_path, + phenotype_of_interest=1, + seed_timepoint=55, + time_window=10, + fov_patterns=["/C/2/", "/B/3/", "/B/2/"], # Specify FOV patterns + ) + + # Using random tracks from specific FOVs + print("\nAnalyzing random tracks from specific FOVs:") + ( + pca, + pca_result, + variance_ratio, + labels, + tracks, + pc_analysis, + cluster_analysis, + dist_analysis, + ) = analyze_embeddings_with_pca( + embedding_path, + annotation_path=None, # This triggers random track selection + n_random_tracks=10, + seed_timepoint=55, + time_window=30, + fov_patterns=["/C/2/", "/B/3/", "/B/2/"], # Specify FOV patterns + ) + +# %% diff --git a/applications/timearrow_phenotyping/test_tarrow.py b/applications/timearrow_phenotyping/test_tarrow.py new file mode 100644 index 000000000..f4d2530a7 --- /dev/null +++ b/applications/timearrow_phenotyping/test_tarrow.py @@ -0,0 +1,57 @@ +# %% Imports +import torch +import torchview +from viscy.data.tarrow import TarrowDataModule +from viscy.representation.timearrow import TarrowModule + + +# %% Load minimal config +config = { + 'data': { + 'init_args': { + 'ome_zarr_path': '/hpc/projects/organelle_phenotyping/ALFI_models_data/datasets/zarr_datasets/float_phase_ome_zarr_output_valtrain.zarr', # Replace with actual path + 'channel_name': 'DIC', + 'patch_size': [256, 256], + 'batch_size': 32, + 'num_workers': 4, + 'train_split': 0.8 + } + }, + 'model': { + 'init_args': { + 'backbone': 'unet', + 'projection_head': 'minimal_batchnorm', + 'classification_head': 'minimal', + } + } +} + +# # Optionally load config from file +# config_path = "/hpc/projects/organelle_phenotyping/models/ALFI/tarrow_test/tarrow.yml" +# with open(config_path) as f: +# config = yaml.safe_load(f) + +# %% Initialize data and model +data_module = TarrowDataModule(**config['data']['init_args']) +model = TarrowModule(**config['model']['init_args']) +# %% Construct a batch of data from the data module +data_module.setup('fit') +batch = next(iter(data_module.train_dataloader())) +images, labels = batch +print(model) +# %% Print model graph. +try: + # Try constructing the graph + model_graph = torchview.draw_graph( + model, + input_data=images, + save_graph=False, # Don't save, just display + expand_nested=True, + device='cpu' # specify CPU device + ) +except Exception as e: + print(f"Error generating model graph: {e}") + +model_graph.visual_graph # Display the graph + +# %% diff --git a/tests/representation/test_gradcam.py b/tests/representation/test_gradcam.py new file mode 100644 index 000000000..ec919a187 --- /dev/null +++ b/tests/representation/test_gradcam.py @@ -0,0 +1,247 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +from lightning.pytorch import LightningModule, Trainer +from torch.utils.data import DataLoader +from lightning.pytorch.loggers import TensorBoardLogger +import matplotlib.pyplot as plt +import numpy as np + +from viscy.callbacks.gradcam import GradCAMCallback + + +class ResNetClassifier(LightningModule): + def __init__(self, num_classes=10): + super().__init__() + # Load pretrained ResNet18 + self.model = torchvision.models.resnet18(pretrained=True) + + # Replace final layer for CIFAR-10 + self.model.fc = nn.Linear(512, num_classes) + + # Save the target layer for GradCAM + self.target_layer = self.model.layer4[-1] + + # Ensure gradients are enabled for the target layer + for param in self.target_layer.parameters(): + param.requires_grad = True + + self.gradients = None + self.activations = None + + def forward(self, x): + return self.model(x) + + def training_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + self.log("train_loss", loss) + return loss + + def validation_step(self, batch, batch_idx): + x, y = batch + logits = self(x) + loss = F.cross_entropy(logits, y) + self.log("val_loss", loss) + return loss + + def configure_optimizers(self): + return torch.optim.Adam(self.parameters(), lr=0.001) + + # GradCAM methods + def activations_hook(self, grad): + self.gradients = grad + + def get_activations(self, x): + return self.activations + + def gradcam(self, x): + # Store original training mode and switch to eval mode + was_training = self.training + self.eval() # Use eval mode for inference + + try: + # Register hooks + h = self.target_layer.register_forward_hook( + lambda module, input, output: setattr(self, "activations", output) + ) + h_bp = self.target_layer.register_backward_hook( + lambda module, grad_in, grad_out: self.activations_hook(grad_out[0]) + ) + + # Forward pass + x = x.unsqueeze(0).to(self.device) # Add batch dimension + + # Enable gradients for the entire computation + with torch.enable_grad(): + x = x.requires_grad_(True) + output = self(x) + + # Get predicted class + pred = output.argmax(dim=1) + + # Create one hot vector for backward pass + one_hot = torch.zeros_like(output, device=self.device) + one_hot[0][pred] = 1 + + # Clear gradients + self.zero_grad(set_to_none=False) + + # Backward pass + output.backward(gradient=one_hot) + + # Generate GradCAM + gradients = self.gradients + activations = self.activations + + # Ensure we have valid gradients + if gradients is None: + raise RuntimeError("No gradients available for GradCAM computation") + + weights = torch.mean(gradients, dim=(2, 3)) + cam = torch.sum(weights[:, :, None, None] * activations, dim=1) + cam = F.relu(cam) + cam = ( + F.interpolate( + cam.unsqueeze(0), + size=x.shape[2:], + mode="bilinear", + align_corners=False, + )[0, 0] + .cpu() + .detach() + .numpy() + ) + + return cam + + finally: + # Clean up + h.remove() + h_bp.remove() + # Restore original training mode + self.train(mode=was_training) + + +def main(): + # Data transforms + transform = transforms.Compose( + [ + transforms.Resize(224), # ResNet expects 224x224 images + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + + # Load CIFAR-10 dataset + train_dataset = torchvision.datasets.CIFAR10( + root="./data", train=True, download=True, transform=transform + ) + + val_dataset = torchvision.datasets.CIFAR10( + root="./data", train=False, download=True, transform=transform + ) + + # Create small visualization dataset + vis_dataset = torch.utils.data.Subset(val_dataset, indices=range(10)) + + # Create data loaders + train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=32) + vis_loader = DataLoader(vis_dataset, batch_size=32) # Added visualization loader + + # Initialize model + model = ResNetClassifier() + + # Initialize callbacks + gradcam_callback = GradCAMCallback( + visual_datasets=[vis_dataset], + every_n_epochs=1, # Generate visualizations every epoch + max_samples=5, # Visualize 5 samples + max_height=224, # Match ResNet input size + ) + + # Initialize trainer with specific logger + trainer = Trainer( + max_epochs=5, + callbacks=[gradcam_callback], + accelerator="auto", + devices=1, + logger=TensorBoardLogger( + save_dir="/home/eduardo.hirata/repos/viscy/tests/representation/lightning_logs", # specify your desired log directory + name="gradcam_cifar", # experiment name + ), + ) + + # Train model + trainer.fit(model, train_loader, val_loader) + + # Test GradCAM visualization + test_gradcam_visualization(model, vis_loader) + + +def test_gradcam_visualization(model, dataloader): + """Test GradCAM visualization. + + Parameters + ---------- + model : LightningModule + The trained model + dataloader : DataLoader + DataLoader containing samples to visualize + """ + model.eval() + # Get a sample from validation set + batch = next(iter(dataloader)) + images, labels = batch + + # Generate GradCAM for first sample + sample_img = images[0] # Shape: (C, H, W) + cam = model.gradcam(sample_img) + + # Plot the results + fig, axes = plt.subplots(1, 3, figsize=(30, 10)) + + # Original image + img = images[0].squeeze().cpu().numpy() + if img.ndim == 3: # Handle RGB images + axes[0].imshow(np.transpose(img, (1, 2, 0))) + else: # Handle grayscale images + axes[0].imshow(img, cmap="gray") + axes[0].set_title("Original Image") + plt.colorbar(axes[0].images[0], ax=axes[0]) + + # GradCAM visualization + im = axes[1].imshow(cam, cmap="magma") + axes[1].set_title("GradCAM") + plt.colorbar(im, ax=axes[1]) + + # Overlay GradCAM on original image + img = images[0].squeeze().cpu().numpy() + if img.ndim == 3: # Handle RGB images + img = np.transpose(img, (1, 2, 0)) + img = (img - img.min()) / (img.max() - img.min()) # Normalize to [0,1] + cam_norm = (cam - cam.min()) / (cam.max() - cam.min()) # Normalize to [0,1] + + # Create RGB overlay + if img.ndim == 2: # Convert grayscale to RGB + img_rgb = np.stack([img] * 3, axis=-1) + else: # Already RGB + img_rgb = img + cam_rgb = plt.cm.magma(cam_norm)[..., :3] # Convert to RGB using magma colormap + overlay = 0.7 * img_rgb + 0.3 * cam_rgb + + axes[2].imshow(overlay) + axes[2].set_title("GradCAM Overlay") + + plt.suptitle(f"GradCAM Visualization (Predicted: {labels[0].item()})", y=1.05) + plt.savefig("./gradcam_cifar.png") + plt.close() + # plt.show() + + +if __name__ == "__main__": + main() diff --git a/viscy/callbacks/gradcam.py b/viscy/callbacks/gradcam.py new file mode 100644 index 000000000..1a2c04dbb --- /dev/null +++ b/viscy/callbacks/gradcam.py @@ -0,0 +1,158 @@ +import logging + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torchvision +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.callbacks import Callback +from skimage.exposure import rescale_intensity + +logger = logging.getLogger(__name__) + + +class GradCAMCallback(Callback): + """Callback for computing and logging GradCAM visualizations. + + Parameters + ---------- + every_n_epochs : int, default=10 + Generate visualizations every n epochs + max_samples : int, default=5 + Maximum number of samples to visualize per dataset + mode : str, default="overlay" + Visualization mode: "separate" for individual images and activations, + or "overlay" for activation map overlaid on input image + """ + + def __init__( + self, + every_n_epochs: int = 10, + max_samples: int = 5, + mode: str = "overlay", + ): + super().__init__() + self.every_n_epochs = every_n_epochs + self.max_samples = max_samples + assert mode in ["separate", "overlay"], "Mode must be 'separate' or 'overlay'" + self.mode = mode + + def on_validation_epoch_end( + self, trainer: Trainer, pl_module: LightningModule + ) -> None: + """Generate and log GradCAM visualizations""" + if (trainer.current_epoch + 1) % self.every_n_epochs != 0: + return + + if not hasattr(trainer.datamodule, "visual_dataloader"): + logger.warning( + "DataModule does not have visual_dataloader method. Skipping GradCAM visualization." + ) + return + + pl_module.eval() + device = pl_module.device + + # Get visual dataloader from the datamodule + visual_loader = trainer.datamodule.visual_dataloader() + + # Get a few samples + samples = [] + activations = [] + + for batch_idx, (x, _) in enumerate(visual_loader): + if batch_idx >= self.max_samples: + break + + try: + # Move tensor to same device as model + x = x.to(device) + + # Generate class activation map + activation_map = pl_module.gradcam(x) + + # Convert to RGB images for visualization + # Handle 5D tensor [B, T, C, H, W] -> take first batch and timepoint + x_img = x[0, 0].cpu().numpy() # Take first batch and timepoint + if x_img.ndim == 3: # Handle [C, H, W] case + x_img = x_img[0] # Take first channel to get [H, W] + x_img = rescale_intensity(x_img, in_range="image", out_range=(0, 1)) + + # Create activation map visualization + activation_norm = self._normalize_cam(torch.from_numpy(activation_map)) + activation_rgb = plt.cm.magma(activation_norm.numpy())[..., :3] + + if self.mode == "separate": + # Keep sample as grayscale + x_vis = ( + torch.from_numpy(x_img).unsqueeze(0).float() + ) # Add channel dim [1, H, W] + activation_vis = ( + torch.from_numpy(activation_rgb).permute(2, 0, 1).float() + ) # [3, H, W] + else: # overlay mode + # Convert input to RGB + x_rgb = np.stack([x_img] * 3, axis=-1) # [H, W, 3] + # Create overlay + overlay = self._create_overlay(x_rgb, activation_rgb) + x_vis = ( + torch.from_numpy(x_rgb).permute(2, 0, 1).float() + ) # [3, H, W] + activation_vis = ( + torch.from_numpy(overlay).permute(2, 0, 1).float() + ) # [3, H, W] + + samples.append(x_vis.cpu()) # Ensure on CPU for visualization + activations.append( + activation_vis.cpu() + ) # Ensure on CPU for visualization + + except Exception as e: + logger.error(f"Error processing sample {batch_idx}: {str(e)}") + continue + + if samples: # Only proceed if we have samples + try: + # Stack images for grid visualization + samples_grid = torchvision.utils.make_grid( + samples, nrow=len(samples), normalize=True, value_range=(0, 1) + ) + activations_grid = torchvision.utils.make_grid( + activations, + nrow=len(activations), + normalize=True, + value_range=(0, 1), + ) + + # Log to tensorboard + trainer.logger.experiment.add_image( + f"gradcam/samples", + samples_grid, + trainer.current_epoch, + ) + trainer.logger.experiment.add_image( + f"gradcam/{'overlays' if self.mode == 'overlay' else 'activations'}", + activations_grid, + trainer.current_epoch, + ) + except Exception as e: + logger.error(f"Error creating visualization grid: {str(e)}") + + @staticmethod + def _tensor_to_img(tensor: torch.Tensor) -> torch.Tensor: + """Convert tensor to normalized image tensor""" + img = tensor.cpu().numpy() + img = (img - img.min()) / (img.max() - img.min() + 1e-7) + return img + + @staticmethod + def _create_overlay( + img: torch.Tensor, cam: torch.Tensor, alpha: float = 0.5 + ) -> torch.Tensor: + """Create overlay of image and CAM""" + return (1 - alpha) * img + alpha * cam + + @staticmethod + def _normalize_cam(cam: torch.Tensor) -> torch.Tensor: + """Normalize CAM to [0,1]""" + return (cam - cam.min()) / (cam.max() - cam.min() + 1e-8) diff --git a/viscy/data/hcs.py b/viscy/data/hcs.py index 4a7e9dfb5..ac522969d 100644 --- a/viscy/data/hcs.py +++ b/viscy/data/hcs.py @@ -25,6 +25,7 @@ from torch.utils.data import DataLoader, Dataset from viscy.data.typing import ChannelMap, DictTransform, HCSStackIndex, NormMeta, Sample +from viscy.utils.engine_state import set_fit_global_state _logger = logging.getLogger("lightning.pytorch") @@ -434,12 +435,6 @@ def setup(self, stage: Literal["fit", "validate", "test", "predict"]): else: raise NotImplementedError(f"{stage} stage") - def _set_fit_global_state(self, num_positions: int) -> torch.Tensor: - # disable metadata tracking in MONAI for performance - set_track_meta(False) - # shuffle positions, randomness is handled globally - return torch.randperm(num_positions) - def _setup_fit(self, dataset_settings: dict): """Set up the training and validation datasets.""" train_transform, val_transform = self._fit_transform() @@ -449,7 +444,7 @@ def _setup_fit(self, dataset_settings: dict): # shuffle positions, randomness is handled globally positions = [pos for _, pos in plate.positions()] - shuffled_indices = self._set_fit_global_state(len(positions)) + shuffled_indices = set_fit_global_state(len(positions)) positions = list(positions[i] for i in shuffled_indices) num_train_fovs = int(len(positions) * self.split_ratio) # training set needs to sample more Z range for augmentation diff --git a/viscy/data/tarrow.py b/viscy/data/tarrow.py new file mode 100644 index 000000000..d89771301 --- /dev/null +++ b/viscy/data/tarrow.py @@ -0,0 +1,312 @@ +from pathlib import Path +from typing import Callable, Sequence + +import numpy as np +import torch.nn as nn +from iohub.ngff import Position, open_ome_zarr +from lightning.pytorch import LightningDataModule +from tarrow.data.tarrow_dataset import TarrowDataset +from torch.utils.data import ConcatDataset, DataLoader + +# FIXME: This module is not available in the viscy package,so shuffle the list of datasets manually. +# from viscy.utils.engine_state import set_fit_global_state +import random + + +class TarrowDataModule(LightningDataModule): + """Lightning DataModule for TimeArrowNet training. + + Parameters + ---------- + ome_zarr_path : str or Path + Path to OME-Zarr file + channel_name : str + Name of the channel to load + train_split : float, default=0.8 + Fraction of data to use for training (0.0 to 1.0) + patch_size : tuple[int, int], default=(128, 128) + Patch size for TarrowDataset + visual_patch_size : tuple[int, int] | None, default=None + Patch size for visualization dataset + visual_batch_size : int | None, default=None + Batch size for visualization dataloader + batch_size : int, default=16 + Batch size for dataloaders + num_workers : int, default=8 + Number of workers for dataloaders + prefetch_factor : int, optional + Prefetch factor for dataloaders + include_fov_names : list[str], default=[] + List of FOV names to include. If empty, use all FOVs + train_samples_per_epoch : int, default=100000 + Number of training samples per epoch + val_samples_per_epoch : int, default=10000 + Number of validation samples per epoch + resolution : int, default=0 + Resolution level to load from OME-Zarr + normalization : function, optional (default=None) + Normalization function to apply to images + z_slice : int, default=0 + Z-slice to load + pin_memory : bool, default=True + Whether to pin memory + persistent_workers : bool, default=True + Whether to keep the workers alive between epochs + augmentations : list[nn.Module], default=[] + List of Kornia augmentation transforms to apply during training + **kwargs : dict + Additional arguments passed to TarrowDataset + """ + + def __init__( + self, + ome_zarr_path: str | Path, + channel_name: str, + train_split: float = 0.8, + batch_size: int = 16, + num_workers: int = 8, + patch_size: tuple[int, int] = (128, 128), + visual_patch_size: tuple[int, int] | None = None, + visual_batch_size: int | None = None, + prefetch_factor: int | None = None, + include_fov_names: list[str] = [], + train_samples_per_epoch: int = 100000, + val_samples_per_epoch: int = 10000, + resolution: int = 0, + z_slice: int = 0, + normalization: Callable[[np.ndarray], np.ndarray] | None = None, + pin_memory: bool = True, + persistent_workers: bool = True, + augmentations: Sequence[nn.Module] = [], + **kwargs, + ): + super().__init__() + self.ome_zarr_path = ome_zarr_path + self.channel_name = channel_name + self.train_split = train_split + self.batch_size = batch_size + self.num_workers = num_workers + self.prefetch_factor = prefetch_factor + self.patch_size = patch_size + self.visual_patch_size = visual_patch_size or tuple(4 * x for x in patch_size) + self.visual_batch_size = visual_batch_size or min(4, batch_size) + self.include_fov_names = include_fov_names + self.train_samples_per_epoch = train_samples_per_epoch + self.val_samples_per_epoch = val_samples_per_epoch + self.resolution = resolution + self.z_slice = z_slice + self.kwargs = kwargs + self.normalization = normalization + self.pin_memory = pin_memory + self.persistent_workers = persistent_workers + self.augmentations = augmentations + + self._filter_positions() + self._channel_idx = self._get_channel_index() + + def _get_channel_index(self) -> int: + """Get the index of the specified channel from the plate metadata.""" + with open_ome_zarr(self.ome_zarr_path, mode="r") as plate: + _, first_pos = next(plate.positions()) + return first_pos.channel_names.index(self.channel_name) + + def _create_augmentation_pipeline(self) -> nn.Sequential | None: + """Create the augmentation pipeline for training. + + Returns + ------- + nn.Sequential | None + Sequential container of Kornia augmentations or None if no augmentations + """ + if not self.augmentations: + return None + + return nn.Sequential(*self.augmentations) + + def _load_images(self, position: Position, channel_idx: int) -> list[np.ndarray]: + """Load all images from positions into memory. + + Parameters + ---------- + position : Position + Position to load + channel_idx : int + Index of channel to load + + Returns + ------- + list[np.ndarray] + List of 2D numpy arrays + """ + imgs = [] + img_arr = position[str(self.resolution)] + # Load all timepoints for this position + for t in range(len(img_arr)): + imgs.append(img_arr[t, channel_idx, self.z_slice]) + return imgs + + def setup(self, stage: str): + """Set up the data module for a specific stage. + + Parameters + ---------- + stage : str + Stage to set up for ("fit", "test", or "predict") + + Raises + ------ + NotImplementedError + If stage is not "fit" + """ + if stage == "fit": + list_dataset = [] + list_visual_dataset = [] + + # Create augmentation pipeline + augmenter = self._create_augmentation_pipeline() + + for pos in self.positions: + pos_imgs = self._load_images(pos, self._channel_idx) + list_dataset.append( + TarrowDataset( + imgs=pos_imgs, + normalize=self.normalization, + size=self.patch_size, + augmenter=augmenter, # Pass augmenter to dataset + **self.kwargs, + ) + ) + # Create visualization dataset with larger patches + list_visual_dataset.append( + TarrowDataset( + imgs=pos_imgs, + normalize=self.normalization, + size=self.visual_patch_size, + **self.kwargs, + ) + ) + + # Calculate split point + split_idx = int(len(self.positions) * self.train_split) + + # Shuffle the list of datasets + + #FIXME: This module is not available in the viscy package,so shuffle the list of datasets manually. + # shuffled_indices = set_fit_global_state(len(list_dataset)) + shuffled_indices = list(range(len(list_dataset))) + random.shuffle(shuffled_indices) + list_dataset = [list_dataset[i] for i in shuffled_indices] + list_visual_dataset = [ + list_visual_dataset[i] for i in shuffled_indices + ] # Use same shuffling + + # Create training dataset with first train_split% of images + self.train_dataset = ConcatDataset(list_dataset[:split_idx]) + self.val_dataset = ConcatDataset(list_dataset[split_idx:]) + + # Take up to n_visual samples from validation set + # NOTE fixed to take the first n_visual samples from validation set + self.visual_batch_size = max( + len(list_visual_dataset[split_idx:]), self.visual_batch_size + ) + self.visual_dataset = ConcatDataset( + list_visual_dataset[split_idx : split_idx + self.visual_batch_size] + ) + + elif stage == "test": + raise NotImplementedError(f"Invalid stage: {stage}") + elif stage == "predict": + raise NotImplementedError(f"Invalid stage: {stage}") + else: + raise NotImplementedError(f"Invalid stage: {stage}") + + def _filter_positions(self): + """Filter positions based on include_fov_names.""" + # Get the positions to load + plate = open_ome_zarr(self.ome_zarr_path, mode="r") + if self.include_fov_names: + positions = [] + for fov_str, pos in plate.positions(): + normalized_include_fovs = [ + f.lstrip("/") for f in self.include_fov_names + ] + if fov_str in normalized_include_fovs: + positions.append(pos) + else: + positions = [pos for _, pos in plate.positions()] + + self.positions = positions + + def train_dataloader(self): + """Create the training dataloader. + + Returns + ------- + torch.utils.data.DataLoader + DataLoader for training data + """ + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=True if self.num_workers > 0 else False, + prefetch_factor=self.prefetch_factor if self.num_workers else None, + pin_memory=self.pin_memory, + shuffle=True, + ) + + def val_dataloader(self): + """Create the validation dataloader. + + Returns + ------- + torch.utils.data.DataLoader + DataLoader for validation data + """ + return DataLoader( + self.val_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + persistent_workers=True if self.num_workers > 0 else False, + prefetch_factor=self.prefetch_factor if self.num_workers else None, + pin_memory=self.pin_memory, + shuffle=False, + ) + + def visual_dataloader(self): + """Create the visualization dataloader. + + Returns + ------- + torch.utils.data.DataLoader + DataLoader for visualization data + """ + return DataLoader( + self.visual_dataset, + batch_size=self.visual_batch_size, + num_workers=self.num_workers, + persistent_workers=True if self.num_workers > 0 else False, + prefetch_factor=self.prefetch_factor if self.num_workers else None, + pin_memory=self.pin_memory, + shuffle=False, + ) + + def test_dataloader(self): + """Create the test dataloader. + + Returns + ------- + torch.utils.data.DataLoader + DataLoader for test data without shuffling + + Raises + ------ + NotImplementedError + Test stage is not implemented yet + """ + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + shuffle=False, + ) diff --git a/viscy/representation/evaluation/distance.py b/viscy/representation/evaluation/distance.py index a920eb072..858823785 100644 --- a/viscy/representation/evaluation/distance.py +++ b/viscy/representation/evaluation/distance.py @@ -1,12 +1,74 @@ from collections import defaultdict from typing import Literal +import matplotlib.pyplot as plt import numpy as np +import pandas as pd +import seaborn as sns +from numpy.typing import NDArray +from scipy.optimize import minimize_scalar +from scipy.stats import gaussian_kde from sklearn.metrics.pairwise import cosine_similarity +from sklearn.preprocessing import StandardScaler +from tqdm import tqdm +from viscy.representation.embedding_writer import read_embedding_dataset +from viscy.representation.evaluation.clustering import ( + compare_time_offset, + pairwise_distance_matrix, + rank_nearest_neighbors, + select_block, +) -def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): - """Extract embeddings and calculate cosine similarities for a specific cell""" + +def calculate_distance_cell( + embedding_dataset, + fov_name, + track_id, + metric: Literal["cosine", "euclidean", "normalized_euclidean"] = "cosine", +): + """ + Calculate distances between a cell's first timepoint embedding and all its subsequent embeddings. + + This function extracts embeddings for a specific cell (identified by fov_name and track_id) + and calculates the distance between its first timepoint embedding and all subsequent timepoints + using the specified distance metric. + + Parameters + ---------- + embedding_dataset : xarray.Dataset + Dataset containing the embeddings and metadata. Must have dimensions for 'features', + 'fov_name', 'track_id', and 't' (time). + fov_name : str + Field of view name to identify the specific imaging area. + track_id : int + Track ID of the cell to analyze. + metric : {'cosine', 'euclidean', 'normalized_euclidean'}, default='cosine' + Distance metric to use for calculations: + - 'cosine': Cosine similarity between embeddings + - 'euclidean': Standard Euclidean distance + - 'normalized_euclidean': Euclidean distance between L2-normalized embeddings + + Returns + ------- + time_points : numpy.ndarray + Array of time points corresponding to the calculated distances. + distances : list + List of distances between the first timepoint embedding and each subsequent + timepoint embedding, calculated using the specified metric. + + Notes + ----- + For 'normalized_euclidean', embeddings are L2-normalized before distance calculation. + Cosine similarity results in values between -1 and 1, where 1 indicates identical + direction, 0 indicates orthogonality, and -1 indicates opposite directions. + Euclidean distances are always non-negative. + + Examples + -------- + >>> times, distances = calculate_distance_cell(dataset, "FOV1", 1, metric="cosine") + >>> times, distances = calculate_distance_cell(dataset, "FOV1", 1, metric="euclidean") + """ filtered_data = embedding_dataset.where( (embedding_dataset["fov_name"] == fov_name) & (embedding_dataset["track_id"] == track_id), @@ -14,11 +76,18 @@ def calculate_cosine_similarity_cell(embedding_dataset, fov_name, track_id): ) features = filtered_data["features"].values # (sample, features) time_points = filtered_data["t"].values # (sample,) + + if metric == "normalized_euclidean": + features = features / np.linalg.norm(features, axis=1, keepdims=True) + first_time_point_embedding = features[0].reshape(1, -1) - cosine_similarities = cosine_similarity( - first_time_point_embedding, features - ).flatten() - return time_points, cosine_similarities.tolist() + + if metric == "cosine": + distances = cosine_similarity(first_time_point_embedding, features).flatten() + else: # both euclidean and normalized_euclidean use norm + distances = np.linalg.norm(first_time_point_embedding - features, axis=1) + + return time_points, distances.tolist() def compute_displacement( @@ -130,18 +199,18 @@ def compute_displacement_statistics( return mean_displacement_per_tau, std_displacement_per_tau -def compute_dynamic_range(mean_displacement_per_tau): +def compute_dynamic_range(mean_displacement_per_delta_t): """ Compute the dynamic range as the difference between the maximum and minimum mean displacement per τ. Parameters: - mean_displacement_per_tau: dict with τ as key and mean displacement as value + mean_displacement_per_delta_t: dict with τ as key and mean displacement as value Returns: float: dynamic range (max displacement - min displacement) """ - displacements = list(mean_displacement_per_tau.values()) + displacements = list(mean_displacement_per_delta_t.values()) return max(displacements) - min(displacements) @@ -193,17 +262,205 @@ def compute_rms_per_track(embedding_dataset): return rms_values -def calculate_normalized_euclidean_distance_cell(embedding_dataset, fov_name, track_id): - filtered_data = embedding_dataset.where( - (embedding_dataset["fov_name"] == fov_name) - & (embedding_dataset["track_id"] == track_id), - drop=True, +def find_distribution_peak(data: np.ndarray) -> float: + """ + Find the peak (mode) of a distribution using kernel density estimation. + + Args: + data: Array of values to find the peak for + + Returns: + float: The x-value where the peak occurs + """ + kde = gaussian_kde(data) + # Find the peak (maximum) of the KDE + result = minimize_scalar( + lambda x: -kde(x), bounds=(np.min(data), np.max(data)), method="bounded" ) - features = filtered_data["features"].values # (sample, features) - time_points = filtered_data["t"].values # (sample,) - normalized_features = features / np.linalg.norm(features, axis=1, keepdims=True) - first_time_point_embedding = normalized_features[0].reshape(1, -1) - euclidean_distances = np.linalg.norm( - first_time_point_embedding - normalized_features, axis=1 + return result.x + + +def compute_piece_wise_dissimilarity( + features_df: pd.DataFrame, cross_dist: NDArray, rank_fractions: NDArray +): + """ + Computing the smoothness and dynamic range + - Get the off diagonal per block and compute the mode + - The blocks are not square, so we need to get the off diagonal elements + - Get the 1 and 99 percentile of the off diagonal per block + """ + piece_wise_dissimilarity_per_track = [] + piece_wise_rank_difference_per_track = [] + for name, subdata in features_df.groupby(["fov_name", "track_id"]): + if len(subdata) > 1: + indices = subdata.index.values + single_track_dissimilarity = select_block(cross_dist, indices) + single_track_rank_fraction = select_block(rank_fractions, indices) + piece_wise_dissimilarity = compare_time_offset( + single_track_dissimilarity, time_offset=1 + ) + piece_wise_rank_difference = compare_time_offset( + single_track_rank_fraction, time_offset=1 + ) + piece_wise_dissimilarity_per_track.append(piece_wise_dissimilarity) + piece_wise_rank_difference_per_track.append(piece_wise_rank_difference) + return piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track + + +def compute_embedding_distances( + prediction_path: Path, + output_path: Path, + distance_metric: Literal["cosine", "euclidean", "normalized_euclidean"] = "cosine", + verbose: bool = False, +) -> pd.DataFrame: + """ + Compute and save pairwise distances between embeddings. + + Parameters + ---------- + prediction_path : Path + Path to the embedding dataset + output_path : Path + name of saved CSV file + distance_metric : str, optional + Distance metric to use for computing distances between embeddings + verbose : bool, optional + If True, plots the distance matrix visualization + + Returns + ------- + pd.DataFrame + DataFrame containing the adjacent frame and random sampling distances + """ + # Read the dataset + embeddings = read_embedding_dataset(prediction_path) + features = embeddings["features"] + + if distance_metric != "euclidean": + features = StandardScaler().fit_transform(features.values) + + # Compute the distance matrix + cross_dist = pairwise_distance_matrix(features, metric=distance_metric) + + # Normalize by sqrt of embedding dimension if using euclidean distance + if distance_metric == "euclidean": + cross_dist /= np.sqrt(features.shape[1]) + + if verbose: + # Plot the distance matrix + plt.figure(figsize=(10, 10)) + plt.imshow(cross_dist, cmap="viridis") + plt.colorbar(label=f"{distance_metric.capitalize()} Distance") + plt.title(f"{distance_metric.capitalize()} Distance Matrix") + plt.tight_layout() + plt.show() + + rank_fractions = rank_nearest_neighbors(cross_dist, normalize=True) + + # Compute piece-wise dissimilarity and rank difference + features_df = features["sample"].to_dataframe().reset_index(drop=True) + piece_wise_dissimilarity_per_track, piece_wise_rank_difference_per_track = ( + compute_piece_wise_dissimilarity(features_df, cross_dist, rank_fractions) + ) + + all_dissimilarity = np.concatenate(piece_wise_dissimilarity_per_track) + + # Random sampling values in the dissimilarity matrix + n_samples = len(all_dissimilarity) + random_indices = np.random.randint(0, len(cross_dist), size=(n_samples, 2)) + sampled_values = cross_dist[random_indices[:, 0], random_indices[:, 1]] + + # Create and save DataFrame + distributions_df = pd.DataFrame( + { + "adjacent_frame": pd.Series(all_dissimilarity), + "random_sampling": pd.Series(sampled_values), + } + ) + + csv_path = output_path + distributions_df.to_csv(csv_path, index=False) + + return distributions_df + + +def analyze_and_plot_distances( + distributions_df: pd.DataFrame, + output_file_path: Optional[str], + overwrite: bool = False, +) -> dict: + """ + Analyze distance distributions and create visualization plots. + + Parameters + ---------- + distributions_df : pd.DataFrame + DataFrame containing 'adjacent_frame' and 'random_sampling' columns + output_file_path : str, optional + Path to save the plot ideally with a .pdf extension. Uses `plt.savefig()` + overwrite : bool, default=False + If True, overwrites existing files + + Returns + ------- + dict + Dictionary containing computed metrics including means, standard deviations, + medians, peaks, and dynamic range of the distributions + """ + # Compute statistics + adjacent_dist = distributions_df["adjacent_frame"].values + random_dist = distributions_df["random_sampling"].values + + # Compute peaks + adjacent_peak = float(find_distribution_peak(adjacent_dist)) + random_peak = float(find_distribution_peak(random_dist)) + dynamic_range = float(random_peak - adjacent_peak) + + metrics = { + "dissimilarity_mean": float(np.mean(adjacent_dist)), + "dissimilarity_std": float(np.std(adjacent_dist)), + "dissimilarity_median": float(np.median(adjacent_dist)), + "dissimilarity_peak": adjacent_peak, + "dissimilarity_p99": float(np.percentile(adjacent_dist, 99)), + "dissimilarity_p1": float(np.percentile(adjacent_dist, 1)), + "random_mean": float(np.mean(random_dist)), + "random_std": float(np.std(random_dist)), + "random_median": float(np.median(random_dist)), + "random_peak": random_peak, + "dynamic_range": dynamic_range, + } + + # Create plot + fig = plt.figure() + sns.histplot( + data=distributions_df, + x="adjacent_frame", + bins=30, + kde=True, + color="cyan", + alpha=0.5, + stat="density", + ) + sns.histplot( + data=distributions_df, + x="random_sampling", + bins=30, + kde=True, + color="red", + alpha=0.5, + stat="density", ) - return time_points, euclidean_distances.tolist() + plt.xlabel("Cosine Dissimilarity") + plt.ylabel("Density") + plt.axvline(x=adjacent_peak, color="cyan", linestyle="--", alpha=0.8) + plt.axvline(x=random_peak, color="red", linestyle="--", alpha=0.8) + plt.tight_layout() + plt.legend(["Adjacent Frame", "Random Sample", "Adjacent Peak", "Random Peak"]) + if output_file_path.exists() and not overwrite: + raise FileExistsError( + f"File {output_file_path} already exists and overwrite=False" + ) + fig.savefig(output_file_path, dpi=600) + plt.show() + + return metrics diff --git a/viscy/representation/timearrow.py b/viscy/representation/timearrow.py new file mode 100644 index 000000000..6df8a1ccf --- /dev/null +++ b/viscy/representation/timearrow.py @@ -0,0 +1,386 @@ +import logging +from typing import Literal, Sequence + +import numpy as np +import torch +import torch.nn as nn +from lightning.pytorch import LightningModule +from tarrow.models import TimeArrowNet +from tarrow.models.losses import DecorrelationLoss +from torch.optim import Adam +from torch.optim.lr_scheduler import CyclicLR, ReduceLROnPlateau + +from viscy.utils.log_images import render_images + +logger = logging.getLogger(__name__) + + +class TarrowModule(LightningModule): + """Lightning Module wrapper for TimeArrowNet. + + Parameters + ---------- + backbone : str, default="unet" + Dense network architecture + projection_head : str, default="minimal_batchnorm" + Dense projection head architecture + classification_head : str, default="minimal" + Classification head architecture + n_frames : int, default=2 + Number of input frames + n_features : int, default=16 + Number of output features from the backbone + n_input_channels : int, default=1 + Number of input channels + symmetric : bool, default=False + If True, use permutation-equivariant classification head + learning_rate : float, default=1e-4 + Learning rate for optimizer + weight_decay : float, default=1e-6 + Weight decay for optimizer + lambda_decorrelation : float, default=0.01 + Prefactor of decorrelation loss + lr_scheduler : str, default="cyclic" + Learning rate scheduler ('plateau' or 'cyclic') + lr_patience : int, default=50 + Patience for learning rate scheduler + log_batches_per_epoch : int, default=8 + Number of batches to log samples from during training + log_samples_per_batch : int, default=1 + Number of samples to log from each batch + """ + + def __init__( + self, + backbone="unet", + projection_head="minimal_batchnorm", + classification_head="minimal", + n_frames=2, + n_features=16, + n_input_channels=1, + symmetric=False, + learning_rate=1e-4, + weight_decay=1e-6, + lambda_decorrelation=0.01, + lr_scheduler="cyclic", + lr_patience=50, + log_batches_per_epoch=8, + log_samples_per_batch=1, + **kwargs, + ): + super().__init__() + self.save_hyperparameters() + + self.log_batches_per_epoch = log_batches_per_epoch + self.log_samples_per_batch = log_samples_per_batch + self.training_step_outputs = [] + self.validation_step_outputs = [] + + self.model = TimeArrowNet( + backbone=backbone, + projection_head=projection_head, + classification_head=classification_head, + n_frames=n_frames, + n_features=n_features, + n_input_channels=n_input_channels, + symmetric=symmetric, + ) + + self.criterion = nn.CrossEntropyLoss(reduction="none") + self.criterion_decorr = DecorrelationLoss() + + def _log_samples(self, key: str, imgs: Sequence[Sequence[np.ndarray]]): + """Log sample images to TensorBoard. + + Parameters + ---------- + key : str + Key for logging + imgs : Sequence[Sequence[np.ndarray]] + List of image pairs to log + """ + grid = render_images(imgs, cmaps=["gray"] * 2) # Only 2 timepoints + self.logger.experiment.add_image( + key, grid, self.current_epoch, dataformats="HWC" + ) + + def _log_step_samples(self, batch_idx, images, stage: Literal["train", "val"]): + """Log samples from a batch. + + Parameters + ---------- + batch_idx : int + Index of current batch + images : torch.Tensor + Batch of images with shape (B, T, C, H, W) + stage : str + Either "train" or "val" + """ + if batch_idx < self.log_batches_per_epoch: + # Get first n samples from batch + n = min(self.log_samples_per_batch, images.shape[0]) + samples = images[:n].detach().cpu().numpy() + + # Split into pairs of timepoints + pairs = [(sample[0], sample[1]) for sample in samples] + + output_list = ( + self.training_step_outputs + if stage == "train" + else self.validation_step_outputs + ) + output_list.extend(pairs) + + def forward(self, x): + """Forward pass through the model. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (batch_size, n_frames, channels, height, width) + + Returns + ------- + tuple + Tuple of (output, projection) where: + - output is the classification logits + - projection is the feature space projection + """ + return self.model(x, mode="both") + + def _shared_step(self, batch, batch_idx, step="train"): + """Shared step for training and validation. + + Parameters + ---------- + batch : tuple + Tuple of (images, labels) + batch_idx : int + Index of the current batch + step : str, default="train" + Current step type ("train" or "val") + + Returns + ------- + torch.Tensor + Combined loss (classification + decorrelation) + """ + x, y = batch + + # Log sample images + self._log_step_samples(batch_idx, x, step) + + out, pro = self(x) + + if out.ndim > 2: + y = torch.broadcast_to( + y.unsqueeze(1).unsqueeze(1), (y.shape[0],) + out.shape[-2:] + ) + loss = self.criterion(out, y) + loss = torch.mean(loss, tuple(range(1, loss.ndim))) + y = y[:, 0, 0] + u_avg = torch.mean(out, tuple(range(2, out.ndim))) + else: + u_avg = out + loss = self.criterion(out, y) + + pred = torch.argmax(u_avg.detach(), 1) + loss = torch.mean(loss) + + # decorrelation loss + pro_batched = pro.flatten(0, 1) + loss_decorr = self.criterion_decorr(pro_batched) + loss_all = loss + self.hparams.lambda_decorrelation * loss_decorr + + acc = torch.mean((pred == y).float()) + + # Main classification loss + self.log(f"loss/{step}_loss", loss, prog_bar=True) + # Decorrelation loss for feature space + self.log(f"loss/{step}_loss_decorr", loss_decorr, prog_bar=True) + # Classification accuracy + self.log(f"metric/{step}_accuracy", acc, prog_bar=True) + # Ratio of positive predictions (class 1) - useful to detect class imbalance + self.log(f"metric/{step}_pred1_ratio", pred.sum().float() / len(pred)) + + return loss_all + + def training_step(self, batch, batch_idx): + return self._shared_step(batch, batch_idx, "train") + + def validation_step(self, batch, batch_idx): + return self._shared_step(batch, batch_idx, "val") + + def configure_optimizers(self): + optimizer = Adam( + self.parameters(), + lr=self.hparams.learning_rate, + weight_decay=self.hparams.weight_decay, + ) + + if self.hparams.lr_scheduler == "plateau": + scheduler = ReduceLROnPlateau( + optimizer, + factor=0.2, + patience=self.hparams.lr_patience, + verbose=True, + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "monitor": "loss/val_loss", + "interval": "epoch", + }, + } + elif self.hparams.lr_scheduler == "cyclic": + # Get dataloader length accounting for DDP + dataloader = self.trainer.datamodule.train_dataloader() + steps_per_epoch = len(dataloader) + + # Account for gradient accumulation and multiple GPUs + if self.trainer.accumulate_grad_batches: + steps_per_epoch = ( + steps_per_epoch // self.trainer.accumulate_grad_batches + ) + + total_steps = steps_per_epoch * self.trainer.max_epochs + + scheduler = CyclicLR( + optimizer, + base_lr=self.hparams.learning_rate, + max_lr=self.hparams.learning_rate * 10, + cycle_momentum=False, + step_size_up=total_steps // 2, # Half the total steps for one cycle + scale_mode="cycle", + scale_fn=lambda x: 0.9**x, + ) + return { + "optimizer": optimizer, + "lr_scheduler": { + "scheduler": scheduler, + "interval": "step", + }, + } + + def on_train_epoch_end(self): + """Log collected training samples at end of epoch.""" + if self.training_step_outputs: + self._log_samples("train_samples", self.training_step_outputs) + self.training_step_outputs = [] + + def on_validation_epoch_end(self): + """Log collected validation samples at end of epoch.""" + if self.validation_step_outputs: + self._log_samples("val_samples", self.validation_step_outputs) + self.validation_step_outputs = [] + + def gradcam(self, x, **kwargs): + """Generate GradCAM visualization for the projection layer. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape (T, C, H, W) + **kwargs : dict + Additional arguments passed to model's gradcam method + + Returns + ------- + numpy.ndarray + GradCAM visualization + """ + # Store training mode and switch to eval + was_training = self.training + self.eval() + + # Store gradients and activations + self.gradients = None + self.activations = None + + # Register hooks + def save_gradients(grad): + self.gradients = grad + + def save_activations(module, input, output): + self.activations = output + + # Get the target layer (last conv layer of backbone) + target_layer = None + for module in self.model.backbone.modules(): + if isinstance(module, nn.Conv2d): + target_layer = module + + if target_layer is None: + raise RuntimeError("Could not find suitable layer for GradCAM") + + # Register hooks + h = target_layer.register_forward_hook(save_activations) + h_bp = target_layer.register_backward_hook( + lambda m, grad_in, grad_out: save_gradients(grad_out[0]) + ) + + try: + # Add batch dimension if needed + if x.ndim == 4: + x = x.unsqueeze(0) + + x = x.to(self.device) + + # Enable gradients for computation + with torch.enable_grad(): + x = x.requires_grad_(True) + + # Forward pass through model + output = self.model(x, mode="both") + if isinstance(output, tuple): + output = output[0] # Get classification output + + # Get predicted class (or use class 0 for binary case) + if output.ndim > 2: + # Handle spatial outputs by averaging + output = torch.mean(output, tuple(range(2, output.ndim))) + pred = output.argmax(dim=1) + + # Create one hot vector for backward pass + one_hot = torch.zeros_like(output, device=self.device) + one_hot[0][pred] = 1 + + # Clear gradients + self.zero_grad(set_to_none=False) + + # Backward pass + output.backward(gradient=one_hot) + + # Ensure we have valid gradients and activations + if self.gradients is None or self.activations is None: + raise RuntimeError( + "No gradients or activations available for GradCAM computation" + ) + + # Calculate weights and generate CAM + weights = torch.mean(self.gradients, dim=(2, 3)) + cam = torch.sum(weights[:, :, None, None] * self.activations, dim=1) + cam = torch.relu(cam) + + # Interpolate CAM to input size + cam = ( + torch.nn.functional.interpolate( + cam.unsqueeze(0), + size=x.shape[-2:], + mode="bilinear", + align_corners=False, + )[0, 0] + .cpu() + .detach() + .numpy() + ) + + return cam + + finally: + # Clean up + h.remove() + h_bp.remove() + # Restore original training mode + self.train(mode=was_training)