Skip to content

Commit 3f18253

Browse files
committed
update the retry method in inference api client
1 parent 6365c10 commit 3f18253

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

inference/utils/inference_api_client.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -188,22 +188,26 @@ def wait_for_model_loading(self, model_identifier: str, max_attempts: int = 10,
188188
return False
189189
return False # Timed out without reaching LOADED state
190190

191-
@retry(stop=stop_after_attempt(5), wait=wait_fixed(60), retry=retry_if_result(lambda x: not x))
192191
def load_model(self, model_identifier: str) -> bool:
193192
"""
194-
Load a specific model, first unloading all models, and wait for loading to complete.
195-
196-
Args:
197-
model_identifier: The model to load
198-
199-
Returns:
200-
bool: True if the model is successfully loaded
193+
Load a specific model and avoid unnecessary unloading during retries.
201194
"""
202-
# First unload all models
195+
# First try to check if model is already loaded
196+
status = self.check_model_status(model_identifier)
197+
if status == ModelStatusEnum.LOADED:
198+
return True
199+
200+
# Only unload all models once, then use retries for loading
203201
if not self.unload_all_models():
204202
return False
205203

206-
# Check current status of our model
204+
# Now use retries only for the loading portion
205+
return self._load_model_with_retries(model_identifier)
206+
207+
@retry(stop=stop_after_attempt(5), wait=wait_fixed(30), retry=retry_if_result(lambda x: not x))
208+
def _load_model_with_retries(self, model_identifier: str) -> bool:
209+
"""Internal method that handles retries for loading a model without unloading first."""
210+
# Check current status
207211
status = self.check_model_status(model_identifier)
208212

209213
# If already loaded, we're done
@@ -216,7 +220,7 @@ def load_model(self, model_identifier: str) -> bool:
216220
if not load_request_success:
217221
return False
218222

219-
# Now wait for loading to complete
223+
# Wait for loading to complete
220224
return self.wait_for_model_loading(model_identifier)
221225

222226
return False

0 commit comments

Comments
 (0)