@@ -748,7 +748,6 @@ def fit(self, train_loader, test_loader, total_loader, first="RNA"):
748748 None.
749749
750750 """
751-
752751 used_cycle = 0
753752
754753 if self .ground_truth1 is not None :
@@ -844,7 +843,6 @@ def score(self, dataloader, metric='clustering'):
844843 Metric eval score for VAE2.
845844
846845 """
847-
848846 if metric == 'clustering' :
849847 self .model1 .eval ()
850848 self .model2 .eval ()
@@ -915,7 +913,6 @@ def _encodeBatch(self, total_loader):
915913 Reconstruction result of modality 2.
916914
917915 """
918-
919916 # processing large-scale datasets
920917 latent_z1 = []
921918 latent_z2 = []
@@ -953,30 +950,30 @@ def _encodeBatch(self, total_loader):
953950 return latent_z1 , latent_z2 , norm_x1 , recon_x1 , norm_x2 , recon_x2
954951
955952 def forward (self , total_loader ):
956- """Forward function for torch.nn.Module. An alias of encode_Batch function.
957-
958- Parameters
959- ----------
960- total_loader : torch.utils.data.DataLoader
961- Dataloader for dataset.
962-
963- Returns
964- -------
965- latent_z1 : numpy.ndarray
966- Latent representation of modality 1.
967- latent_z2 : numpy.ndarray
968- Latent representation of modality 2.
969- norm_x1 : numpy.ndarray
970- Normalized representation of modality 1.
971- recon_x1 : numpy.ndarray
972- Reconstruction result of modality 1.
973- norm_x2 : numpy.ndarray
974- Normalized representation of modality 2.
975- recon_x2 : numpy.ndarray
976- Reconstruction result of modality 2.
953+ """Forward function for torch.nn.Module.
954+
955+ An alias of encode_Batch function.
956+ Parameters
957+ ----------
958+ total_loader : torch.utils.data.DataLoader
959+ Dataloader for dataset.
960+
961+ Returns
962+ -------
963+ latent_z1 : numpy.ndarray
964+ Latent representation of modality 1.
965+ latent_z2 : numpy.ndarray
966+ Latent representation of modality 2.
967+ norm_x1 : numpy.ndarray
968+ Normalized representation of modality 1.
969+ recon_x1 : numpy.ndarray
970+ Reconstruction result of modality 1.
971+ norm_x2 : numpy.ndarray
972+ Normalized representation of modality 2.
973+ recon_x2 : numpy.ndarray
974+ Reconstruction result of modality 2.
977975
978976 """
979-
980977 latent_z1 , latent_z2 , norm_x1 , recon_x1 , norm_x2 , recon_x2 = self ._encodeBatch (total_loader )
981978
982979 return latent_z1 , latent_z2 , norm_x1 , recon_x1 , norm_x2 , recon_x2
0 commit comments