Skip to content

Commit 7340cd8

Browse files
committed
pass api_client to external job and allow specification of api_url
1 parent 63bb8a0 commit 7340cd8

File tree

1 file changed

+6
-7
lines changed

1 file changed

+6
-7
lines changed

inference/models/inference.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# inference/models/inference.py
2+
from django.conf import settings
23
from django.db import models
34
from django.utils import timezone
45

@@ -119,11 +120,9 @@ def log_error_and_set_status_failed(self, error_msg: str) -> None:
119120
self.completed_at = timezone.now()
120121
self.save(update_fields=["error_message", "status", "completed_at", "updated_at"])
121122

122-
def create_external_job(self, batch_data) -> "ExternalJob":
123+
def _create_external_job(self, batch_data, api_client) -> "ExternalJob":
123124
"""Create and submit an external job for a batch"""
124125
try:
125-
api_client = InferenceAPIClient()
126-
127126
# Submit batch to API using model version identifier
128127
job_id = api_client.submit_batch(self.model_version.api_identifier, batch_data)
129128

@@ -144,22 +143,22 @@ def create_external_job(self, batch_data) -> "ExternalJob":
144143
self.log_error_and_set_status_failed(f"Failed to create external job: {str(e)}")
145144
return None
146145

147-
def initiate(self) -> None:
146+
def initiate(self, inference_api_url=settings.INFERENCE_API_URL) -> None:
148147
"""Initialize job and create batches"""
149148
try:
150149
# Load model using the refactored API client
151-
api_client = InferenceAPIClient()
150+
api_client = InferenceAPIClient(inference_api_url=inference_api_url)
152151
if not api_client.load_model(self.model_version.api_identifier):
153152
# TODO: should refactor to get an exact error out of the api client
154153
self.log_error_and_set_status_failed("Failed to load model")
155154
return
156155

157156
batch_processor = BatchProcessor()
158-
urls = self.collection.curated_urls.all()
157+
urls = self.collection.dump_urls.all()
159158
created_batch = False
160159

161160
for batch in batch_processor.iter_url_batches(urls):
162-
external_job = self.create_external_job(batch)
161+
external_job = self._create_external_job(batch, api_client)
163162
if external_job:
164163
created_batch = True
165164
else:

0 commit comments

Comments
 (0)