Skip to content

Commit aab6c21

Browse files
authored
Reapply "Revert 'Adds X-Request-LightspeedUser to WCA requests"
Adds more healthcheck request testing Removes api_key from infer_from_parameters
1 parent ee5ea55 commit aab6c21

File tree

15 files changed

+110
-26
lines changed

15 files changed

+110
-26
lines changed

ansible_ai_connect/ai/api/model_pipelines/dummy/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def invoke(self, params: CompletionsParameters) -> CompletionsResponse:
8383
response_body["model_id"] = "_"
8484
return response_body
8585

86-
def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_id=None):
86+
def infer_from_parameters(self, model_id, context, prompt, suggestion_id=None, headers=None):
8787
raise NotImplementedError
8888

8989
def self_test(self) -> HealthCheckSummary:

ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ def self_test(self) -> HealthCheckSummary:
122122
)
123123
return summary
124124

125-
def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_id=None):
125+
def infer_from_parameters(self, model_id, context, prompt, suggestion_id=None, headers=None):
126126
raise NotImplementedError
127127

128128

ansible_ai_connect/ai/api/model_pipelines/langchain/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -233,7 +233,7 @@ def self_test(self) -> HealthCheckSummary:
233233
def get_chat_model(self, model_id):
234234
raise NotImplementedError
235235

236-
def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_id=None):
236+
def infer_from_parameters(self, model_id, context, prompt, suggestion_id=None, headers=None):
237237
raise NotImplementedError
238238

239239

ansible_ai_connect/ai/api/model_pipelines/llamacpp/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def invoke(self, params: CompletionsParameters) -> CompletionsResponse:
126126
except requests.exceptions.Timeout:
127127
raise ModelTimeoutError
128128

129-
def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_id=None):
129+
def infer_from_parameters(self, model_id, context, prompt, suggestion_id=None, headers=None):
130130
raise NotImplementedError
131131

132132
def self_test(self) -> HealthCheckSummary:

ansible_ai_connect/ai/api/model_pipelines/nop/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ def __init__(self, config: NopConfiguration):
6464
def invoke(self, params: CompletionsParameters) -> CompletionsResponse:
6565
raise FeatureNotAvailable
6666

67-
def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_id=None):
67+
def infer_from_parameters(self, model_id, context, prompt, suggestion_id=None, headers=None):
6868
raise NotImplementedError
6969

7070
def self_test(self) -> HealthCheckSummary:

ansible_ai_connect/ai/api/model_pipelines/ollama/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class OllamaCompletionsPipeline(LangchainCompletionsPipeline[OllamaConfiguration
5353
def __init__(self, config: OllamaConfiguration):
5454
super().__init__(config=config)
5555

56-
def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_id=None):
56+
def infer_from_parameters(self, model_id, context, prompt, suggestion_id=None, headers=None):
5757
raise NotImplementedError
5858

5959
def self_test(self) -> HealthCheckSummary:

ansible_ai_connect/ai/api/model_pipelines/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,7 +335,7 @@ def alias():
335335
return "model-server"
336336

337337
@abstractmethod
338-
def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_id=None):
338+
def infer_from_parameters(self, model_id, context, prompt, suggestion_id=None, headers=None):
339339
raise NotImplementedError
340340

341341

