@@ -849,29 +849,53 @@ class TripletMultiOmicDataset(Dataset):
849849 def __init__ (self , mydataset , main_var ):
850850 self .dataset = mydataset
851851 self .main_var = main_var
852- self .labels_set , self .label_to_indices = self .get_label_indices (self .dataset .ann [self .main_var ])
852+
853+ labels = self .dataset .ann [self .main_var ].numpy ()
854+ self .labels_set , self .label_to_indices = self .get_label_indices (labels )
855+
856+ # Valid anchor indices are those without NA labels
857+ self .valid_indices = [i for i , label in enumerate (labels ) if not np .isnan (label )]
858+
853859 def __getitem__ (self , index ):
860+ # We only use valid non-NA indices for anchors
861+ real_index = self .valid_indices [index ]
862+
854863 # get anchor sample and its label
855- anchor , y_dict = self .dataset [index ][0 ], self .dataset [index ][1 ]
864+ anchor , y_dict = self .dataset [real_index ][0 ], self .dataset [real_index ][1 ]
865+
856866 # choose another sample with same label
857867 label = y_dict [self .main_var ].item ()
858- positive_index = index
859- while positive_index == index :
868+ positive_index = real_index
869+ while positive_index == real_index :
860870 positive_index = np .random .choice (self .label_to_indices [label ])
871+
861872 # choose another sample with a different label
862- negative_label = np .random .choice (list (self .labels_set - set ([label ])))
873+ # possible negative labels include NA
874+ import random
875+ negative_label = random .choice (list (self .labels_set - set ([label ])))
863876 negative_index = np .random .choice (self .label_to_indices [negative_label ])
877+
864878 pos = self .dataset [positive_index ][0 ] # positive example
865879 neg = self .dataset [negative_index ][0 ] # negative example
866880 return anchor , pos , neg , y_dict
867881
868882 def __len__ (self ):
869- return len (self .dataset )
883+ return len (self .valid_indices )
884+
885+ def get_label_indices (self , labels_array ):
886+ # Filter out NaNs for a clean set of valid classes
887+ valid_labels = [l for l in labels_array if not np .isnan (l )]
888+ labels_set = set (valid_labels )
889+
890+ label_to_indices = {label : np .where (labels_array == label )[0 ]
891+ for label in labels_set }
892+
893+ # Handle NA as a single separate group (if any exist)
894+ na_indices = np .where (np .isnan (labels_array ))[0 ]
895+ if len (na_indices ) > 0 :
896+ labels_set .add ("NA" )
897+ label_to_indices ["NA" ] = na_indices
870898
871- def get_label_indices (self , labels ):
872- labels_set = set (labels .numpy ())
873- label_to_indices = {label : np .where (labels .numpy () == label )[0 ]
874- for label in labels_set }
875899 return labels_set , label_to_indices
876900
877901
0 commit comments