Skip to content

[Phase 6] Evaluation and Explainability #7

@Sakeeb91

Description

@Sakeeb91

Phase 6: Evaluation and Explainability

Parent: #1
Depends on: #5, #6

Objectives

Implement comprehensive evaluation metrics and model explainability for clinical relevance.

Tasks

  • Implement per-class precision, recall, F1 computation
  • Generate confusion matrix visualizations
  • Compute ROC-AUC curves (one-vs-rest)
  • Implement SHAP analysis for classical models
  • Implement Grad-CAM for CNN visualization
  • Visualize LSTM attention weights
  • Create comparison table across all models
  • Generate publication-quality figures

Files to Create/Modify

File Action Description
src/evaluate.py Create Evaluation module
src/explainability.py Create SHAP and Grad-CAM
outputs/ Create Generated figures

Code Reference

from sklearn.metrics import (classification_report, confusion_matrix,
                             roc_curve, auc)
import matplotlib.pyplot as plt
import seaborn as sns
import shap
import numpy as np

CLASS_NAMES = ['Normal', 'Supraventricular', 'Ventricular', 'Fusion', 'Unknown']

def plot_confusion_matrix(y_true, y_pred, save_path=None):
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

def plot_roc_curves(y_true, y_proba, save_path=None):
    plt.figure(figsize=(10, 8))
    for i, class_name in enumerate(CLASS_NAMES):
        y_binary = (y_true == i).astype(int)
        fpr, tpr, _ = roc_curve(y_binary, y_proba[:, i])
        roc_auc = auc(fpr, tpr)
        plt.plot(fpr, tpr, label=f'{class_name} (AUC = {roc_auc:.3f})')

    plt.plot([0, 1], [0, 1], 'k--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC Curves (One-vs-Rest)')
    plt.legend()
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

def shap_analysis(model, X, feature_names, save_path=None):
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(X[:1000])  # Subset for speed
    shap.summary_plot(shap_values, X[:1000], feature_names=feature_names,
                      show=False)
    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
    plt.show()

Grad-CAM Implementation

class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None

        target_layer.register_forward_hook(self.save_activation)
        target_layer.register_backward_hook(self.save_gradient)

    def save_activation(self, module, input, output):
        self.activations = output

    def save_gradient(self, module, grad_input, grad_output):
        self.gradients = grad_output[0]

    def generate(self, input_tensor, target_class):
        output = self.model(input_tensor)
        self.model.zero_grad()
        output[0, target_class].backward()

        weights = self.gradients.mean(dim=-1, keepdim=True)
        cam = (weights * self.activations).sum(dim=1)
        cam = F.relu(cam)
        cam = cam / cam.max()
        return cam.squeeze().detach().cpu().numpy()

Definition of Done

  • Classification report generated for all models
  • Confusion matrices saved as images
  • ROC curves with AUC values plotted
  • SHAP summary plot shows feature importance
  • Grad-CAM visualizations for sample predictions
  • Attention weights visualized for LSTM
  • Results table comparing all models

Technical Notes

For junior developers:

  • SHAP is slow on large datasets - use subset
  • Grad-CAM requires hooks - understand PyTorch autograd
  • ROC-AUC for multi-class uses one-vs-rest strategy
  • Healthcare stakeholders value explainability highly

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions