@@ -738,15 +738,17 @@ def configure_optimizers(self) -> dict:
738738
739739# ## Fitted Model
740740
741- checkpoint_path = (
742- "lightning_logs/time_vae_naive/version_29/checkpoints/epoch=1999-step=354000.ckpt"
743- )
744- vae_model_reloaded = VAEModel .load_from_checkpoint (checkpoint_path , model = vae )
741+ IS_RELOAD = True
745742
743+ if IS_RELOAD :
744+ checkpoint_path = "lightning_logs/time_vae_naive/version_29/checkpoints/epoch=1999-step=354000.ckpt"
745+ vae_model_reloaded = VAEModel .load_from_checkpoint (checkpoint_path , model = vae )
746+ else :
747+ vae_model_reloaded = vae_model
746748
747- for i in time_vae_dm .predict_dataloader ():
748- print (i .size ())
749- i_pred = vae_model_reloaded .model (i .float ().cuda ())
749+ for pred_batch in time_vae_dm .predict_dataloader ():
750+ print (pred_batch .size ())
751+ i_pred = vae_model_reloaded .model (pred_batch .float ().cuda ())
750752 break
751753
752754i_pred [0 ].size ()
@@ -758,23 +760,70 @@ def configure_optimizers(self) -> dict:
758760
759761element = 4
760762
761- ax .plot (i .detach ().numpy ()[element , :, 0 ])
763+ ax .plot (pred_batch .detach ().numpy ()[element , :, 0 ])
762764ax .plot (i_pred [0 ].cpu ().detach ().numpy ()[element , :, 0 ], "x-" )
763765# -
764766
765767# Data generation using the decoder.
766768
767769sampling_z = torch .randn (
768- 2 , vae_model_reloaded .model .encoder .params .latent_size
770+ pred_batch . size ( 0 ) , vae_model_reloaded .model .encoder .params .latent_size
769771).type_as (vae_model_reloaded .model .encoder .z_mean_layer .weight )
770- sampling_x = vae_model_reloaded .model .decoder (sampling_z )
772+ generated_samples_x = (
773+ vae_model_reloaded .model .decoder (sampling_z ).cpu ().detach ().numpy ().squeeze ()
774+ )
771775
772- sampling_x .size ()
776+ generated_samples_x .size ()
773777
774778# +
775779_ , ax = plt .subplots ()
776780
777- for i in range (min (len (sampling_x ), 4 )):
778- ax .plot (sampling_x . cpu (). detach (). numpy () [i , :, 0 ], "x-" )
781+ for i in range (min (len (generated_samples_x ), 4 )):
782+ ax .plot (generated_samples_x [i , :], "x-" )
779783
780784# -
785+ from openTSNE import TSNE
786+
787+ n_tsne_samples = 100
788+
789+ original_samples = pred_batch .cpu ().detach ().numpy ().squeeze ()[:n_tsne_samples ]
790+ original_samples .shape
791+
792+ tsne = TSNE (
793+ perplexity = 30 ,
794+ metric = "euclidean" ,
795+ n_jobs = 8 ,
796+ random_state = 42 ,
797+ verbose = True ,
798+ )
799+
800+ original_samples_embedding = tsne .fit (original_samples )
801+
802+ generated_samples_x [:n_tsne_samples ]
803+
804+ generated_samples_embedding = original_samples_embedding .transform (
805+ generated_samples_x [:n_tsne_samples ]
806+ )
807+
808+ # +
809+ fig , ax = plt .subplots (figsize = (7 , 7 ))
810+
811+ ax .scatter (
812+ original_samples_embedding [:, 0 ],
813+ original_samples_embedding [:, 1 ],
814+ color = "black" ,
815+ marker = "." ,
816+ label = "original" ,
817+ )
818+
819+ ax .scatter (
820+ generated_samples_embedding [:, 0 ],
821+ generated_samples_embedding [:, 1 ],
822+ color = "red" ,
823+ marker = "x" ,
824+ label = "generated" ,
825+ )
826+
827+ ax .set_title ("t-SNE of original and generated samples" )
828+ ax .set_xlabel ("t-SNE 1" )
829+ ax .set_ylabel ("t-SNE 2" )
0 commit comments