Skip to content

Commit 156aa6b

Browse files
authored
Lint the results from role generation (#1513)
1 parent b2e04ee commit 156aa6b

File tree

10 files changed

+176
-11
lines changed

10 files changed

+176
-11
lines changed

ansible_ai_connect/ai/api/model_pipelines/dummy/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def __init__(self, config: DummyConfiguration):
181181

182182
def invoke(self, params: RoleGenerationParameters) -> RoleGenerationResponse:
183183
create_outline = params.create_outline
184-
return "install_nginx", ROLE_FILES, ROLE_OUTLINE.strip() if create_outline else ""
184+
return "install_nginx", ROLE_FILES, ROLE_OUTLINE.strip() if create_outline else "", []
185185

186186
def self_test(self) -> Optional[HealthCheckSummary]:
187187
raise NotImplementedError

ansible_ai_connect/ai/api/model_pipelines/langchain/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -374,7 +374,7 @@ def invoke(self, params: RoleGenerationParameters) -> RoleGenerationResponse:
374374
else:
375375
files[1]["content"] = unwrap_message_with_yaml_answer(content)
376376

377-
return role, files, outline
377+
return role, files, outline, []
378378

379379
def self_test(self) -> Optional[HealthCheckSummary]:
380380
raise NotImplementedError

ansible_ai_connect/ai/api/model_pipelines/langchain/tests/test_pipeline.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -316,7 +316,7 @@ def fake_get_chat_mode(model_id=None):
316316
self.my_client.get_chat_model = fake_get_chat_mode
317317

318318
def test_generate_role(self):
319-
role, files, outline = self.my_client.invoke(
319+
role, files, outline, warnings = self.my_client.invoke(
320320
RoleGenerationParameters.init(
321321
request=Mock(),
322322
text="foo",
@@ -326,9 +326,10 @@ def test_generate_role(self):
326326
self.assertEqual(role, "vpc_subnet_ec2")
327327
self.assertEqual(files[1]["content"], TestUnwrapRoleAnswer._expected_second_request_content)
328328
self.assertEqual(outline, "")
329+
self.assertEqual(warnings, [])
329330

330331
def test_generate_role_with_outline(self):
331-
role, files, outline = self.my_client.invoke(
332+
role, files, outline, warnings = self.my_client.invoke(
332333
RoleGenerationParameters.init(
333334
request=Mock(),
334335
text="foo",
@@ -343,6 +344,7 @@ def test_generate_role_with_outline(self):
343344
2. Display EC2 Instance ID
344345
""",
345346
)
347+
self.assertEqual(warnings, [])
346348

347349

348350
class TestLangChainPlaybookExplanationPipeline(WisdomServiceLogAwareTestCase):

ansible_ai_connect/ai/api/model_pipelines/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def init(
152152
)
153153

154154

155-
RoleGenerationResponse = tuple[str, list, str]
155+
RoleGenerationResponse = tuple[str, list, str, list]
156156

157157

158158
@define

ansible_ai_connect/ai/api/model_pipelines/tests/test_wca_client.py

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
ContentMatchParameters,
5454
PlaybookExplanationParameters,
5555
PlaybookGenerationParameters,
56+
RoleGenerationParameters,
5657
)
5758
from ansible_ai_connect.ai.api.model_pipelines.tests import mock_pipeline_config
5859
from ansible_ai_connect.ai.api.model_pipelines.wca.pipelines_base import (
@@ -63,6 +64,7 @@
6364
wca_codegen_playbook_hist,
6465
wca_codegen_playbook_retry_counter,
6566
wca_codegen_retry_counter,
67+
wca_codegen_role_hist,
6668
wca_codematch_hist,
6769
wca_codematch_retry_counter,
6870
wca_explain_playbook_hist,
@@ -77,6 +79,7 @@
7779
WCASaaSContentMatchPipeline,
7880
WCASaaSPlaybookExplanationPipeline,
7981
WCASaaSPlaybookGenerationPipeline,
82+
WCASaaSRoleGenerationPipeline,
8083
)
8184
from ansible_ai_connect.test_utils import (
8285
WisdomAppsBackendMocking,
@@ -335,7 +338,7 @@ def test_get_api_key_with_wca_configured(self):
335338

336339
@override_settings(WCA_SECRET_BACKEND_TYPE="dummy")
337340
@override_settings(ENABLE_ANSIBLE_LINT_POSTPROCESS=False)
338-
class TestWCAClientGeneration(WisdomAppsBackendMocking, WisdomServiceLogAwareTestCase):
341+
class TestWCAClientPlaybookGeneration(WisdomAppsBackendMocking, WisdomServiceLogAwareTestCase):
339342
def setUp(self):
340343
super().setUp()
341344
wca_client = WCASaaSPlaybookGenerationPipeline(
@@ -510,6 +513,64 @@ def test_playbook_gen_request_id_correlation_failure(self):
510513
)
511514

512515

516+
@override_settings(WCA_SECRET_BACKEND_TYPE="dummy")
517+
@override_settings(ENABLE_ANSIBLE_LINT_POSTPROCESS=False)
518+
class TestWCAClientRoleGeneration(WisdomAppsBackendMocking, WisdomServiceLogAwareTestCase):
519+
def setUp(self):
520+
super().setUp()
521+
wca_client = WCASaaSRoleGenerationPipeline(
522+
mock_pipeline_config("wca", api_key=None, model_id=None)
523+
)
524+
wca_client.get_api_key = Mock(return_value="some-key")
525+
wca_client.get_token = Mock(return_value={"access_token": "a-token"})
526+
wca_client.get_model_id = Mock(return_value="a-random-model")
527+
wca_client.session = Mock()
528+
response = Mock
529+
response.text = (
530+
'{"name": "foo_bar", "outline": "Ahh!", "files": [{"path": '
531+
'"roles/foo_bar/tasks/main.yml", "content": "some content", '
532+
'"file_type": "task"}, {"path": "roles/foo_bar/defaults/main.yml", '
533+
'"content": "some content", "file_type": "default"}], "warnings": []}'
534+
)
535+
response.status_code = 200
536+
response.headers = {WCA_REQUEST_ID_HEADER: WCA_REQUEST_ID_HEADER}
537+
response.raise_for_status = Mock()
538+
wca_client.session.post.return_value = response
539+
self.wca_client = wca_client
540+
541+
@assert_call_count_metrics(metric=wca_codegen_role_hist)
542+
@override_settings(ENABLE_ANSIBLE_LINT_POSTPROCESS=True)
543+
def test_role_gen_with_lint(self):
544+
fake_linter = Mock()
545+
fake_linter.run_linter.return_value = "I'm super fake!"
546+
self.mock_ansible_lint_caller_with(fake_linter)
547+
name, files, outline, warnings = self.wca_client.invoke(
548+
RoleGenerationParameters.init(
549+
request=Mock(), text="Install Wordpress", create_outline=True
550+
)
551+
)
552+
self.assertEqual(name, "foo_bar")
553+
self.assertEqual(outline, "Ahh!")
554+
self.assertEqual(warnings, [])
555+
for file in files:
556+
self.assertEqual(file["content"], "I'm super fake!")
557+
558+
@assert_call_count_metrics(metric=wca_codegen_role_hist)
559+
@override_settings(ENABLE_ANSIBLE_LINT_POSTPROCESS=True)
560+
def test_role_gen_when_is_not_initialized(self):
561+
self.mock_ansible_lint_caller_with(None)
562+
name, files, outline, warnings = self.wca_client.invoke(
563+
RoleGenerationParameters.init(
564+
request=Mock(), text="Install Wordpress", create_outline=True
565+
)
566+
)
567+
self.assertEqual(name, "foo_bar")
568+
self.assertEqual(outline, "Ahh!")
569+
self.assertEqual(warnings, [])
570+
for file in files:
571+
self.assertEqual(file["content"], "some content")
572+
573+
513574
@override_settings(WCA_SECRET_BACKEND_TYPE="dummy")
514575
@override_settings(ENABLE_ANSIBLE_LINT_POSTPROCESS=False)
515576
class TestWCAClientExplanation(WisdomAppsBackendMocking, WisdomServiceLogAwareTestCase):

ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_base.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,12 @@
120120
namespace=NAMESPACE,
121121
buckets=DEFAULT_LATENCY_BUCKETS,
122122
)
123+
wca_codegen_role_hist = Histogram(
124+
"wca_codegen_role_latency_seconds",
125+
"Histogram of WCA codegen-role API processing time",
126+
namespace=NAMESPACE,
127+
buckets=DEFAULT_LATENCY_BUCKETS,
128+
)
123129
wca_explain_playbook_hist = Histogram(
124130
"wca_explain_playbook_latency_seconds",
125131
"Histogram of WCA explain-playbook API processing time",
@@ -147,6 +153,11 @@
147153
"Counter of WCA codegen-playbook API invocation retries",
148154
namespace=NAMESPACE,
149155
)
156+
wca_codegen_role_retry_counter = Counter(
157+
"wca_codegen_role_retries",
158+
"Counter of WCA codegen-role API invocation retries",
159+
namespace=NAMESPACE,
160+
)
150161
wca_explain_playbook_retry_counter = Counter(
151162
"wca_explain_playbook_retries",
152163
"Counter of WCA explain-playbook API invocation retries",
@@ -241,6 +252,11 @@ def on_backoff_codegen_playbook(details):
241252
WCABasePipeline.log_backoff_exception(details)
242253
wca_codegen_playbook_retry_counter.inc()
243254

255+
@staticmethod
256+
def on_backoff_codegen_role(details):
257+
WCABasePipeline.log_backoff_exception(details)
258+
wca_codegen_role_retry_counter.inc()
259+
244260
@staticmethod
245261
def on_backoff_explain_playbook(details):
246262
WCABasePipeline.log_backoff_exception(details)

ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_dummy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ def __init__(self, config: WCADummyConfiguration):
9898
super().__init__(config=config)
9999

100100
def invoke(self, params: RoleGenerationParameters) -> RoleGenerationResponse:
101-
return "wca_dummy_role", [], "wca_dummy_outline"
101+
return "wca_dummy_role", [], "wca_dummy_outline", []
102102

103103
def self_test(self) -> Optional[HealthCheckSummary]:
104104
raise NotImplementedError

ansible_ai_connect/ai/api/model_pipelines/wca/pipelines_saas.py

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import json
1516
import logging
1617
from abc import ABCMeta
17-
from typing import TYPE_CHECKING, Generic, Optional
18+
from typing import TYPE_CHECKING, Generic, Optional, cast
1819

1920
import backoff
2021
from django.apps import apps
@@ -32,6 +33,7 @@
3233
WcaKeyNotFound,
3334
WcaModelIdNotFound,
3435
WcaNoDefaultModelId,
36+
WcaRequestIdCorrelationFailure,
3537
WcaTokenFailure,
3638
)
3739
from ansible_ai_connect.ai.api.model_pipelines.pipelines import (
@@ -65,8 +67,11 @@
6567
WcaModelRequestException,
6668
WcaTokenRequestException,
6769
ibm_cloud_identity_token_hist,
70+
wca_codegen_role_hist,
6871
)
6972
from ansible_ai_connect.ai.api.model_pipelines.wca.wca_utils import (
73+
Context,
74+
InferenceResponseChecks,
7075
TokenContext,
7176
TokenResponseChecks,
7277
)
@@ -336,6 +341,73 @@ class WCASaaSRoleGenerationPipeline(
336341
def __init__(self, config: WCASaaSConfiguration):
337342
super().__init__(config=config)
338343

344+
# This should be moved to the base WCA class when it becomes available on-prem
345+
def invoke(self, params: RoleGenerationParameters) -> RoleGenerationResponse:
346+
request = params.request
347+
text = params.text
348+
create_outline = params.create_outline
349+
outline = params.outline
350+
model_id = params.model_id
351+
generation_id = params.generation_id
352+
353+
organization_id = request.user.organization.id if request.user.organization else None
354+
api_key = self.get_api_key(request.user, organization_id)
355+
model_id = self.get_model_id(request.user, organization_id, model_id)
356+
357+
headers = self.get_request_headers(api_key, generation_id)
358+
data = {
359+
"model_id": model_id,
360+
"text": text,
361+
"create_outline": create_outline,
362+
}
363+
if outline:
364+
data["outline"] = outline
365+
366+
@backoff.on_exception(
367+
backoff.expo,
368+
Exception,
369+
max_tries=self.retries + 1,
370+
giveup=self.fatal_exception,
371+
on_backoff=self.on_backoff_codegen_role,
372+
)
373+
@wca_codegen_role_hist.time()
374+
def post_request():
375+
return self.session.post(
376+
f"{self.config.inference_url}/v1/wca/codegen/ansible/roles",
377+
headers=headers,
378+
json=data,
379+
verify=self.config.verify_ssl,
380+
)
381+
382+
result = post_request()
383+
384+
x_request_id = result.headers.get(WCA_REQUEST_ID_HEADER)
385+
if generation_id and x_request_id:
386+
# request/payload suggestion_id is a UUID not a string whereas
387+
# HTTP headers are strings.
388+
if x_request_id != str(generation_id):
389+
raise WcaRequestIdCorrelationFailure(model_id=model_id, x_request_id=x_request_id)
390+
391+
context = Context(model_id, result, False)
392+
InferenceResponseChecks().run_checks(context)
393+
result.raise_for_status()
394+
395+
response = json.loads(result.text)
396+
397+
name = response["name"]
398+
files = response["files"]
399+
outline = response["outline"]
400+
warnings = response["warnings"] if "warnings" in response else []
401+
402+
from ansible_ai_connect.ai.apps import AiConfig
403+
404+
ai_config = cast(AiConfig, apps.get_app_config("ai"))
405+
if ansible_lint_caller := ai_config.get_ansible_lint_caller():
406+
for file in files:
407+
file["content"] = ansible_lint_caller.run_linter(file["content"])
408+
409+
return name, files, outline, warnings
410+
339411
def self_test(self) -> Optional[HealthCheckSummary]:
340412
raise NotImplementedError
341413

ansible_ai_connect/ai/api/tests/test_role_generation_view.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,14 +47,26 @@ def __init__(self):
4747

4848
class MockedPipelineRoleGeneration(ModelPipelineRoleGeneration[MockedConfig]):
4949

50-
def __init__(self, response_roles: str, response_files: list, response_outline: str):
50+
def __init__(
51+
self,
52+
response_roles: str,
53+
response_files: list,
54+
response_outline: str,
55+
response_warnings: list,
56+
):
5157
super().__init__(MockedConfig())
5258
self.response_roles = response_roles
5359
self.response_files = response_files
5460
self.response_outline = response_outline
61+
self.response_warnings = response_warnings
5562

5663
def invoke(self, params: RoleGenerationParameters) -> RoleGenerationResponse:
57-
return self.response_roles, self.response_files, self.response_outline
64+
return (
65+
self.response_roles,
66+
self.response_files,
67+
self.response_outline,
68+
self.response_warnings,
69+
)
5870

5971
def self_test(self) -> Optional[HealthCheckSummary]:
6072
raise NotImplementedError
@@ -127,6 +139,7 @@ def test_with_anonymized_response(self):
127139
}
128140
],
129141
"Install mysql and email [email protected]",
142+
[],
130143
)
131144
),
132145
):

ansible_ai_connect/ai/api/views.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -933,7 +933,7 @@ def post(self, request) -> Response:
933933
ModelPipelineRoleGeneration
934934
)
935935

936-
roles, files, outline = llm.invoke(
936+
roles, files, outline, warnings = llm.invoke(
937937
RoleGenerationParameters.init(
938938
request=request,
939939
text=self.validated_data["text"],
@@ -964,6 +964,7 @@ def post(self, request) -> Response:
964964
"files": anonymized_files,
965965
"format": "plaintext",
966966
"generationId": self.validated_data["generationId"],
967+
"warnings": warnings,
967968
}
968969

969970
return Response(

0 commit comments

Comments
 (0)