Skip to content

Commit 1dc9521

Browse files
committed
[feat] plot predictions
1 parent b0497c5 commit 1dc9521

File tree

2 files changed

+465
-45
lines changed

2 files changed

+465
-45
lines changed

scripts/plot_activation.ipynb

Lines changed: 143 additions & 38 deletions
Large diffs are not rendered by default.

src/brain_decoding/utils/plot.py

Lines changed: 322 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,300 @@
11
import os
2-
from typing import List
2+
from time import sleep
3+
from typing import Any, List
34

45
import matplotlib.pyplot as plt
56
import numpy as np
7+
import pandas as pd
68
import seaborn as sb
9+
from matplotlib.patches import Patch
710

811
from brain_decoding.dataloader.patients import Events
912

13+
# from brain_decoding.dataloader.save_clusterless import SECONDS_PER_HOUR
1014

11-
def prediction_curve(predictions: np.ndarray[float], labels: List[str], save_file_name: str) -> None:
15+
PREDICTION_FS = 4
16+
SLEEP_SCORE_FS = 1 / 30
17+
SLEEP_SCORE_OFFSET = 0
18+
SECONDS_PER_HOUR = 3600
19+
20+
21+
def prediction_curve(
22+
predictions: np.ndarray, sleep_score: pd.DataFrame, labels: List[str], save_file_name: str
23+
) -> None:
24+
"""
25+
Plot prediction curves with background colors representing sleep stages and a legend.
26+
27+
Parameters:
28+
- predictions (np.ndarray): n by m array of predictions.
29+
- sleep_score (pd.DataFrame): n by 2 DataFrame with sleep stage (column 0) and start index (column 1).
30+
- labels (List[str]): List of labels for each prediction curve.
31+
- save_file_name (str): The file path to save the plot.
32+
33+
Returns:
34+
- None: The function saves the figure to the specified output file.
35+
"""
1236
# Creating subplots
1337
palette = sb.color_palette("husl", n_colors=predictions.shape[1])
1438

1539
y_min = np.min(predictions)
1640
y_max = np.max(predictions)
1741

