Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 10 additions & 11 deletions viscy/data/triplet.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,17 +309,16 @@ def __getitems__(self, indices: list[int]) -> list[TripletSample]:
if self.return_negative:
for sample, negative_patch in zip(samples, negative_patches):
sample["negative"] = negative_patch
else:
for sample, (_, anchor_row) in zip(samples, anchor_rows.iterrows()):
# For new predictions, ensure all INDEX_COLUMNS are included
index_dict = {}
for col in INDEX_COLUMNS:
if col in anchor_row.index:
index_dict[col] = anchor_row[col]
elif col not in ["y", "x", "z"]:
# Skip y and x for legacy data - they weren't part of INDEX_COLUMNS
raise KeyError(f"Required column '{col}' not found in data")
sample["index"] = index_dict
for sample, (_, anchor_row) in zip(samples, anchor_rows.iterrows()):
# For new predictions, ensure all INDEX_COLUMNS are included
index_dict = {}
for col in INDEX_COLUMNS:
if col in anchor_row.index:
index_dict[col] = anchor_row[col]
elif col not in ["y", "x", "z"]:
# Skip y and x for legacy data - they weren't part of INDEX_COLUMNS
raise KeyError(f"Required column '{col}' not found in data")
sample["index"] = index_dict
return samples


Expand Down