Skip to content

Commit 82ade41

Browse files
committed
timevae evaluation
1 parent 2c93017 commit 82ade41

File tree

3 files changed

+95
-14
lines changed

3 files changed

+95
-14
lines changed

dl/notebooks/time_vae.py

Lines changed: 62 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -738,15 +738,17 @@ def configure_optimizers(self) -> dict:
738738

739739
# ## Fitted Model
740740

741-
checkpoint_path = (
742-
"lightning_logs/time_vae_naive/version_29/checkpoints/epoch=1999-step=354000.ckpt"
743-
)
744-
vae_model_reloaded = VAEModel.load_from_checkpoint(checkpoint_path, model=vae)
741+
IS_RELOAD = True
745742

743+
if IS_RELOAD:
744+
checkpoint_path = "lightning_logs/time_vae_naive/version_29/checkpoints/epoch=1999-step=354000.ckpt"
745+
vae_model_reloaded = VAEModel.load_from_checkpoint(checkpoint_path, model=vae)
746+
else:
747+
vae_model_reloaded = vae_model
746748

747-
for i in time_vae_dm.predict_dataloader():
748-
print(i.size())
749-
i_pred = vae_model_reloaded.model(i.float().cuda())
749+
for pred_batch in time_vae_dm.predict_dataloader():
750+
print(pred_batch.size())
751+
i_pred = vae_model_reloaded.model(pred_batch.float().cuda())
750752
break
751753

752754
i_pred[0].size()
@@ -758,23 +760,70 @@ def configure_optimizers(self) -> dict:
758760

759761
element = 4
760762

761-
ax.plot(i.detach().numpy()[element, :, 0])
763+
ax.plot(pred_batch.detach().numpy()[element, :, 0])
762764
ax.plot(i_pred[0].cpu().detach().numpy()[element, :, 0], "x-")
763765
# -
764766

765767
# Data generation using the decoder.
766768

767769
sampling_z = torch.randn(
768-
2, vae_model_reloaded.model.encoder.params.latent_size
770+
pred_batch.size(0), vae_model_reloaded.model.encoder.params.latent_size
769771
).type_as(vae_model_reloaded.model.encoder.z_mean_layer.weight)
770-
sampling_x = vae_model_reloaded.model.decoder(sampling_z)
772+
generated_samples_x = (
773+
vae_model_reloaded.model.decoder(sampling_z).cpu().detach().numpy().squeeze()
774+
)
771775

772-
sampling_x.size()
776+
generated_samples_x.size()
773777

774778
# +
775779
_, ax = plt.subplots()
776780

777-
for i in range(min(len(sampling_x), 4)):
778-
ax.plot(sampling_x.cpu().detach().numpy()[i, :, 0], "x-")
781+
for i in range(min(len(generated_samples_x), 4)):
782+
ax.plot(generated_samples_x[i, :], "x-")
779783

780784
# -
785+
from openTSNE import TSNE
786+
787+
n_tsne_samples = 100
788+
789+
original_samples = pred_batch.cpu().detach().numpy().squeeze()[:n_tsne_samples]
790+
original_samples.shape
791+
792+
tsne = TSNE(
793+
perplexity=30,
794+
metric="euclidean",
795+
n_jobs=8,
796+
random_state=42,
797+
verbose=True,
798+
)
799+
800+
original_samples_embedding = tsne.fit(original_samples)
801+
802+
generated_samples_x[:n_tsne_samples]
803+
804+
generated_samples_embedding = original_samples_embedding.transform(
805+
generated_samples_x[:n_tsne_samples]
806+
)
807+
808+
# +
809+
fig, ax = plt.subplots(figsize=(7, 7))
810+
811+
ax.scatter(
812+
original_samples_embedding[:, 0],
813+
original_samples_embedding[:, 1],
814+
color="black",
815+
marker=".",
816+
label="original",
817+
)
818+
819+
ax.scatter(
820+
generated_samples_embedding[:, 0],
821+
generated_samples_embedding[:, 1],
822+
color="red",
823+
marker="x",
824+
label="generated",
825+
)
826+
827+
ax.set_title("t-SNE of original and generated samples")
828+
ax.set_xlabel("t-SNE 1")
829+
ax.set_ylabel("t-SNE 2")

poetry.lock

Lines changed: 32 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ jupytext = "^1.15.2"
3232

3333
[tool.poetry.group.visualization.dependencies]
3434
seaborn = "^0.13.2"
35+
opentsne = "^1.0.2"
3536

3637

3738
[tool.poetry.group.data.dependencies]

0 commit comments

Comments
 (0)