ansible_ai_connect/ai/api/model_pipelines/tests/test_wca_client.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@
5858
from ansible_ai_connect.ai.api.model_pipelines.tests import mock_pipeline_config
5959
from ansible_ai_connect.ai.api.model_pipelines.wca.pipelines_base import (
6060
WCA_REQUEST_ID_HEADER,
61+
WCA_REQUEST_USER_UUID_HEADER,
6162
ibm_cloud_identity_token_hist,
6263
ibm_cloud_identity_token_retry_counter,
6364
wca_codegen_hist,
@@ -846,6 +847,7 @@ def _do_inference(
846847
):
847848
model_id = "zavala"
848849
api_key = "abc123"
850+
user_uuid = str(uuid.uuid4())
849851
context = ""
850852
prompt = prompt if prompt else "- name: install ffmpeg on Red Hat Enterprise Linux"
851853

@@ -874,13 +876,14 @@ def _do_inference(
874876
response = MockResponse(
875877
json=predictions,
876878
status_code=200,
877-
headers={WCA_REQUEST_ID_HEADER: request_id},
879+
headers={WCA_REQUEST_ID_HEADER: request_id, WCA_REQUEST_USER_UUID_HEADER: user_uuid},
878880
)
879881

880882
requestHeaders = {
881883
"Content-Type": "application/json",
882884
"Authorization": f"Bearer {token['access_token']}",
883885
WCA_REQUEST_ID_HEADER: suggestion_id,
886+
WCA_REQUEST_USER_UUID_HEADER: user_uuid,
884887
}
885888

886889
model_client = WCASaaSCompletionsPipeline(self.config)
@@ -889,9 +892,13 @@ def _do_inference(
889892
model_client.get_model_id = Mock(return_value=model_id)
890893
model_client.get_api_key = Mock(return_value=api_key)
891894

895+
mock_request = Mock()
896+
mock_request.user = Mock()
897+
mock_request.user.uuid = user_uuid
898+
892899
result = model_client.invoke(
893900
CompletionsParameters.init(
894-
request=Mock(),
901+
request=mock_request,
895902
model_input=model_input,
896903
model_id=model_id,
897904
suggestion_id=suggestion_id,
@@ -1485,6 +1492,7 @@ def test_get_model_id_without_setting(self):
14851492
class TestWCAOnPremCodegen(WisdomServiceLogAwareTestCase):
14861493
prompt = "- name: install ffmpeg on Red Hat Enterprise Linux"
14871494
suggestion_id = "suggestion_id"
1495+
user_uuid = str(uuid.uuid4())
14881496
token = base64.b64encode(bytes("username:12345", "ascii")).decode("ascii")
14891497
codegen_data = {
14901498
"model_id": "model-name",
@@ -1493,6 +1501,7 @@ class TestWCAOnPremCodegen(WisdomServiceLogAwareTestCase):
14931501
request_headers = {
14941502
"Authorization": f"ZenApiKey {token}",
14951503
WCA_REQUEST_ID_HEADER: suggestion_id,
1504+
WCA_REQUEST_USER_UUID_HEADER: user_uuid,
14961505
}
14971506
model_input = {
14981507
"instances": [
@@ -1519,9 +1528,13 @@ def setUp(self):
15191528
self.model_client.session.post = Mock(return_value=MockResponse(json={}, status_code=200))
15201529

15211530
def test_headers(self):
1531+
mock_request = Mock()
1532+
mock_request.user = Mock()
1533+
mock_request.user.uuid = self.user_uuid
1534+
15221535
self.model_client.invoke(
15231536
CompletionsParameters.init(
1524-
request=Mock(), model_input=self.model_input, suggestion_id=self.suggestion_id
1537+
request=mock_request, model_input=self.model_input, suggestion_id=self.suggestion_id
15251538
),
15261539
)
15271540
self.model_client.session.post.assert_called_once_with(
@@ -1533,10 +1546,13 @@ def test_headers(self):
15331546
)
15341547

15351548
def test_disabled_model_server_ssl(self):
1549+
mock_request = Mock()
1550+
mock_request.user = Mock()
1551+
mock_request.user.uuid = self.user_uuid
15361552
self.config.verify_ssl = False
15371553
self.model_client.invoke(
15381554
CompletionsParameters.init(
1539-
request=Mock(), model_input=self.model_input, suggestion_id=self.suggestion_id
1555+
request=mock_request, model_input=self.model_input, suggestion_id=self.suggestion_id
15401556
),
15411557
)
15421558
self.model_client.session.post.assert_called_once_with(

ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_base.py

Lines changed: 30 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@
8181

8282
WCA_REQUEST_ID_HEADER = "X-Request-ID"
8383

84+
WCA_REQUEST_USER_UUID_HEADER = "X-Request-LightspeedUser"
85+
8486
# from django_prometheus.middleware.DEFAULT_LATENCY_BUCKETS
8587
DEFAULT_LATENCY_BUCKETS = (
8688
0.01,
@@ -243,6 +245,20 @@ class WCABasePipeline(
243245
def __init__(self, config: WCA_PIPELINE_CONFIGURATION):
244246
super().__init__(config=config)
245247

248+
def _prepare_request_headers(
249+
self, request_user: Optional[User], api_key: str, identifier: Optional[str]
250+
) -> dict[str, Optional[str]]:
251+
"""
252+
Helper method to extract user UUID and get request headers.
253+
"""
254+
lightspeed_user_uuid_str: Optional[str] = None
255+
if request_user and hasattr(request_user, "uuid"):
256+
lightspeed_user_uuid_str = str(request_user.uuid)
257+
258+
return self.get_request_headers(
259+
api_key, identifier, lightspeed_user_uuid=lightspeed_user_uuid_str
260+
)
261+
246262
@staticmethod
247263
def log_backoff_exception(details):
248264
_, exc, _ = sys.exc_info()
@@ -284,7 +300,7 @@ def on_backoff_explain_role(details):
284300

285301
@abstractmethod
286302
def get_request_headers(
287-
self, api_key: str, identifier: Optional[str]
303+
self, api_key: str, identifier: Optional[str], lightspeed_user_uuid: Optional[str] = None
288304
) -> dict[str, Optional[str]]:
289305
raise NotImplementedError
290306

@@ -318,7 +334,10 @@ def invoke(self, params: CompletionsParameters) -> CompletionsResponse:
318334
try:
319335
api_key = self.get_api_key(request.user)
320336
model_id = self.get_model_id(request.user, model_id)
321-
result = self.infer_from_parameters(api_key, model_id, context, prompt, suggestion_id)
337+
338+
headers = self._prepare_request_headers(request.user, api_key, suggestion_id)
339+
340+
result = self.infer_from_parameters(model_id, context, prompt, suggestion_id, headers)
322341

323342
response = result.json()
324343
response["model_id"] = model_id
@@ -328,14 +347,13 @@ def invoke(self, params: CompletionsParameters) -> CompletionsResponse:
328347
except requests.exceptions.Timeout:
329348
raise ModelTimeoutError(model_id=model_id)
330349

331-
def infer_from_parameters(self, api_key, model_id, context, prompt, suggestion_id=None):
350+
def infer_from_parameters(self, model_id, context, prompt, suggestion_id=None, headers=None):
332351
data = {
333352
"model_id": model_id,
334353
"prompt": f"{context}{prompt}",
335354
}
336355
logger.debug(f"Inference API request payload: {json.dumps(data)}")
337356

338-
headers = self.get_request_headers(api_key, suggestion_id)
339357
task_count = len(get_task_names_from_prompt(prompt))
340358
prediction_url = f"{self.config.inference_url}/v1/wca/codegen/ansible"
341359

@@ -471,7 +489,8 @@ def invoke(self, params: PlaybookGenerationParameters) -> PlaybookGenerationResp
471489
api_key = self.get_api_key(request.user)
472490
model_id = self.get_model_id(request.user, model_id)
473491

474-
headers = self.get_request_headers(api_key, generation_id)
492+
headers = self._prepare_request_headers(request.user, api_key, generation_id)
493+
475494
data = {
476495
"model_id": model_id,
477496
"text": text,
@@ -553,7 +572,8 @@ def invoke(self, params: RoleGenerationParameters) -> RoleGenerationResponse:
553572
api_key = self.get_api_key(request.user)
554573
model_id = self.get_model_id(request.user, model_id)
555574

556-
headers = self.get_request_headers(api_key, generation_id)
575+
headers = self._prepare_request_headers(request.user, api_key, generation_id)
576+
557577
data = {
558578
"model_id": model_id,
559579
"text": text,
@@ -634,7 +654,8 @@ def invoke(self, params: PlaybookExplanationParameters) -> PlaybookExplanationRe
634654
api_key = self.get_api_key(request.user)
635655
model_id = self.get_model_id(request.user, model_id)
636656

637-
headers = self.get_request_headers(api_key, explanation_id)
657+
headers = self._prepare_request_headers(request.user, api_key, explanation_id)
658+
638659
data = {
639660
"model_id": model_id,
640661
"playbook": content,
@@ -696,7 +717,8 @@ def invoke(self, params: RoleExplanationParameters) -> RoleExplanationResponse:
696717
api_key = self.get_api_key(request.user)
697718
model_id = self.get_model_id(request.user, model_id)
698719

699-
headers = self.get_request_headers(api_key, explanation_id)
720+
headers = self._prepare_request_headers(request.user, api_key, explanation_id)
721+
700722
data = {
701723
"role_name": params.role_name,
702724
"model_id": model_id,

ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_onprem.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@
4747
)
4848
from ansible_ai_connect.ai.api.model_pipelines.wca.pipelines_base import (
4949
WCA_REQUEST_ID_HEADER,
50+
WCA_REQUEST_USER_UUID_HEADER,
5051
WCABaseCompletionsPipeline,
5152
WCABaseContentMatchPipeline,
5253
WCABaseMetaData,
@@ -114,12 +115,13 @@ def __init__(self, config: WCAOnPremConfiguration):
114115
# User may provide an override value if the setting is not defined.
115116

116117
def get_request_headers(
117-
self, api_key: str, identifier: Optional[str]
118+
self, api_key: str, identifier: Optional[str], lightspeed_user_uuid: Optional[str] = None
118119
) -> dict[str, Optional[str]]:
119120
base_headers = self._get_base_headers(api_key)
120121
return {
121122
**base_headers,
122123
WCA_REQUEST_ID_HEADER: str(identifier) if identifier else None,
124+
WCA_REQUEST_USER_UUID_HEADER: lightspeed_user_uuid if lightspeed_user_uuid else None,
123125
}
124126

125127
def _get_base_headers(self, api_key: str) -> dict[str, str]:
@@ -150,11 +152,14 @@ def self_test(self) -> HealthCheckSummary:
150152
}
151153
)
152154
try:
155+
headers = self.get_request_headers(wca_api_key, None)
156+
153157
self.infer_from_parameters(
154-
wca_api_key,
155158
wca_model_id,
156159
"",
157160
"- name: install ffmpeg on Red Hat Enterprise Linux",
161+
None,
162+
headers,
158163
)
159164
except Exception as e:
160165
logger.exception(str(e))

0 commit comments

Comments
 (0)