42+
# Assign a unique color for each unique sleep stage
43+
unique_stages = sleep_score["Score"].unique()
44+
stage_colors = sb.color_palette("Set2", len(unique_stages))
45+
stage_color_map = dict(zip(unique_stages, stage_colors)) # Map sleep stages to colors
46+
1847
fig, axes = plt.subplots(nrows=predictions.shape[1], ncols=1, figsize=(20, 12), sharex=True)
48+
49+
# Loop through each prediction curve
1950
for i in range(predictions.shape[1]):
51+
# Calculate time in hours
52+
time = np.arange(predictions.shape[0]) / PREDICTION_FS / SECONDS_PER_HOUR
53+
54+
# Plot the prediction curve with time in hours
2055
sb.lineplot(
21-
x=np.arange(predictions.shape[0]),
56+
x=time,
2257
y=predictions[:, i],
2358
ax=axes[i],
2459
color=palette[i],
60+
linewidth=1.5,
61+
)
62+
# Plot the mean curve with a dashed line
63+
sb.lineplot(
64+
x=time,
65+
y=np.mean(predictions[:, i]),
66+
ax=axes[i],
67+
color="#808080",
68+
linestyle="--",
2569
)
70+
71+
# Add background color based on sleep_score start_index
72+
for j in range(len(sleep_score)):
73+
start = sleep_score.iloc[j]["start_index"] / PREDICTION_FS / SECONDS_PER_HOUR
74+
end = (
75+
sleep_score.iloc[j + 1]["start_index"] / PREDICTION_FS / SECONDS_PER_HOUR
76+
if j < len(sleep_score) - 1
77+
else predictions.shape[0] / PREDICTION_FS / SECONDS_PER_HOUR
78+
)
79+
80+
if 0 <= start < predictions.shape[0] / PREDICTION_FS / SECONDS_PER_HOUR:
81+
color = stage_color_map[sleep_score.iloc[j]["Score"]]
82+
axes[i].axvspan(xmin=start, xmax=end, color=color, alpha=0.3)
83+
84+
# Set y-axis limits and title
2685
axes[i].set_ylim([y_min, y_max])
27-
axes[i].set_title(labels[i])
86+
axes[i].set_title(labels[i], fontsize=14)
87+
88+
# Create custom legend for the background colors
89+
legend_elements = [Patch(facecolor=stage_color_map[stage], label=stage, alpha=0.3) for stage in unique_stages]
90+
plt.legend(handles=legend_elements, loc="upper right", title="Sleep Stages")
91+
92+
# Set a common y-label for the figure
93+
fig.supylabel("Activation", fontsize=14)
94+
plt.xlabel("Time (hours)", fontsize=14)
95+
plt.tight_layout()
96+
97+
# Save the figure
98+
plt.savefig(save_file_name)
99+
plt.show()
100+
101+
102+
def stage_box_plot(predictions: np.ndarray, sleep_score: pd.DataFrame, labels: List[str], save_file_name: str) -> None:
103+
"""
104+
Plot violin plots with swarms overlaid for each sleep stage, with a separate subplot for each label.
105+
Limit the number of swarm points per stage for performance improvement and add stage length to the label.
106+
107+
Parameters:
108+
- predictions (np.ndarray): n by m array of predictions.
109+
- sleep_score (pd.DataFrame): n by 2 DataFrame with sleep stage (column 0) and start index (column 1).
110+
- labels (List[str]): List of labels for each prediction column.
111+
- save_file_name (str): The file path to save the plot.
112+
- sampling_rate (int): The sampling rate of the data (default is 4 Hz).
113+
114+
Returns:
115+
- None: The function saves the figure with subplots to the specified output file.
116+
"""
117+
n_samples, n_labels = predictions.shape
118+
119+
# Create subplots for each label (column of predictions)
120+
fig, axes = plt.subplots(n_labels, 1, figsize=(12, 3 * n_labels), sharex=True)
121+
122+
# If there's only one label, we need to convert axes to an iterable
123+
if n_labels == 1:
124+
axes = [axes]
125+
126+
# Loop through each label (column of predictions)
127+
for i, label in enumerate(labels):
128+
# Overwrite the combined DataFrame for memory efficiency
129+
combined_df_list = []
130+
show_legend = True if i == 0 else False
131+
132+
for j in range(len(sleep_score)):
133+
start = int(sleep_score.iloc[j]["start_index"])
134+
end = int(sleep_score.iloc[j + 1]["start_index"]) if j < len(sleep_score) - 1 else n_samples
135+
136+
if 0 <= start < predictions.shape[0] and end - start > 600 * PREDICTION_FS:
137+
stage_data = predictions[start:end, i]
138+
stage_data = stage_data[stage_data > 0.5] # Filter values greater than 0.5
139+
# Calculate stage length (duration in seconds)
140+
stage_length = (end - start) / PREDICTION_FS
141+
stage_label = f"Stage: {j} ({stage_length:.1f} sec)"
142+
143+
# Overwrite combined_df each time to save memory
144+
combined_df_list.append(
145+
pd.DataFrame(
146+
{
147+
"Stage": [stage_label] * len(stage_data),
148+
"Value(>.5)": stage_data,
149+
"Label": [label] * len(stage_data),
150+
"Stage Label": [sleep_score.iloc[j]["Score"]] * len(stage_data),
151+
}
152+
)
153+
)
154+
155+
combined_df = pd.concat(combined_df_list, axis=0)
156+
# Sample a maximum of n points per stage for the swarmplot
157+
combined_df_sample = (
158+
combined_df.groupby("Stage")
159+
.apply(lambda x: x.sample(n=min(len(x), 200), random_state=42))
160+
.reset_index(drop=True)
161+
)
162+
163+
# Create a color palette for the stages
164+
unique_stages = combined_df["Stage Label"].unique()
165+
palette = sb.color_palette("Set2", len(unique_stages))
166+
stage_color_map = dict(zip(unique_stages, palette))
167+
168+
# Plot the violin/box plot for this label on its respective axis
169+
ax = sb.boxplot(
170+
x="Stage",
171+
y="Value(>.5)",
172+
data=combined_df,
173+
hue="Stage Label",
174+
palette=stage_color_map,
175+
linewidth=1.5,
176+
color="none",
177+
width=0.7,
178+
notch=True,
179+
ax=axes[i],
180+
dodge=False,
181+
legend=False,
182+
)
183+
# Overlay the swarmplot with limited points
184+
ax = sb.swarmplot(
185+
x="Stage",
186+
y="Value(>.5)",
187+
data=combined_df_sample,
188+
hue="Stage Label",
189+
palette=stage_color_map,
190+
size=2,
191+
dodge=False,
192+
legend=show_legend,
193+
ax=axes[i],
194+
)
195+
196+
if show_legend:
197+
c = ax.collections
198+
ax.legend(
199+
borderaxespad=0.0,
200+
loc="right",
201+
columnspacing=1.2,
202+
frameon=False,
203+
markerscale=5,
204+
handlelength=0.1,
205+
prop={"size": 10},
206+
title="",
207+
bbox_to_anchor=(1, 1.1),
208+
ncol=2,
209+
)
210+
211+
# change boxplot edge color:
212+
for i, artist in enumerate(ax.patches):
213+
# Set the linecolor on the artist to the facecolor, and set the facecolor to None
214+
col = artist.get_facecolor()
215+
artist.set_edgecolor(col)
216+
artist.set_facecolor("None")
217+
218+
# Each box has 6 associated Line2D objects (to make the whiskers, fliers, etc.)
219+
# Loop over them here, and use the same colour as above
220+
for j in range(i * 6, i * 6 + 6):
221+
line = ax.lines[j]
222+
line.set_color(col)
223+
line.set_mfc(col)
224+
line.set_mec(col)
28225

