-
Notifications
You must be signed in to change notification settings - Fork 0
[Phase 6] Evaluation and Explainability #7
Copy link
Copy link
Open
Description
Phase 6: Evaluation and Explainability
Objectives
Implement comprehensive evaluation metrics and model explainability for clinical relevance.
Tasks
- Implement per-class precision, recall, F1 computation
- Generate confusion matrix visualizations
- Compute ROC-AUC curves (one-vs-rest)
- Implement SHAP analysis for classical models
- Implement Grad-CAM for CNN visualization
- Visualize LSTM attention weights
- Create comparison table across all models
- Generate publication-quality figures
Files to Create/Modify
| File | Action | Description |
|---|---|---|
src/evaluate.py |
Create | Evaluation module |
src/explainability.py |
Create | SHAP and Grad-CAM |
outputs/ |
Create | Generated figures |
Code Reference
from sklearn.metrics import (classification_report, confusion_matrix,
roc_curve, auc)
import matplotlib.pyplot as plt
import seaborn as sns
import shap
import numpy as np
CLASS_NAMES = ['Normal', 'Supraventricular', 'Ventricular', 'Fusion', 'Unknown']
def plot_confusion_matrix(y_true, y_pred, save_path=None):
cm = confusion_matrix(y_true, y_pred)
plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=CLASS_NAMES, yticklabels=CLASS_NAMES)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.show()
def plot_roc_curves(y_true, y_proba, save_path=None):
plt.figure(figsize=(10, 8))
for i, class_name in enumerate(CLASS_NAMES):
y_binary = (y_true == i).astype(int)
fpr, tpr, _ = roc_curve(y_binary, y_proba[:, i])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f'{class_name} (AUC = {roc_auc:.3f})')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curves (One-vs-Rest)')
plt.legend()
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.show()
def shap_analysis(model, X, feature_names, save_path=None):
explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(X[:1000]) # Subset for speed
shap.summary_plot(shap_values, X[:1000], feature_names=feature_names,
show=False)
if save_path:
plt.savefig(save_path, dpi=150, bbox_inches='tight')
plt.show()Grad-CAM Implementation
class GradCAM:
def __init__(self, model, target_layer):
self.model = model
self.target_layer = target_layer
self.gradients = None
self.activations = None
target_layer.register_forward_hook(self.save_activation)
target_layer.register_backward_hook(self.save_gradient)
def save_activation(self, module, input, output):
self.activations = output
def save_gradient(self, module, grad_input, grad_output):
self.gradients = grad_output[0]
def generate(self, input_tensor, target_class):
output = self.model(input_tensor)
self.model.zero_grad()
output[0, target_class].backward()
weights = self.gradients.mean(dim=-1, keepdim=True)
cam = (weights * self.activations).sum(dim=1)
cam = F.relu(cam)
cam = cam / cam.max()
return cam.squeeze().detach().cpu().numpy()Definition of Done
- Classification report generated for all models
- Confusion matrices saved as images
- ROC curves with AUC values plotted
- SHAP summary plot shows feature importance
- Grad-CAM visualizations for sample predictions
- Attention weights visualized for LSTM
- Results table comparing all models
Technical Notes
For junior developers:
- SHAP is slow on large datasets - use subset
- Grad-CAM requires hooks - understand PyTorch autograd
- ROC-AUC for multi-class uses one-vs-rest strategy
- Healthcare stakeholders value explainability highly
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels