Skip to content

Commit c8ba061

Browse files
committed
avoid sorting and re-iterating
1 parent 238be4e commit c8ba061

File tree

1 file changed

+13
-15
lines changed

1 file changed

+13
-15
lines changed

chebifier/_custom_cache.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -112,18 +112,22 @@ def wrapper(instance, smiles_list: list[str]) -> list[Any]:
112112
model_name = getattr(instance, "model_name", None)
113113
assert model_name is not None, "Instance must have a model_name attribute."
114114

115-
results: list[tuple[int, Any]] = []
116115
missing_smiles: list[str] = []
117116
missing_indices: list[int] = []
117+
ordered_results: list[Any] = [None] * len(smiles_list)
118118

119119
# First: try to fetch all from cache
120-
for i, smiles in enumerate(smiles_list):
121-
result = self.get(smiles=smiles, model_name=model_name)
122-
if result is not None:
123-
results.append((i, result)) # save index for reordering
120+
for idx, smiles in enumerate(smiles_list):
121+
prediction = self.get(smiles=smiles, model_name=model_name)
122+
if prediction is not None:
123+
# For debugging purposes, you can uncomment the print statement below
124+
# print(
125+
# f"[Cache Hit] Prediction for smiles: {smiles} and model: {model_name} are retrieved from cache."
126+
# )
127+
ordered_results[idx] = prediction
124128
else:
125129
missing_smiles.append(smiles)
126-
missing_indices.append(i)
130+
missing_indices.append(idx)
127131

128132
# If some are missing, call original function
129133
if missing_smiles:
@@ -138,15 +142,9 @@ def wrapper(instance, smiles_list: list[str]) -> list[Any]:
138142
):
139143
if prediction is not None:
140144
self.set(smiles, model_name, prediction)
141-
results.append((missing_idx, prediction))
142-
143-
# Reorder results to match original indices
144-
results.sort(key=lambda x: x[0]) # sort by index
145-
ordered = [result for _, result in results]
146-
assert len(ordered) == len(
147-
smiles_list
148-
), "Result length does not match input length."
149-
return ordered
145+
ordered_results[missing_idx] = prediction
146+
147+
return ordered_results
150148

151149
return wrapper
152150

0 commit comments

Comments
 (0)