diff --git a/viscy/data/triplet.py b/viscy/data/triplet.py index 72828af42..e4167d117 100644 --- a/viscy/data/triplet.py +++ b/viscy/data/triplet.py @@ -309,17 +309,16 @@ def __getitems__(self, indices: list[int]) -> list[TripletSample]: if self.return_negative: for sample, negative_patch in zip(samples, negative_patches): sample["negative"] = negative_patch - else: - for sample, (_, anchor_row) in zip(samples, anchor_rows.iterrows()): - # For new predictions, ensure all INDEX_COLUMNS are included - index_dict = {} - for col in INDEX_COLUMNS: - if col in anchor_row.index: - index_dict[col] = anchor_row[col] - elif col not in ["y", "x", "z"]: - # Skip y and x for legacy data - they weren't part of INDEX_COLUMNS - raise KeyError(f"Required column '{col}' not found in data") - sample["index"] = index_dict + for sample, (_, anchor_row) in zip(samples, anchor_rows.iterrows()): + # For new predictions, ensure all INDEX_COLUMNS are included + index_dict = {} + for col in INDEX_COLUMNS: + if col in anchor_row.index: + index_dict[col] = anchor_row[col] + elif col not in ["y", "x", "z"]: + # Skip y and x for legacy data - they weren't part of INDEX_COLUMNS + raise KeyError(f"Required column '{col}' not found in data") + sample["index"] = index_dict return samples