-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathplot_features.py
More file actions
34 lines (29 loc) · 1.02 KB
/
plot_features.py
File metadata and controls
34 lines (29 loc) · 1.02 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
def plot_feature(
xs,
sample_rate,
feats,
xlabel="time",
ylabel="",
):
batch = xs.shape[0]
duration = xs.shape[1] / sample_rate
feats /= feats.max(dim=1, keepdim=True).values
fig = plt.figure(figsize=(12, 6))
gs = GridSpec(nrows=batch, ncols=1)
axs = [fig.add_subplot(gs[b, 0]) for b in range(batch)]
for i, (x, feat) in enumerate(zip(xs, feats)):
axs[i].plot(torch.linspace(0, duration, xs.shape[1]), x, alpha=0.5)
axs[i].plot(torch.linspace(0, duration, feats.shape[1]), feat)
axs[i].tick_params(labelsize="xx-large")
axs[i].set_xlabel(xlabel)
axs[i].set_ylabel(ylabel)
axs[i].set_xlim(0.0, duration)
axs[i].set_ylim(-1.1, 1.1)
axs[i].minorticks_on()
axs[i].grid(True, which="major", alpha=1.0, linewidth=1)
axs[i].grid(True, which="minor", alpha=0.3)
plt.tight_layout()
plt.show()