Skip to content

Commit 3785bb5

Browse files
committed
avoid repeated slicing in loop
1 parent 30ca5f6 commit 3785bb5

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

chebai/preprocessing/datasets/chebi.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -359,18 +359,18 @@ def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, No
359359
"""
360360
with open(input_file_path, "rb") as input_file:
361361
df = pd.read_pickle(input_file)
362-
if self.single_class is not None:
363-
single_cls_index = list(df.columns).index(int(self.single_class))
364-
for row in df.values:
365-
if self.single_class is None:
366-
labels = row[self._LABELS_START_IDX :].astype(bool)
367-
else:
368-
labels = [bool(row[single_cls_index])]
369-
yield dict(
370-
features=row[self._DATA_REPRESENTATION_IDX],
371-
labels=labels,
372-
ident=row[self._ID_IDX],
373-
)
362+
363+
if self.single_class is None:
364+
all_labels = df.iloc[:, self._LABELS_START_IDX :].to_numpy(dtype=bool)
365+
else:
366+
single_cls_index = df.columns.get_loc(int(self.single_class))
367+
all_labels = df.iloc[:, [single_cls_index]].to_numpy(dtype=bool)
368+
369+
features = df.iloc[:, self._DATA_REPRESENTATION_IDX].to_numpy()
370+
idents = df.iloc[:, self._ID_IDX].to_numpy()
371+
372+
for feat, labels, ident in zip(features, all_labels, idents):
373+
yield dict(features=feat, labels=labels, ident=ident)
374374

375375
# ------------------------------ Phase: Dynamic Splits -----------------------------------
376376
def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:

0 commit comments

Comments
 (0)