Skip to content

Commit 8480b0b

Browse files
committed
priliminary test generation
1 parent 1fedeb5 commit 8480b0b

File tree

4 files changed

+206
-54
lines changed

4 files changed

+206
-54
lines changed

tools/aaz-flow/helpers.py

Lines changed: 18 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ async def get_name(ctx: Context) -> str:
9090
"This list is fetched directly from the Azure REST API Specs repository. " \
9191
"Ask the user in a professional manner to select a module/extension from the list. " \
9292
"The list is provided when they click on the Respond button so do not give them any options in the questions itself. " \
93-
"The result of this option selection will determine which module's code will be generated using AAZ."
93+
"The result of this option selection will determine which module's code will be generated for Azure CLI."
9494
)
9595
extension_choice = await ctx.elicit(
9696
message=choice_prompt.text,
@@ -142,7 +142,6 @@ async def browse_specs(ctx: Context, base_path: str):
142142
dirs = [e for e in entries if os.path.isdir(os.path.join(current_path, e))]
143143
files = [e for e in entries if os.path.isfile(os.path.join(current_path, e)) and e.endswith((".json", ".yaml", ".yml"))]
144144

145-
# Labels shown to the user vs actual values
146145
labels = [".."] + [f"> {d}" for d in dirs] + files
147146
mapping = dict(zip(labels, [".."] + dirs + files))
148147

@@ -164,11 +163,9 @@ async def browse_specs(ctx: Context, base_path: str):
164163
current_path = os.path.join(current_path, selected)
165164
await ctx.info(f"az_cli : Entered directory: {selected}")
166165
else:
167-
# A spec file was chosen
168166
selected_file_path = os.path.join(current_path, selected)
169167
await ctx.info(f"az_cli : Selected spec file: {selected_file_path}")
170168

171-
# Relative path for extracting metadata
172169
rel_path = os.path.relpath(selected_file_path, base_path)
173170
parts = rel_path.split(os.sep)
174171

@@ -282,59 +279,30 @@ async def execute_commands(ctx: Context, paths: dict, request: AAZRequest):
282279

283280
cmd2 = (
284281
f"{aaz_dev} cli generate-by-swagger-tag "
285-
f"--aaz-path {paths['aaz']} "
286-
f"--cli-path {paths['cli']} "
287-
f"--cli-extension-path {paths['cli_extension']} "
288-
f"--extension-or-module-name {request.name} "
289-
f"--swagger-module-path {request.swagger_module_path} "
290-
f"--resource-provider {request.resource_provider} "
291-
f"--swagger-tag {request.swagger_tag} "
282+
f"-a {paths['aaz']} "
283+
f"-e {paths['cli_extension']} "
284+
f"--name {request.name} "
285+
f"--sm {request.swagger_module_path} "
286+
f"--rp {request.resource_provider} "
287+
f"--tag {request.swagger_tag} "
292288
f"--profile latest"
293289
)
294290

295291
try:
296292
await run_command(ctx, cmd1, "Generate command model from Swagger", 50, 80)
293+
generated_init_path = f"{paths['cli']}/src/azure-cli/azure/cli/command_modules/{request.name}/aaz/__init__.py"
294+
max_wait = 30
295+
waited = 0
296+
while not os.path.exists(generated_init_path) and waited < max_wait:
297+
await ctx.info(f"az_cli : Waiting for {generated_init_path} to be created...")
298+
await asyncio.sleep(1)
299+
waited += 1
300+
if not os.path.exists(generated_init_path):
301+
await ctx.info(f"az_cli : Timed out waiting for {generated_init_path}")
302+
return f"Code generation failed: {generated_init_path} was not created."
297303
await run_command(ctx, cmd2, "Generate CLI from Swagger tag", 80, 100)
298304
except Exception as e:
299305
await ctx.info(f"az_cli : Code generation failed: {str(e)}")
300306
return f"Code generation failed: {str(e)}"
301307

302-
return "Azure CLI code generation completed successfully!"
303-
304-
async def generate_tests(ctx: "Context"):
305-
await ctx.info("Starting test generation workflow.")
306-
307-
module_name = getattr(ctx, "generated_module", None)
308-
if not module_name:
309-
response = await ctx.elicit("Enter the module/extension name to generate tests for:")
310-
if response.action != "accept" or not response.data:
311-
return "Test generation cancelled."
312-
module_name = response.data
313-
else:
314-
await ctx.info(f"Detected generated module: {module_name}")
315-
316-
aaz_path = Path(f"{paths['cli']}/src/azure-cli/azure/cli/command_modules/{module_name}/aaz")
317-
if not aaz_path.exists():
318-
return f"AAZ path not found for module '{module_name}'"
319-
320-
commands = []
321-
for file in aaz_path.rglob("*.py"):
322-
with open(file, "r", encoding="utf-8") as f:
323-
for line in f:
324-
if line.strip().startswith("def "):
325-
commands.append(line.strip().replace("def ", "").split("(")[0])
326-
327-
test_dir = Path(f"{paths['cli']}/src/azure-cli/azure/cli/command_modules/{module_name}/tests/latest")
328-
test_dir.mkdir(parents=True, exist_ok=True)
329-
test_file = test_dir / f"test_{module_name}.py"
330-
331-
with open(test_file, "w", encoding="utf-8") as f:
332-
f.write("import unittest\n")
333-
f.write("from azure.cli.testsdk import ScenarioTest\n\n")
334-
f.write(f"class {module_name.capitalize()}ScenarioTest(ScenarioTest):\n\n")
335-
for cmd in commands:
336-
f.write(f" def test_{cmd}(self):\n")
337-
f.write(f" self.cmd('az {module_name} {cmd} --resource-name test-resource')\n\n")
338-
339-
await ctx.info(f"Generated test file: {test_file}")
340-
return f"Test generation completed for module '{module_name}'."
308+
return "Azure CLI code generation completed successfully!"

tools/aaz-flow/main.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,10 @@
33
# Licensed under the MIT License. See License.txt in the project root for license information.
44
# --------------------------------------------------------------------------------------------
55

6-
import os
7-
from pathlib import Path
86
from fastmcp import FastMCP, Context
97
from models import AAZRequest
10-
from helpers import generate_tests, execute_commands, validate_paths, get_name, get_swagger_config
8+
from helpers import execute_commands, validate_paths, get_name, get_swagger_config
9+
from testgen import generate_tests
1110

1211
mcp = FastMCP("AAZ Flow")
1312

@@ -65,7 +64,7 @@ async def generate_code(ctx: Context):
6564

6665
await ctx.info("Automatically generating tests for the newly generated module...")
6766
try:
68-
test_result = await generate_tests(ctx)
67+
test_result = await generate_tests(ctx, paths)
6968
await ctx.info(f"Automatic test generation result: {test_result}")
7069
except Exception as e:
7170
await ctx.info(f"Automatic test generation failed: {str(e)}")

tools/aaz-flow/prompt_templates.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
# --------------------------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for license information.
4+
# --------------------------------------------------------------------------------------------
5+
6+
"""Prompt templates and static guidance for AAZ Flow tools."""
7+
8+
def get_testgen_static_instructions() -> str:
9+
return (
10+
"You are generating Azure CLI scenario tests for a new module.\n"
11+
"Follow the style used by azure-cli scenario tests. Keep tests idempotent and light.\n"
12+
"Generate tests that achieve at least 80%% coverage of methods and parameters covering primary commands for the target module.\n"
13+
"To understand the primary commands that need to be tested, read through and understand the target module's generated AAZ commands.\n"
14+
"Constraints: \n"
15+
"- Include necessary imports: azure.cli.testsdk imports and others only as required and seen in reference.\n"
16+
"- Use self.kwargs for dynamic values.\n"
17+
"- Use ResourceGroupPreparer if a resource group is implied.\n"
18+
"- Add minimal checks (e.g., self.check) where sensible.\n"
19+
"- Keep tests safe-by-default; avoid destructive operations unless clearly required.\n"
20+
"- Output only valid Python code for the test file, nothing else."
21+
)
22+
23+
REF_STYLE_LABEL = (
24+
"Read and reference the following test files (do not copy verbatim, just follow structure):\n"
25+
)

tools/aaz-flow/testgen.py

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
# --------------------------------------------------------------------------------------------
2+
# Copyright (c) Microsoft Corporation. All rights reserved.
3+
# Licensed under the MIT License. See License.txt in the project root for license information.
4+
# --------------------------------------------------------------------------------------------
5+
6+
from pathlib import Path
7+
import os
8+
import asyncio
9+
import random
10+
from fastmcp import Context
11+
from prompt_templates import get_testgen_static_instructions, REF_STYLE_LABEL
12+
13+
14+
async def check_module_status(ctx: Context, paths: dict):
15+
await ctx.info("Starting test generation workflow.")
16+
17+
module_name = getattr(ctx, "generated_module", None)
18+
if not module_name:
19+
response = await ctx.elicit("Enter the module/extension name to generate tests for:")
20+
if response.action != "accept" or not response.data:
21+
return "Test generation cancelled."
22+
module_name = response.data
23+
else:
24+
await ctx.info(f"Detected generated module: {module_name}")
25+
26+
ctx.generated_module = module_name
27+
return module_name
28+
29+
def find_test_dir(cli_extension_path: str, module_name: str) -> Path | None:
30+
base = Path(cli_extension_path) / "src" / module_name
31+
for path in base.rglob("tests/latest"):
32+
if path.is_dir():
33+
return path
34+
return None
35+
36+
async def check_path_status(ctx: Context, paths: dict):
37+
module_name = await check_module_status(ctx, paths)
38+
if not module_name or "cancelled" in str(module_name).lower():
39+
return str(module_name), [], None
40+
41+
aaz_path = Path(
42+
f"{paths['cli_extension']}/src/{module_name}")
43+
if not aaz_path.exists():
44+
return f"AAZ path not found for module '{module_name}'", [], None
45+
46+
commands = []
47+
for file in aaz_path.rglob("*.py"):
48+
with open(file, "r", encoding="utf-8") as f:
49+
for line in f:
50+
if line.strip().startswith("def "):
51+
commands.append(line.strip().replace(
52+
"def ", "").split("(")[0])
53+
54+
test_dir = find_test_dir(paths["cli_extension"], module_name)
55+
test_dir.mkdir(parents=True, exist_ok=True)
56+
test_file = test_dir / f"test_{module_name}.py"
57+
return module_name, commands, test_file
58+
59+
def build_testgen_prompt(module_name: str, commands: list[str], reference_snippet: str = "") -> str:
60+
parts = [
61+
get_testgen_static_instructions(),
62+
(
63+
f"Module name: '{module_name}'. Generate a single test class named "
64+
f"'{module_name.capitalize()}ScenarioTest' deriving from ScenarioTest."
65+
),
66+
"Discovered AAZ functions (potential commands):\n" +
67+
", ".join(commands[:30])
68+
]
69+
if reference_snippet:
70+
parts.append(REF_STYLE_LABEL + reference_snippet)
71+
return "\n\n".join(parts)
72+
73+
def strip_code_fences(text: str) -> str:
74+
lines = text.strip().splitlines()
75+
blocks = []
76+
inside = False
77+
current = []
78+
79+
for line in lines:
80+
if line.strip().startswith("```"):
81+
if inside:
82+
blocks.append("\n".join(current).strip())
83+
current = []
84+
inside = False
85+
else:
86+
inside = True
87+
elif inside:
88+
current.append(line)
89+
if current:
90+
blocks.append("\n".join(current).strip())
91+
92+
if blocks:
93+
return max(blocks, key=len)
94+
return text.strip()
95+
96+
97+
async def generate_tests(ctx: Context, paths: dict):
98+
module_name, commands, test_file = await check_path_status(ctx, paths)
99+
100+
if test_file is None and not commands:
101+
return str(module_name)
102+
103+
if not commands:
104+
return f"No commands found to generate tests for module '{module_name}'."
105+
106+
reference_snippet = "\n".join([
107+
"azure-cli/src/azure-cli/azure/cli/command_modules/resource/tests/latest/test_resource.py",
108+
"azure-cli/src/azure-cli/azure/cli/command_modules/keyvault/tests/latest/test_keyvault_commands.py"
109+
])
110+
111+
sampling_prompt = build_testgen_prompt(
112+
module_name, commands, reference_snippet)
113+
ctx.info("Constructed test generation prompt as follows:\n" + sampling_prompt)
114+
max_retries = int(os.getenv("TESTGEN_RETRIES", "5"))
115+
base_delay = float(os.getenv("TESTGEN_RETRY_BASE_DELAY", "2"))
116+
117+
attempt = 0
118+
content = ""
119+
last_err = None
120+
while attempt <= max_retries:
121+
try:
122+
if attempt > 0:
123+
await ctx.info(f"Retrying test generation (attempt {attempt}/{max_retries})...")
124+
sampled = await ctx.sample(sampling_prompt)
125+
raw_content = (getattr(sampled, "text", "") or "").strip()
126+
content = strip_code_fences(raw_content)
127+
if content:
128+
break
129+
last_err = RuntimeError("Empty content returned from provider")
130+
raise last_err
131+
except Exception as ex:
132+
last_err = ex
133+
message = str(ex).lower()
134+
retriable = any(k in message for k in [
135+
"rate limit", "overloaded", "timeout", "temporarily unavailable", "429"
136+
])
137+
if attempt >= max_retries or not retriable:
138+
break
139+
delay = base_delay * (2 ** attempt)
140+
delay = min(delay, 30)
141+
jitter = random.uniform(0.7, 1.3)
142+
wait_time = delay * jitter
143+
await ctx.info(f"Transient error encountered: {ex}. Waiting {wait_time:.1f}s before retry.")
144+
await asyncio.sleep(wait_time)
145+
attempt += 1
146+
continue
147+
148+
if not content:
149+
if last_err:
150+
return (
151+
f"Test generation failed after {max_retries} retries for module '{module_name}': "
152+
f"{last_err}"
153+
)
154+
return f"Test generation failed: no content generated for module '{module_name}'."
155+
156+
with open(test_file, "w", encoding="utf-8") as f:
157+
f.write(content)
158+
159+
await ctx.info(f"Generated test file: {test_file}")
160+
return f"Test generation completed for module '{module_name}'."

0 commit comments

Comments
 (0)