Skip to content

Commit ed7a2ee

Browse files
authored
Add auth headers to infer_from_parameters (#1674)
1 parent aab6c21 commit ed7a2ee

File tree

2 files changed

+65
-7
lines changed

2 files changed

+65
-7
lines changed

ansible_ai_connect/ai/api/wca/model_id_views.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,15 @@ def validate(api_key, model_id):
217217
model_mesh_client: ModelPipelineCompletions = apps.get_app_config("ai").get_model_pipeline(
218218
ModelPipelineCompletions
219219
)
220+
headers = model_mesh_client.get_request_headers(
221+
api_key=api_key, identifier=None, lightspeed_user_uuid=None
222+
)
220223
model_mesh_client.infer_from_parameters(
221-
model_id, "", "---\n- hosts: all\n tasks:\n - name: install ssh\n"
224+
model_id,
225+
"",
226+
"---\n- hosts: all\n tasks:\n - name: install ssh\n",
227+
user=None,
228+
headers=headers,
222229
)
223230

224231

ansible_ai_connect/ai/api/wca/tests/test_model_id_views.py

Lines changed: 57 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
WisdomLogAwareMixin,
4848
)
4949

50+
VALIDATE_PROMPT = "---\n- hosts: all\n tasks:\n - name: install ssh\n"
51+
5052

5153
@override_settings(DEPLOYMENT_MODE="saas")
5254
@override_settings(WCA_SECRET_BACKEND_TYPE="aws_sm")
@@ -183,6 +185,9 @@ def _test_set_model_id(self, has_seat):
183185
self.user.organization = Organization.objects.get_or_create(id=123)[0]
184186
self.user.rh_user_has_seat = has_seat
185187
mock_secret_manager = apps.get_app_config("ai").get_wca_secret_manager()
188+
mock_wca_client: ModelPipelineCompletions = apps.get_app_config("ai").get_model_pipeline(
189+
ModelPipelineCompletions
190+
)
186191
self.client.force_authenticate(user=self.user)
187192

188193
# ModelId should initially not exist
@@ -192,7 +197,15 @@ def _test_set_model_id(self, has_seat):
192197
mock_secret_manager.get_secret.assert_called_with(123, Suffixes.MODEL_ID)
193198

194199
# Set ModelId
195-
mock_secret_manager.get_secret.return_value = {"SecretString": "someAPIKey"}
200+
api_key_value = "someAPIKey"
201+
model_id_value = "secret_model_id"
202+
mock_secret_manager.get_secret.return_value = {"SecretString": api_key_value}
203+
204+
expected_headers = {"Authorization": f"Bearer {api_key_value}", "X-Test-Header-Set": "true"}
205+
mock_wca_client.get_request_headers.return_value = expected_headers
206+
mock_wca_client.infer_from_parameters.reset_mock(side_effect=True)
207+
mock_wca_client.infer_from_parameters.side_effect = None
208+
196209
with self.assertLogs(logger="ansible_ai_connect.users.signals", level="DEBUG") as signals:
197210
with self.assertLogs(logger="root", level="DEBUG") as log:
198211
r = self.client.post(
@@ -202,8 +215,19 @@ def _test_set_model_id(self, has_seat):
202215
)
203216

204217
self.assertEqual(r.status_code, HTTPStatus.NO_CONTENT)
218+
219+
mock_wca_client.get_request_headers.assert_called_once_with(
220+
api_key=api_key_value, identifier=None, lightspeed_user_uuid=None
221+
)
222+
mock_wca_client.infer_from_parameters.assert_called_once_with(
223+
model_id_value,
224+
"",
225+
VALIDATE_PROMPT,
226+
user=None,
227+
headers=expected_headers,
228+
)
205229
mock_secret_manager.save_secret.assert_called_with(
206-
123, Suffixes.MODEL_ID, "secret_model_id"
230+
123, Suffixes.MODEL_ID, model_id_value
207231
)
208232
self.assert_segment_log(log, "modelIdSet", None)
209233

@@ -432,18 +456,45 @@ def _test_validate_ok(self, has_seat):
432456
self.user.organization = Organization.objects.get_or_create(id=123)[0]
433457
self.user.rh_user_has_seat = has_seat
434458
mock_secret_manager = apps.get_app_config("ai").get_wca_secret_manager()
459+
mock_wca_client: ModelPipelineCompletions = apps.get_app_config("ai").get_model_pipeline(
460+
ModelPipelineCompletions
461+
)
435462
self.client.force_authenticate(user=self.user)
436463

437-
def mock_get_secret_model_id(*args, **kwargs):
464+
api_key_value = "some_api_key_for_validate"
465+
model_id_value = "model_id_for_validate"
466+
467+
def mock_get_secret_side_effect(*args, **kwargs):
438468
if args[1] == Suffixes.API_KEY:
439-
return {"SecretString": "some_api_key"}
440-
return {"SecretString": "model_id"}
469+
return {"SecretString": api_key_value}
470+
if args[1] == Suffixes.MODEL_ID:
471+
return {"SecretString": model_id_value}
472+
return None
441473

442-
mock_secret_manager.get_secret.side_effect = mock_get_secret_model_id
474+
mock_secret_manager.get_secret.side_effect = mock_get_secret_side_effect
475+
476+
expected_headers = {
477+
"Authorization": f"Bearer {api_key_value}",
478+
"X-Test-Header-Validate": "true",
479+
}
480+
mock_wca_client.get_request_headers.return_value = expected_headers
481+
mock_wca_client.infer_from_parameters.reset_mock(side_effect=True)
482+
mock_wca_client.infer_from_parameters.side_effect = None
443483

444484
with self.assertLogs(logger="root", level="DEBUG") as log:
445485
r = self.client.get(self.api_version_reverse("wca_model_id_validator"))
446486
self.assertEqual(r.status_code, HTTPStatus.OK)
487+
488+
mock_wca_client.get_request_headers.assert_called_once_with(
489+
api_key=api_key_value, identifier=None, lightspeed_user_uuid=None
490+
)
491+
mock_wca_client.infer_from_parameters.assert_called_once_with(
492+
model_id_value,
493+
"",
494+
VALIDATE_PROMPT,
495+
user=None,
496+
headers=expected_headers,
497+
)
447498
self.assert_segment_log(log, "modelIdValidate", None)
448499

449500
@override_settings(SEGMENT_WRITE_KEY="DUMMY_KEY_VALUE")

0 commit comments

Comments
 (0)