Skip to content

Commit 24b10c9

Browse files
committed
fix feature 3d animation bugs
1 parent d8e5c2f commit 24b10c9

File tree

3 files changed

+44
-27
lines changed

3 files changed

+44
-27
lines changed

dhg/visualization/feature/utils.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,18 @@
66
from sklearn import preprocessing
77
from sklearn.manifold import TSNE
88
from sklearn.decomposition import PCA
9+
import matplotlib.animation as animation
910

1011
eps = 1e-5
1112
min_norm = 1e-15
1213

1314

15+
def on_key_press(event):
16+
print(event.key)
17+
if event.key == "escape" or event.key == "q":
18+
plt.close(event.canvas.figure)
19+
20+
1421
def make_animation(embeddings: np.ndarray, colors: Union[np.ndarray, str], cmap="viridis"):
1522
r"""Make an animation of embeddings.
1623
@@ -22,18 +29,20 @@ def make_animation(embeddings: np.ndarray, colors: Union[np.ndarray, str], cmap=
2229
embeddings = normalize(embeddings)
2330
x, y, z = embeddings[:, 0], embeddings[:, 1], embeddings[:, 2]
2431
fig = plt.figure(figsize=(8, 8))
25-
ax = fig.gca(projection="3d")
26-
plt.ion()
27-
for i in range(30000):
28-
plt.clf()
29-
fig = plt.gcf()
30-
ax = fig.gca(projection="3d")
32+
ax = fig.add_subplot(111, projection="3d")
33+
34+
def init():
3135
if colors is not None:
3236
ax.scatter(x, y, z, c=colors, cmap=cmap)
3337
else:
3438
ax.scatter(x, y, z, cmap=cmap)
39+
return fig
40+
41+
def animate(i):
3542
ax.view_init(elev=20, azim=i % 360)
36-
plt.pause(0.001)
43+
44+
ani = animation.FuncAnimation(fig, animate, init_func=init, frames=360, interval=20, blit=False)
45+
return ani
3746

3847

3948
def plot_2d_embedding(embeddings: np.ndarray, label: Optional[np.ndarray] = None, cmap="viridis"):

dhg/visualization/feature/visualization.py

Lines changed: 28 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -62,43 +62,53 @@ def draw_in_poincare_ball(
6262
def animation_of_3d_poincare_ball(
6363
embeddings: np.ndarray, label: Optional[np.ndarray] = None, reduce_method: str = "pca", cmap="viridis"
6464
):
65-
r"""Play animation of embeddings visualization on Poincare Ball.
65+
r"""Make 3D animation of embeddings visualization on Poincare Ball.
66+
This function will return the animation object `ani <https://matplotlib.org/stable/api/_as_gen/matplotlib.animation.FuncAnimation.html>`_.
67+
You can save the animation by ``ani.save("animation.gif")``.
6668
6769
Args:
6870
``feature`` (``np.ndarray``): The feature matrix. Size :math:`(N, C)`.
6971
``label`` (``np.ndarray``, optional): The label matrix. Size :math:`(N, )`. Defaults to ``None``.
7072
``reduce_method`` (``str``): The method to project the embedding into low-dimensional space. It can be ``pca`` or ``tsne``. Defaults to ``pca``.
7173
``cmap`` (``str``, optional): The `color map <https://matplotlib.org/stable/tutorials/colors/colormaps.html>`_. Defaults to ``"viridis"``.
74+
75+
Example:
76+
>>> import numpy as np
77+
>>> import matplotlib.pyplot as plt
78+
>>> from dhg.visualization import animation_of_3d_poincare_ball
79+
>>> x = np.random.rand(100, 32)
80+
>>> ani = animation_of_3d_poincare_ball(x)
81+
>>> plt.show()
82+
>>> ani.save('a.gif')
7283
"""
7384
emb_low = project_to_poincare_ball(embeddings, 3, reduce_method)
74-
colors = label if label is not None else "b"
75-
make_animation(emb_low, colors, cmap=cmap)
85+
colors = label if label is not None else "r"
86+
return make_animation(emb_low, colors, cmap=cmap)
7687

7788

7889
def animation_of_3d_euclidean_space(
7990
embeddings: np.ndarray, label: Optional[np.ndarray] = None, cmap="viridis",
8091
):
81-
r"""Play animation of embeddings visualization of tSNE algorithm.
92+
r"""Make 3D animation of embeddings visualization of tSNE algorithm.
93+
This function will return the animation object `ani <https://matplotlib.org/stable/api/_as_gen/matplotlib.animation.FuncAnimation.html>`_.
94+
You can save the animation by ``ani.save("animation.gif")``.
8295
8396
Args:
8497
``feature`` (``np.ndarray``): The feature matrix. Size :math:`(N, C)`.
8598
``label`` (``np.ndarray``, optional): The label matrix. Size :math:`(N, )`. Defaults to ``None``.
8699
``cmap`` (``str``, optional): The `color map <https://matplotlib.org/stable/tutorials/colors/colormaps.html>`_. Defaults to ``"viridis"``.
100+
101+
Example:
102+
>>> import numpy as np
103+
>>> import matplotlib.pyplot as plt
104+
>>> from dhg.visualization import animation_of_3d_euclidean_space
105+
>>> x = np.random.rand(100, 32)
106+
>>> ani = animation_of_3d_euclidean_space(x)
107+
>>> plt.show()
108+
>>> ani.save('a.gif')
87109
"""
88110
tsne = TSNE(n_components=3, init="pca")
89111
emb_low = tsne.fit_transform(embeddings)
90-
colors = label if label is not None else "b"
91-
make_animation(emb_low, colors, cmap=cmap)
112+
colors = label if label is not None else "r"
113+
return make_animation(emb_low, colors, cmap=cmap)
92114

93-
94-
if __name__ == "__main__":
95-
file_dir = "data/modelnet40/train_img_feat_4.npy"
96-
# save_dir = "./tmp/figure"
97-
save_dir = None # None for show now or file name to save
98-
low_demen_method = "TSNE" # vis for poincare_ball, PCA or TSNE
99-
show_method = "Rotation" # None for 2d or Rotation and Drag for 3d
100-
label = np.load("data/modelnet40/train_label.npy")
101-
ft = np.load(file_dir)
102-
d = 3
103-
# vis_tsne(ft, save_dir,d)
104-
draw_in_poincare_ball(ft, save_dir, d, label, reduce_method=low_demen_method, auto_play=show_method)

docs/source/api/visualization.rst

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,6 @@ Feature Visualization in Poincare Ball
4545
Make Animations
4646
~~~~~~~~~~~~~~~~~~
4747

48-
.. autofunction:: dhg.visualization.make_animation
49-
5048
.. autofunction:: dhg.visualization.animation_of_3d_euclidean_space
5149

5250
.. autofunction:: dhg.visualization.animation_of_3d_poincare_ball

0 commit comments

Comments
 (0)