Skip to content

Commit c294594

Browse files
committed
added ideal template, improved robustness
1 parent 2e02bf7 commit c294594

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

tools/aaz-flow/prompt_templates.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,32 @@
55

66
"""Prompt templates and static guidance for AAZ Flow tools."""
77

8+
IDEAL_STYLE = """
9+
@ResourceGroupPreparer(name_prefix="cli_test_ag_with_non_v2_sku_", location="westus")
10+
def test_ag_with_non_v2_sku(self):
11+
self.kwargs.update({
12+
"ag_name": self.create_random_name("ag-", 12),
13+
"port_name": self.create_random_name("port-", 12),
14+
"lisener_name": self.create_random_name("lisener-", 12),
15+
"rule_name": self.create_random_name("rule-", 12),
16+
})
17+
18+
self.cmd("network application-gateway create -n {ag_name} -g {rg} --sku WAF_Medium")
19+
20+
self.kwargs["front_ip"] = self.cmd("network application-gateway show -n {ag_name} -g {rg}").get_output_in_json()["frontendIPConfigurations"][0]["name"]
21+
self.cmd("network application-gateway frontend-port create -n {port_name} -g {rg} --gateway-name {ag_name} --port 8080")
22+
self.cmd("network application-gateway http-listener create -n {lisener_name} -g {rg} --gateway-name {ag_name} --frontend-ip {front_ip} --frontend-port {port_name}")
23+
24+
self.cmd(
25+
"network application-gateway rule create -n {rule_name} -g {rg} --gateway-name {ag_name} --http-listener {lisener_name}",
26+
checks=[
27+
self.check("name", "{ag_name}"),
28+
self.check("sku.tier", "WAF")
29+
]
30+
)
31+
32+
self.cmd("network application-gateway delete -n {ag_name} -g {rg}")
33+
"""
834

935
def get_testgen_static_instructions() -> str:
1036
return (
@@ -13,13 +39,10 @@ def get_testgen_static_instructions() -> str:
1339
"Generate tests that achieve at least 80%% coverage of methods and parameters covering primary commands for the target module.\n"
1440
"To understand the primary commands that need to be tested, read through and understand the target module's generated AAZ commands.\n"
1541
"Constraints: \n"
16-
"- Include necessary imports: azure.cli.testsdk imports and others only as required and seen in reference.\n"
17-
"- Use self.kwargs for dynamic values.\n"
18-
"- Use ResourceGroupPreparer if a resource group is implied.\n"
19-
"- Add minimal checks (e.g., self.check) where sensible.\n"
2042
"- Keep tests safe-by-default; avoid destructive operations unless clearly required.\n"
2143
"- Ensure tests can run in parallel without conflicts.\n"
2244
"- If tests are large and can be safely and logically split, create multiple test methods (i.e. avoid a single CRUD test if possible, split it into multiple tests if logically and safely separable).\n"
45+
"- It is highly preferred that all CRUD operations are not coupled in a single test.\n"
2346
"- Output only valid Python code for the test file, nothing else."
2447
)
2548

tools/aaz-flow/testgen.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import asyncio
99
import random
1010
from fastmcp import Context
11-
from prompt_templates import get_testgen_static_instructions, REF_STYLE_LABEL
11+
from prompt_templates import get_testgen_static_instructions, REF_STYLE_LABEL, IDEAL_STYLE
1212

1313

1414
async def check_module_status(ctx: Context, paths: dict):
@@ -120,6 +120,8 @@ def build_testgen_prompt(
120120

121121
if reference_snippet:
122122
parts.append(REF_STYLE_LABEL + reference_snippet)
123+
124+
parts.append("Here is the ideal test style example to follow:\n" + IDEAL_STYLE)
123125

124126
return "\n\n".join(parts)
125127

@@ -147,6 +149,11 @@ def strip_code_fences(text: str) -> str:
147149
return max(blocks, key=len)
148150
return text.strip()
149151

152+
def strip_shebang(text: str) -> str:
153+
lines = text.strip().splitlines()
154+
if lines and lines[0].startswith("#!"):
155+
return "\n".join(lines[1:]).strip()
156+
return text.strip()
150157

151158
async def generate_tests(ctx: Context, paths: dict):
152159
module_name, commands, test_file = await check_path_status(ctx, paths)
@@ -171,7 +178,6 @@ async def generate_tests(ctx: Context, paths: dict):
171178
sampling_prompt = build_testgen_prompt(
172179
module_name, commands, reference_snippet, extracted_examples
173180
)
174-
await ctx.info("Constructed test generation prompt as follows:\n" + sampling_prompt)
175181
max_retries = int(os.getenv("TESTGEN_RETRIES", "5"))
176182
base_delay = float(os.getenv("TESTGEN_RETRY_BASE_DELAY", "2"))
177183

@@ -186,7 +192,7 @@ async def generate_tests(ctx: Context, paths: dict):
186192
)
187193
sampled = await ctx.sample(sampling_prompt)
188194
raw_content = (getattr(sampled, "text", "") or "").strip()
189-
content = strip_code_fences(raw_content)
195+
content = strip_shebang(strip_code_fences(raw_content))
190196
if content:
191197
break
192198
last_err = RuntimeError("Empty content returned from provider")

0 commit comments

Comments
 (0)