47
47
WisdomLogAwareMixin ,
48
48
)
49
49
50
+ VALIDATE_PROMPT = "---\n - hosts: all\n tasks:\n - name: install ssh\n "
51
+
50
52
51
53
@override_settings (DEPLOYMENT_MODE = "saas" )
52
54
@override_settings (WCA_SECRET_BACKEND_TYPE = "aws_sm" )
@@ -183,6 +185,9 @@ def _test_set_model_id(self, has_seat):
183
185
self .user .organization = Organization .objects .get_or_create (id = 123 )[0 ]
184
186
self .user .rh_user_has_seat = has_seat
185
187
mock_secret_manager = apps .get_app_config ("ai" ).get_wca_secret_manager ()
188
+ mock_wca_client : ModelPipelineCompletions = apps .get_app_config ("ai" ).get_model_pipeline (
189
+ ModelPipelineCompletions
190
+ )
186
191
self .client .force_authenticate (user = self .user )
187
192
188
193
# ModelId should initially not exist
@@ -192,7 +197,15 @@ def _test_set_model_id(self, has_seat):
192
197
mock_secret_manager .get_secret .assert_called_with (123 , Suffixes .MODEL_ID )
193
198
194
199
# Set ModelId
195
- mock_secret_manager .get_secret .return_value = {"SecretString" : "someAPIKey" }
200
+ api_key_value = "someAPIKey"
201
+ model_id_value = "secret_model_id"
202
+ mock_secret_manager .get_secret .return_value = {"SecretString" : api_key_value }
203
+
204
+ expected_headers = {"Authorization" : f"Bearer { api_key_value } " , "X-Test-Header-Set" : "true" }
205
+ mock_wca_client .get_request_headers .return_value = expected_headers
206
+ mock_wca_client .infer_from_parameters .reset_mock (side_effect = True )
207
+ mock_wca_client .infer_from_parameters .side_effect = None
208
+
196
209
with self .assertLogs (logger = "ansible_ai_connect.users.signals" , level = "DEBUG" ) as signals :
197
210
with self .assertLogs (logger = "root" , level = "DEBUG" ) as log :
198
211
r = self .client .post (
@@ -202,8 +215,19 @@ def _test_set_model_id(self, has_seat):
202
215
)
203
216
204
217
self .assertEqual (r .status_code , HTTPStatus .NO_CONTENT )
218
+
219
+ mock_wca_client .get_request_headers .assert_called_once_with (
220
+ api_key = api_key_value , identifier = None , lightspeed_user_uuid = None
221
+ )
222
+ mock_wca_client .infer_from_parameters .assert_called_once_with (
223
+ model_id_value ,
224
+ "" ,
225
+ VALIDATE_PROMPT ,
226
+ user = None ,
227
+ headers = expected_headers ,
228
+ )
205
229
mock_secret_manager .save_secret .assert_called_with (
206
- 123 , Suffixes .MODEL_ID , "secret_model_id"
230
+ 123 , Suffixes .MODEL_ID , model_id_value
207
231
)
208
232
self .assert_segment_log (log , "modelIdSet" , None )
209
233
@@ -432,18 +456,45 @@ def _test_validate_ok(self, has_seat):
432
456
self .user .organization = Organization .objects .get_or_create (id = 123 )[0 ]
433
457
self .user .rh_user_has_seat = has_seat
434
458
mock_secret_manager = apps .get_app_config ("ai" ).get_wca_secret_manager ()
459
+ mock_wca_client : ModelPipelineCompletions = apps .get_app_config ("ai" ).get_model_pipeline (
460
+ ModelPipelineCompletions
461
+ )
435
462
self .client .force_authenticate (user = self .user )
436
463
437
- def mock_get_secret_model_id (* args , ** kwargs ):
464
+ api_key_value = "some_api_key_for_validate"
465
+ model_id_value = "model_id_for_validate"
466
+
467
+ def mock_get_secret_side_effect (* args , ** kwargs ):
438
468
if args [1 ] == Suffixes .API_KEY :
439
- return {"SecretString" : "some_api_key" }
440
- return {"SecretString" : "model_id" }
469
+ return {"SecretString" : api_key_value }
470
+ if args [1 ] == Suffixes .MODEL_ID :
471
+ return {"SecretString" : model_id_value }
472
+ return None
441
473
442
- mock_secret_manager .get_secret .side_effect = mock_get_secret_model_id
474
+ mock_secret_manager .get_secret .side_effect = mock_get_secret_side_effect
475
+
476
+ expected_headers = {
477
+ "Authorization" : f"Bearer { api_key_value } " ,
478
+ "X-Test-Header-Validate" : "true" ,
479
+ }
480
+ mock_wca_client .get_request_headers .return_value = expected_headers
481
+ mock_wca_client .infer_from_parameters .reset_mock (side_effect = True )
482
+ mock_wca_client .infer_from_parameters .side_effect = None
443
483
444
484
with self .assertLogs (logger = "root" , level = "DEBUG" ) as log :
445
485
r = self .client .get (self .api_version_reverse ("wca_model_id_validator" ))
446
486
self .assertEqual (r .status_code , HTTPStatus .OK )
487
+
488
+ mock_wca_client .get_request_headers .assert_called_once_with (
489
+ api_key = api_key_value , identifier = None , lightspeed_user_uuid = None
490
+ )
491
+ mock_wca_client .infer_from_parameters .assert_called_once_with (
492
+ model_id_value ,
493
+ "" ,
494
+ VALIDATE_PROMPT ,
495
+ user = None ,
496
+ headers = expected_headers ,
497
+ )
447
498
self .assert_segment_log (log , "modelIdValidate" , None )
448
499
449
500
@override_settings (SEGMENT_WRITE_KEY = "DUMMY_KEY_VALUE" )
0 commit comments