Skip to content

Commit e56233b

Browse files
authored
Merge pull request #142 from nilchia/fix_triplet
fix triplet anchoring
2 parents 159d142 + 96743f9 commit e56233b

File tree

1 file changed

+34
-10
lines changed

1 file changed

+34
-10
lines changed

flexynesis/data.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -849,29 +849,53 @@ class TripletMultiOmicDataset(Dataset):
849849
def __init__(self, mydataset, main_var):
850850
self.dataset = mydataset
851851
self.main_var = main_var
852-
self.labels_set, self.label_to_indices = self.get_label_indices(self.dataset.ann[self.main_var])
852+
853+
labels = self.dataset.ann[self.main_var].numpy()
854+
self.labels_set, self.label_to_indices = self.get_label_indices(labels)
855+
856+
# Valid anchor indices are those without NA labels
857+
self.valid_indices = [i for i, label in enumerate(labels) if not np.isnan(label)]
858+
853859
def __getitem__(self, index):
860+
# We only use valid non-NA indices for anchors
861+
real_index = self.valid_indices[index]
862+
854863
# get anchor sample and its label
855-
anchor, y_dict = self.dataset[index][0], self.dataset[index][1]
864+
anchor, y_dict = self.dataset[real_index][0], self.dataset[real_index][1]
865+
856866
# choose another sample with same label
857867
label = y_dict[self.main_var].item()
858-
positive_index = index
859-
while positive_index == index:
868+
positive_index = real_index
869+
while positive_index == real_index:
860870
positive_index = np.random.choice(self.label_to_indices[label])
871+
861872
# choose another sample with a different label
862-
negative_label = np.random.choice(list(self.labels_set - set([label])))
873+
# possible negative labels include NA
874+
import random
875+
negative_label = random.choice(list(self.labels_set - set([label])))
863876
negative_index = np.random.choice(self.label_to_indices[negative_label])
877+
864878
pos = self.dataset[positive_index][0] # positive example
865879
neg = self.dataset[negative_index][0] # negative example
866880
return anchor, pos, neg, y_dict
867881

868882
def __len__(self):
869-
return len(self.dataset)
883+
return len(self.valid_indices)
884+
885+
def get_label_indices(self, labels_array):
886+
# Filter out NaNs for a clean set of valid classes
887+
valid_labels = [l for l in labels_array if not np.isnan(l)]
888+
labels_set = set(valid_labels)
889+
890+
label_to_indices = {label: np.where(labels_array == label)[0]
891+
for label in labels_set}
892+
893+
# Handle NA as a single separate group (if any exist)
894+
na_indices = np.where(np.isnan(labels_array))[0]
895+
if len(na_indices) > 0:
896+
labels_set.add("NA")
897+
label_to_indices["NA"] = na_indices
870898

871-
def get_label_indices(self, labels):
872-
labels_set = set(labels.numpy())
873-
label_to_indices = {label: np.where(labels.numpy() == label)[0]
874-
for label in labels_set}
875899
return labels_set, label_to_indices
876900

877901

0 commit comments

Comments
 (0)