Skip to content

Commit 2e02bf7

Browse files
committed
added new features and comprehensive test generation
1 parent 8480b0b commit 2e02bf7

File tree

5 files changed

+213
-89
lines changed

5 files changed

+213
-89
lines changed

tools/aaz-flow/helpers.py

Lines changed: 100 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,13 @@
1515
paths = {
1616
"aaz": os.getenv("AAZ_PATH", "/workspaces/aaz"),
1717
"cli": os.getenv("CLI_PATH", "/workspaces/azure-cli"),
18-
"cli_extension": os.getenv("CLI_EXTENSION_PATH", "/workspaces/azure-cli-extensions"),
19-
"swagger_path": os.getenv("SWAGGER_PATH", "/workspaces/azure-rest-api-specs")
18+
"cli_extension": os.getenv(
19+
"CLI_EXTENSION_PATH", "/workspaces/azure-cli-extensions"
20+
),
21+
"swagger_path": os.getenv("SWAGGER_PATH", "/workspaces/azure-rest-api-specs"),
2022
}
2123

24+
2225
async def fetch_available_services():
2326
"""Retrieve available services by parsing local azure-rest-api-specs/specification directory"""
2427
spec_path = os.path.join(paths["swagger_path"], "specification")
@@ -27,101 +30,131 @@ async def fetch_available_services():
2730
return ["storage", "compute", "network", "keyvault", "monitor"]
2831

2932
try:
30-
directories = [d for d in os.listdir(spec_path) if os.path.isdir(os.path.join(spec_path, d))]
33+
directories = [
34+
d
35+
for d in os.listdir(spec_path)
36+
if os.path.isdir(os.path.join(spec_path, d))
37+
]
3138
directories.sort()
3239
return directories
3340
except Exception:
3441
return ["storage", "compute", "network", "keyvault", "monitor"]
3542

43+
3644
async def validate_paths(ctx: Context) -> dict:
3745
"""Validate and get correct paths for required directories."""
3846

3947
await ctx.info("az_cli : Validating local paths...")
4048
await ctx.report_progress(progress=5, total=100)
4149

42-
for i, (key, path) in enumerate(paths.items(), 1):
43-
progress = 5 + (i * 5)
44-
await ctx.report_progress(progress=progress, total=100)
50+
combined_check = await ctx.sample(
51+
"Ask the user to confirm if the detected paths for AAZ, Azure CLI, Azure CLI Extensions and Swagger specs are correct. The detected paths are as follows:\n"
52+
f"- AAZ path: `{paths['aaz']}`\n"
53+
f"- Azure CLI path: `{paths['cli']}`\n"
54+
f"- Azure CLI Extensions path: `{paths['cli_extension']}`\n"
55+
f"- Swagger specifications path: `{paths['swagger_path']}`\n"
56+
"If any path is incorrect, ask the user to answer with 'no'."
57+
)
4558

46-
display_name = key.replace('_', ' ')
47-
phrased_question = await ctx.sample(
48-
f"Ask the user to confirm the path for {display_name} directory: {path}. Use `` around the path when displaying it."
49-
)
50-
check_result = await ctx.elicit(
51-
message=phrased_question.text,
52-
response_type=Literal["yes", "no"]
53-
)
59+
check_result = await ctx.elicit(
60+
message=combined_check.text, response_type=Literal["yes", "no"]
61+
)
5462

55-
if check_result.action != "accept":
56-
return None
63+
if check_result.action != "accept":
64+
return None
65+
66+
if check_result.data != "yes":
67+
for i, (key, path) in enumerate(paths.items(), 1):
68+
progress = 5 + (i * 5)
69+
await ctx.report_progress(progress=progress, total=100)
5770

