Skip to content

Commit 49c27a0

Browse files
add 3d pca scatterplot util
1 parent 61417dc commit 49c27a0

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

lidbox/visualize.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import matplotlib.pyplot as plt
2+
from mpl_toolkits.mplot3d import Axes3D
23
import numpy as np
4+
import pandas as pd
35
import seaborn as sns
46

57

@@ -110,3 +112,21 @@ def plot_embedding_vector(v, cmap="RdBu_r", figsize=None):
110112

111113
plt.gcf().set_size_inches(*figsize)
112114
plt.show()
115+
116+
117+
def draw_3d_pca_scatterplot(pca_data_3d, data_labels):
118+
df = pd.DataFrame.from_dict({
119+
"x": pca_data_3d[:,0],
120+
"y": pca_data_3d[:,1],
121+
"z": pca_data_3d[:,2],
122+
"label": data_labels,
123+
})
124+
125+
fig = plt.figure()
126+
ax = fig.add_subplot(111, projection="3d")
127+
128+
for label, group in df.groupby("label"):
129+
ax.scatter(group.x, group.y, group.z, label=label)
130+
131+
ax.legend()
132+
return fig, ax

0 commit comments

Comments
 (0)