29-
plt.ylabel("Activation")
30-
plt.xlabel("Time")
226+
# sb.violinplot(x='Stage', y='Value(>.5)', data=combined_df, hue='Stage Label', palette=stage_color_map,
227+
# linewidth=1.5, facecolor="none", ax=axes[i], inner=None, dodge=False, legend=False)
228+
229+
# Hide the right and top spines
230+
ax.spines["right"].set_visible(False)
231+
ax.spines["top"].set_visible(False)
232+
233+
# Set the title for each subplot
234+
ax.set_ylabel(label, fontsize=12)
235+
ax.tick_params(axis="x", rotation=45)
236+
237+
# Add overall figure label
31238
plt.tight_layout()
32239

240+
# Save the figure
33241
plt.savefig(save_file_name)
242+
plt.show()
243+
34244

245+
def correlation_heatmap(data: np.ndarray, column_labels: List[str], output_filename: str) -> None:
246+
"""
247+
Calculate the correlation among the columns of the data array and plot a heatmap with the
248+
distribution of correlation values in a subplot.
35249
36-
plt.show()
250+
Parameters:
251+
- data (np.ndarray): n by m array where n is the number of samples and m is the number of columns.
252+
- column_labels (List[str]): A list of labels for each column.
253+
- output_filename (str): The file path to save the heatmap image.
254+
255+
Returns:
256+
- None: The function saves the figure to the specified output file.
257+
"""
258+
# Calculate the correlation matrix
259+
corr_matrix = np.corrcoef(data, rowvar=False)
260+
v_min, v_max = -1, 1
261+
262+
# Flatten the correlation matrix and exclude the diagonal (correlation of a variable with itself)
263+
corr_values = corr_matrix[np.triu_indices_from(corr_matrix, k=1)]
264+
265+
# Create a figure with 2 subplots: 1 for the heatmap, 1 for the histogram
266+
fig, (ax_heatmap, ax_hist) = plt.subplots(1, 2, figsize=(14, 8), gridspec_kw={"width_ratios": [2.5, 1.5]})
267+
268+
# Plot the heatmap on the first subplot
269+
sb.heatmap(
270+
corr_matrix,
271+
annot=True,
272+
fmt=".3f",
273+
cmap="coolwarm",
274+
xticklabels=column_labels,
275+
center=0,
276+
vmin=v_min,
277+
vmax=v_max,
278+
cbar=False,
279+
annot_kws={"size": 12},
280+
yticklabels=column_labels,
281+
ax=ax_heatmap,
282+
)
283+
ax_heatmap.set_title("Correlation Heatmap")
284+
285+
ax_heatmap.set_xticklabels(ax_heatmap.get_xticklabels(), rotation=45, horizontalalignment="right", fontsize=12)
286+
ax_heatmap.set_yticklabels(ax_heatmap.get_yticklabels(), rotation=0, fontsize=12)
287+
288+
# Plot the distribution of correlation values on the second subplot
289+
ax_hist.hist(corr_values, bins=10, color="gray", edgecolor="black")
290+
ax_hist.set_title("Correlation Value Distribution")
291+
ax_hist.set_xlabel("Correlation")
292+
ax_hist.set_ylabel("Frequency")
293+
294+
# Save the figure
295+
plt.tight_layout()
296+
plt.savefig(output_filename, bbox_inches="tight")
297+
plt.show()
37298

