@@ -112,18 +112,22 @@ def wrapper(instance, smiles_list: list[str]) -> list[Any]:
112112 model_name = getattr (instance , "model_name" , None )
113113 assert model_name is not None , "Instance must have a model_name attribute."
114114
115- results : list [tuple [int , Any ]] = []
116115 missing_smiles : list [str ] = []
117116 missing_indices : list [int ] = []
117+ ordered_results : list [Any ] = [None ] * len (smiles_list )
118118
119119 # First: try to fetch all from cache
120- for i , smiles in enumerate (smiles_list ):
121- result = self .get (smiles = smiles , model_name = model_name )
122- if result is not None :
123- results .append ((i , result )) # save index for reordering
120+ for idx , smiles in enumerate (smiles_list ):
121+ prediction = self .get (smiles = smiles , model_name = model_name )
122+ if prediction is not None :
123+ # For debugging purposes, you can uncomment the print statement below
124+ # print(
125+ # f"[Cache Hit] Prediction for smiles: {smiles} and model: {model_name} are retrieved from cache."
126+ # )
127+ ordered_results [idx ] = prediction
124128 else :
125129 missing_smiles .append (smiles )
126- missing_indices .append (i )
130+ missing_indices .append (idx )
127131
128132 # If some are missing, call original function
129133 if missing_smiles :
@@ -138,15 +142,9 @@ def wrapper(instance, smiles_list: list[str]) -> list[Any]:
138142 ):
139143 if prediction is not None :
140144 self .set (smiles , model_name , prediction )
141- results .append ((missing_idx , prediction ))
142-
143- # Reorder results to match original indices
144- results .sort (key = lambda x : x [0 ]) # sort by index
145- ordered = [result for _ , result in results ]
146- assert len (ordered ) == len (
147- smiles_list
148- ), "Result length does not match input length."
149- return ordered
145+ ordered_results [missing_idx ] = prediction
146+
147+ return ordered_results
150148
151149 return wrapper
152150
0 commit comments