Skip to content

Commit e2fa900

Browse files
authored
To use openshift SA cert when required for SSL verfication for TLS enablement (#1751)
* to use openshift SA cert when provided for SSL verfication * add utc for the change * fix pre-commit * address review * fix review * fix ci
1 parent 18094e4 commit e2fa900

File tree

6 files changed

+1223
-6
lines changed

6 files changed

+1223
-6
lines changed

ansible_ai_connect/ai/api/model_pipelines/http/configuration.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,15 +48,18 @@ def __init__(
4848
verify_ssl: bool,
4949
stream: bool = False,
5050
mcp_servers: Optional[list[dict[str, str]]] = None,
51+
ca_cert_file: Optional[str] = None,
5152
):
5253
super().__init__(inference_url, model_id, timeout, enable_health_check)
5354
self.verify_ssl = verify_ssl
5455
self.stream = stream
5556
self.mcp_servers = mcp_servers or []
57+
self.ca_cert_file = ca_cert_file
5658

5759
verify_ssl: bool
5860
stream: bool
5961
mcp_servers: Optional[list[dict[str, str]]] = None
62+
ca_cert_file: Optional[str] = None
6063

6164

6265
@Register(api_type="http")
@@ -73,6 +76,7 @@ def __init__(self, **kwargs):
7376
verify_ssl=kwargs["verify_ssl"],
7477
stream=kwargs["stream"],
7578
mcp_servers=kwargs["mcp_servers"],
79+
ca_cert_file=kwargs.get("ca_cert_file"),
7680
),
7781
)
7882

@@ -81,6 +85,9 @@ def __init__(self, **kwargs):
8185
class HttpConfigurationSerializer(BaseConfigSerializer):
8286
verify_ssl = serializers.BooleanField(required=False, default=True)
8387
stream = serializers.BooleanField(required=False, default=False)
88+
ca_cert_file = serializers.CharField(
89+
required=False, default=None, allow_blank=True, allow_null=True
90+
)
8491
mcp_servers = serializers.ListSerializer(
8592
child=MCPServersSerializer(), allow_empty=True, required=False, default=None
8693
)

ansible_ai_connect/ai/api/model_pipelines/http/pipelines.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import copy
1616
import json
1717
import logging
18+
import ssl
1819
from json import JSONDecodeError
1920
from typing import Any, AsyncGenerator
2021

@@ -96,7 +97,9 @@ def invoke(self, params: CompletionsParameters) -> CompletionsResponse:
9697
headers=self.headers,
9798
json=model_input,
9899
timeout=self.task_gen_timeout(task_count),
99-
verify=self.config.verify_ssl,
100+
verify=(
101+
self.config.ca_cert_file if self.config.ca_cert_file else self.config.verify_ssl
102+
),
100103
)
101104
result.raise_for_status()
102105
response = json.loads(result.text)
@@ -114,7 +117,13 @@ def self_test(self) -> HealthCheckSummary:
114117
}
115118
)
116119
try:
117-
res = requests.get(url, verify=self.config.verify_ssl, timeout=1)
120+
res = requests.get(
121+
url,
122+
verify=(
123+
self.config.ca_cert_file if self.config.ca_cert_file else self.config.verify_ssl
124+
),
125+
timeout=1,
126+
)
118127
res.raise_for_status()
119128
except Exception as e:
120129
logger.exception(str(e))
@@ -146,7 +155,9 @@ def self_test(self) -> HealthCheckSummary:
146155
self.config.inference_url + "/readiness",
147156
headers=headers,
148157
timeout=1,
149-
verify=self.config.verify_ssl,
158+
verify=(
159+
self.config.ca_cert_file if self.config.ca_cert_file else self.config.verify_ssl
160+
),
150161
)
151162
r.raise_for_status()
152163

@@ -203,7 +214,7 @@ def invoke(self, params: ChatBotParameters) -> ChatBotResponse:
203214
headers=self.headers,
204215
json=data,
205216
timeout=self.task_gen_timeout(1),
206-
verify=self.config.verify_ssl,
217+
verify=self.config.ca_cert_file if self.config.ca_cert_file else self.config.verify_ssl,
207218
)
208219

209220
if response.status_code == 200:
@@ -267,8 +278,11 @@ def send_schema1_event(self, ev):
267278
async def async_invoke(self, params: StreamingChatBotParameters) -> AsyncGenerator:
268279

269280
# Configure SSL context based on verify_ssl setting
270-
ssl_context = self.config.verify_ssl
271-
connector = aiohttp.TCPConnector(ssl=ssl_context)
281+
if self.config.ca_cert_file:
282+
ssl_context = ssl.create_default_context(cafile=self.config.ca_cert_file)
283+
connector = aiohttp.TCPConnector(ssl=ssl_context)
284+
else:
285+
connector = aiohttp.TCPConnector(ssl=self.config.verify_ssl)
272286

273287
async with aiohttp.ClientSession(raise_for_status=True, connector=connector) as session:
274288
headers = {

0 commit comments

Comments
 (0)