38299

39300
def prediction_heatmap(predictions: np.ndarray[float], events: Events, title: str, file_path: str):
@@ -69,3 +330,57 @@ def prediction_heatmap(predictions: np.ndarray[float], events: Events, title: st
69330
plt.tight_layout()
70331
plt.savefig(file_path)
71332
plt.show()
333+
334+
335+
def smooth_data(data: np.ndarray[float], window_size: int = 5) -> np.ndarray[float, Any]:
336+
return np.convolve(data, np.ones(window_size) / window_size, mode="valid")
337+
338+
339+
def smooth_columns(data: np.ndarray[float], window_size: int = 5) -> np.ndarray[float, Any]:
340+
n_rows = data.shape[0]
341+
smoothed_data = np.zeros((n_rows - window_size + 1, data.shape[1])) # Adjust size for smoothing
342+
343+
# Smoothing each column
344+
for i in range(data.shape[1]):
345+
smoothed_data[:, i] = smooth_data(data[:, i], window_size=window_size)
346+
347+
return smoothed_data
348+
349+
350+
def combine_continuous_scores(df: pd.DataFrame) -> pd.DataFrame:
351+
"""
352+
Combine rows with continuous same values in the 'Score' column and keep the first value in the 'start_index' column.
353+
354+
Parameters:
355+
- df (pd.DataFrame): A DataFrame with 'Score' and 'start_index' columns.
356+
357+
Returns:
358+
- pd.DataFrame: A new DataFrame with combined rows, keeping the first 'start_index' value for each group.
359+
"""
360+
361+
# Create a mask to identify where the 'Score' changes
362+
df["group"] = (df["Score"] != df["Score"].shift()).cumsum()
363+
364+
# Group by the 'group' column and aggregate 'Score' and 'start_index'
365+
combined_df = df.groupby("group").agg({"Score": "first", "start_index": "first"}).reset_index(drop=True)
366+
367+
# Drop the temporary 'group' column if necessary
368+
combined_df = combined_df[["Score", "start_index"]]
369+
370+
return combined_df
371+
372+
373+
def read_sleep_score(filename: str) -> pd.DataFrame:
374+
sleep_score = pd.read_csv(filename, header=0)
375+
print(
376+
f"shape of sleep_score: {sleep_score.shape}, "
377+
f"duration: {sleep_score.shape[0] / SLEEP_SCORE_FS / SECONDS_PER_HOUR} hours"
378+
)
379+
sleep_score["start_index"] = [
380+
int(i * PREDICTION_FS / SLEEP_SCORE_FS + SLEEP_SCORE_OFFSET) for i in range(sleep_score.shape[0])
381+
]
382+
sleep_score = combine_continuous_scores(sleep_score)
383+
384+
print(f"shape of sleep_score after merge: {sleep_score.shape}")
385+
386+
return sleep_score

0 commit comments

Comments
 (0)