58-
if check_result.data == "no":
59-
elicit_question = await ctx.sample(
60-
f"Ask the user to provide the correct path for the {display_name} directory."
71+
display_name = key.replace("_", " ")
72+
phrased_question = await ctx.sample(
73+
f"Ask the user to confirm the path for {display_name} directory: {path}. Use `` around the path when displaying it."
6174
)
62-
new_path_result = await ctx.elicit(
63-
message=elicit_question.text,
64-
response_type=str
75+
check_result = await ctx.elicit(
76+
message=phrased_question.text, response_type=Literal["yes", "no"]
6577
)
66-
if new_path_result.action != "accept":
78+
79+
if check_result.action != "accept":
6780
return None
68-
paths[key] = new_path_result.data.strip('"')
69-
await ctx.info(f"az_cli : Updated {display_name} path to: {paths[key]}")
81+
82+
if check_result.data == "no":
83+
elicit_question = await ctx.sample(
84+
f"Ask the user to provide the correct path for the {display_name} directory."
85+
)
86+
new_path_result = await ctx.elicit(
87+
message=elicit_question.text, response_type=str
88+
)
89+
if new_path_result.action != "accept":
90+
return None
91+
paths[key] = new_path_result.data.strip('"')
92+
await ctx.info(f"az_cli : Updated {display_name} path to: {paths[key]}")
7093

7194
await ctx.info("az_cli : Verifying path existence...")
7295
await ctx.report_progress(progress=30, total=100)
7396

7497
for key, path in paths.items():
7598
if not os.path.exists(path):
76-
raise FileNotFoundError(f"{key.replace('_', ' ')} path does not exist: {path}")
99+
raise FileNotFoundError(
100+
f"{key.replace('_', ' ')} path does not exist: {path}"
101+
)
77102

78103
await ctx.info("az_cli : Path validation completed.")
79104
await ctx.report_progress(progress=35, total=100)
80105
return paths
81106

107+
82108
async def get_name(ctx: Context) -> str:
83109
"""Get the extension or module name from user."""
84110
await ctx.info("az_cli : Fetching available services...")
85111
common_extensions = await fetch_available_services()
86112
await ctx.report_progress(progress=40, total=100)
87113

88114
choice_prompt = await ctx.sample(
89-
"When the user clicks on the Respond button, the user will receive a list of Azure CLI modules and extensions to choose from." \
90-
"This list is fetched directly from the Azure REST API Specs repository. " \
91-
"Ask the user in a professional manner to select a module/extension from the list. " \
92-
"The list is provided when they click on the Respond button so do not give them any options in the questions itself. " \
115+
"When the user clicks on the Respond button, the user will receive a list of Azure CLI modules and extensions to choose from."
116+
"This list is fetched directly from the Azure REST API Specs repository. "
117+
"Ask the user in a professional manner to select a module/extension from the list. "
118+
"The list is provided when they click on the Respond button so do not give them any options in the questions itself. "
93119
"The result of this option selection will determine which module's code will be generated for Azure CLI."
94120
)
95121
extension_choice = await ctx.elicit(
96-
message=choice_prompt.text,
97-
response_type=Literal[tuple(common_extensions)]
122+
message=choice_prompt.text, response_type=Literal[tuple(common_extensions)]
98123
)
99124

100125
if extension_choice.action != "accept":
101126
return None
102127

103128
if extension_choice.data == "other":
104129
custom_extension = await ctx.elicit(
105-
"Enter custom extension/module name:",
106-
response_type=str
130+
"Enter custom extension/module name:", response_type=str
107131
)
108132
if custom_extension.action != "accept":
109133
return None
110134
return custom_extension.data
111135

112136
return extension_choice.data
113137

114-
async def get_swagger_config(ctx: Context, paths: dict, service_name: str = None) -> dict:
138+
139+
async def get_swagger_config(
140+
ctx: Context, paths: dict, service_name: str = None
141+
) -> dict:
115142
"""Get Swagger configuration details from user."""
116143
await ctx.info("az_cli : Browsing Swagger specifications...")
117144
await ctx.report_progress(progress=60, total=100)
118145

119-
spec_result = await browse_specs(ctx, os.path.join(paths["swagger_path"], "specification", service_name, "resource-manager"))
146+
spec_result = await browse_specs(
147+
ctx,
148+
os.path.join(
149+
paths["swagger_path"], "specification", service_name, "resource-manager"
150+
),
151+
)
120152
if not spec_result:
121153
return None
122154
else:
123155
return spec_result
124156

157+
125158
async def browse_specs(ctx: Context, base_path: str):
126159
"""Interactive browser for Swagger specifications with clean labels and correct metadata extraction."""
127160
await ctx.info("az_cli : Starting specification browser...")
@@ -140,14 +173,19 @@ async def browse_specs(ctx: Context, base_path: str):
140173
return None
141174

142175
dirs = [e for e in entries if os.path.isdir(os.path.join(current_path, e))]
143-
files = [e for e in entries if os.path.isfile(os.path.join(current_path, e)) and e.endswith((".json", ".yaml", ".yml"))]
176+
files = [
177+
e
178+
for e in entries
179+
if os.path.isfile(os.path.join(current_path, e))
180+
and e.endswith((".json", ".yaml", ".yml"))
181+
]
144182

145183
labels = [".."] + [f"> {d}" for d in dirs] + files
146184
mapping = dict(zip(labels, [".."] + dirs + files))
147185

148186
choice = await ctx.elicit(
149187
message="Click on the respond button to browse through the sub-folders of the chosen service and select the appropriate spec file.",
150-
response_type=Literal[tuple(labels)]
188+
response_type=Literal[tuple(labels)],
151189
)
152190

153191
if choice.action != "accept":
@@ -177,20 +215,21 @@ async def browse_specs(ctx: Context, base_path: str):
177215
"file": os.path.dirname(base_path),
178216
"resource_provider": resource_provider,
179217
"release": release,
180-
"swagger_tag": swagger_tag
218+
"swagger_tag": swagger_tag,
181219
}
182220

183221
await ctx.info(
184222
f"az_cli : Extracted: Resource Provider={resource_provider}, Release={release}, Tag={swagger_tag}"
185223
)
186224
return result
187225

188-
async def run_command(ctx: Context, command: str, step_name: str, progress_start: int, progress_end: int):
226+
227+
async def run_command(
228+
ctx: Context, command: str, step_name: str, progress_start: int, progress_end: int
229+
):
189230
await ctx.info(f"az_cli : Starting: {step_name}")
190231
process = await asyncio.create_subprocess_shell(
191-
command,
192-
stdout=asyncio.subprocess.PIPE,
193-
stderr=asyncio.subprocess.STDOUT
232+
command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT
194233
)
195234

196235
progress_range = progress_end - progress_start
@@ -206,7 +245,9 @@ async def run_command(ctx: Context, command: str, step_name: str, progress_start
206245
continue
207246
lines_count += 1
208247
await ctx.info(f"az_cli : {line.decode().rstrip()}")
209-
progress = progress_start + min(progress_range, int((lines_count / total_lines_estimate) * progress_range))
248+
progress = progress_start + min(
249+
progress_range, int((lines_count / total_lines_estimate) * progress_range)
250+
)
210251
await ctx.report_progress(progress, 100)
211252

212253
await process.wait()
@@ -245,25 +286,36 @@ def _resolve_aaz_dev_prefix() -> str:
245286
for py in _resolve_python_candidates():
246287
try:
247288
import subprocess
289+
248290
code = (
249291
"import importlib.util, sys; "
250292
"spec = importlib.util.find_spec('aaz_dev.__main__'); "
251293
"sys.exit(0 if spec else 1)"
252294
)
253-
res = subprocess.run([py, "-c", code], stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
295+
res = subprocess.run(
296+
[py, "-c", code],
297+
stdout=subprocess.PIPE,
298+
stderr=subprocess.PIPE,
299+
text=True,
300+
)
254301
if res.returncode == 0:
255302
return f"{py} -m aaz_dev"
256303
except Exception:
257304
pass
258305
for maybe in [
259306
"/workspaces/.venv/bin/aaz-dev",
260-
str(Path(os.environ.get("VIRTUAL_ENV", "")) / "bin" / "aaz-dev") if os.environ.get("VIRTUAL_ENV") else None,
261-
shutil.which("aaz-dev")
307+
(
308+
str(Path(os.environ.get("VIRTUAL_ENV", "")) / "bin" / "aaz-dev")
309+
if os.environ.get("VIRTUAL_ENV")
310+
else None
311+
),
312+
shutil.which("aaz-dev"),
262313
]:
263314
if maybe and os.path.exists(maybe):
264315
return maybe
265316
return "aaz-dev"
266317

318+
267319
async def execute_commands(ctx: Context, paths: dict, request: AAZRequest):
268320
aaz_dev = _resolve_aaz_dev_prefix()
269321
await ctx.info(f"az_cli : Using aaz-dev invocation: {aaz_dev}")
@@ -290,19 +342,9 @@ async def execute_commands(ctx: Context, paths: dict, request: AAZRequest):
290342

291343
try:
292344
await run_command(ctx, cmd1, "Generate command model from Swagger", 50, 80)
293-
generated_init_path = f"{paths['cli']}/src/azure-cli/azure/cli/command_modules/{request.name}/aaz/__init__.py"
294-
max_wait = 30
295-
waited = 0
296-
while not os.path.exists(generated_init_path) and waited < max_wait:
297-
await ctx.info(f"az_cli : Waiting for {generated_init_path} to be created...")
298-
await asyncio.sleep(1)
299-
waited += 1
300-
if not os.path.exists(generated_init_path):
301-
await ctx.info(f"az_cli : Timed out waiting for {generated_init_path}")
302-
return f"Code generation failed: {generated_init_path} was not created."
303345
await run_command(ctx, cmd2, "Generate CLI from Swagger tag", 80, 100)
304346
except Exception as e:
305347
await ctx.info(f"az_cli : Code generation failed: {str(e)}")
306348
return f"Code generation failed: {str(e)}"
307349

308-
return "Azure CLI code generation completed successfully!"
350+
return "Azure CLI code generation completed successfully!"

tools/aaz-flow/main.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -10,23 +10,25 @@
1010

1111
mcp = FastMCP("AAZ Flow")
1212

13+
1314
@mcp.tool(
1415
name="az_cli_generate_help",
15-
description="Explains how to correctly call the az_cli_generate tool."
16+
description="Explains how to correctly call the az_cli_generate tool.",
1617
)
1718
async def generate_help(ctx: Context):
1819
help_message = {
1920
"tool": "az_cli_generate",
2021
"description": "Generate Azure CLI commands from Swagger specs.",
2122
"parameters": {},
22-
"usage": "Call with no parameters, e.g. {}"
23+
"usage": "Call with no parameters, e.g. {}",
2324
}
2425
await ctx.info("az_cli_generate_help retrieved.")
2526
return help_message
2627

28+
2729
@mcp.tool(
2830
name="az_cli_generate",
29-
description="Generate Azure CLI commands from Swagger specs."
31+
description="Generate Azure CLI commands from Swagger specs.",
3032
)
3133
async def generate_code(ctx: Context):
3234
await ctx.info("Initiating Azure CLI code generation workflow.")
@@ -53,7 +55,7 @@ async def generate_code(ctx: Context):
5355
name=name,
5456
swagger_module_path=swagger_config["file"],
5557
resource_provider=swagger_config["resource_provider"],
56-
swagger_tag=swagger_config["swagger_tag"]
58+
swagger_tag=swagger_config["swagger_tag"],
5759
)
5860

5961
await execute_commands(ctx, paths, request)
@@ -69,7 +71,10 @@ async def generate_code(ctx: Context):
6971
except Exception as e:
7072
await ctx.info(f"Automatic test generation failed: {str(e)}")
7173

72-
return f"Code generation and test generation completed for extension/module '{name}'."
74+
return (
75+
f"Code generation and test generation completed for extension/module '{name}'."
76+
)
77+
7378

7479
if __name__ == "__main__":
75-
mcp.run(transport="stdio")
80+
mcp.run(transport="stdio")

tools/aaz-flow/models.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
from dataclasses import dataclass
77

8+
89
@dataclass
910
class AAZRequest:
1011
name: str

tools/aaz-flow/prompt_templates.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
"""Prompt templates and static guidance for AAZ Flow tools."""
77

8+
89
def get_testgen_static_instructions() -> str:
910
return (
1011
"You are generating Azure CLI scenario tests for a new module.\n"
@@ -17,9 +18,10 @@ def get_testgen_static_instructions() -> str:
1718
"- Use ResourceGroupPreparer if a resource group is implied.\n"
1819
"- Add minimal checks (e.g., self.check) where sensible.\n"
1920
"- Keep tests safe-by-default; avoid destructive operations unless clearly required.\n"
21+
"- Ensure tests can run in parallel without conflicts.\n"
22+
"- If tests are large and can be safely and logically split, create multiple test methods (i.e. avoid a single CRUD test if possible, split it into multiple tests if logically and safely separable).\n"
2023
"- Output only valid Python code for the test file, nothing else."
2124
)
2225

23-
REF_STYLE_LABEL = (
24-
"Read and reference the following test files (do not copy verbatim, just follow structure):\n"
25-
)
26+
27+
REF_STYLE_LABEL = "Read and reference the following test files (do not copy verbatim, just follow structure):\n"

0 commit comments

Comments
 (0)